diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..eb12f60 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1,4 @@ +# Clippy configuration +doc-valid-idents = ["GitHub", "GitLab", "TypeScript", "WebSocket", "PostgreSQL", "Redis", "OpenAI"] +avoid-breaking-exported-api = true +disallowed-types = [] diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..0b8cf0b --- /dev/null +++ b/.editorconfig @@ -0,0 +1,42 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.{js,ts,jsx,tsx,json,jsonc,md,yaml,yml,toml}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.go] +indent_style = tab +indent_size = unset +tab_width = 8 +[*.rs] +indent_style = space +indent_size = 4 + +[Makefile] +indent_style = tab +indent_size = unset + +[Dockerfile] +indent_style = space +indent_size = 2 + +[*.md] +trim_trailing_whitespace = false + +[*.{yml,yaml}] +indent_style = space +indent_size = 2 + +[*.toml] +indent_style = space +indent_size = 2 \ No newline at end of file diff --git a/app/email/Cargo.toml b/app/email/Cargo.toml new file mode 100644 index 0000000..280815a --- /dev/null +++ b/app/email/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "app-email" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[[bin]] +name = "email-service" +path = "src/main.rs" + +[dependencies] +anyhow = { workspace = true } +config = { workspace = true } +email = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } + +[lints] +workspace = true \ No newline at end of file diff --git a/app/email/src/context.rs b/app/email/src/context.rs new file mode 100644 index 0000000..5872453 --- /dev/null +++ b/app/email/src/context.rs @@ -0,0 +1,24 @@ +use config::AppConfig; +use tracing_subscriber::EnvFilter; + +pub struct AppContext { + pub config: AppConfig, +} + +impl AppContext { + pub fn init() -> anyhow::Result { + let config = AppConfig::load(); + init_tracing(&config)?; + Ok(Self { config }) + } +} + +fn init_tracing(config: &AppConfig) -> anyhow::Result<()> { + let level = config.log_level()?; + let filter = EnvFilter::try_new(&level)?; + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(false) + .init(); + Ok(()) +} diff --git a/app/email/src/main.rs b/app/email/src/main.rs new file mode 100644 index 0000000..50f7f01 --- /dev/null +++ b/app/email/src/main.rs @@ -0,0 +1,22 @@ +mod context; + +use context::AppContext; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let ctx = AppContext::init()?; + tracing::info!("email service starting"); + + tokio::select! { + result = email::EmailWorker::start(&ctx.config) => { + if let Err(e) = result { + tracing::error!("email worker exited with error: {}", e); + } + } + _ = tokio::signal::ctrl_c() => { + tracing::info!("shutdown signal received, stopping email service"); + } + } + + Ok(()) +} diff --git a/app/gitdata/Cargo.toml b/app/gitdata/Cargo.toml new file mode 100644 index 0000000..42696cf --- /dev/null +++ b/app/gitdata/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "app-gitdata" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[[bin]] +name = "gitdata" +path = "src/main.rs" + +[[bin]] +name = "gen-openapi" +path = "src/bin/gen-openapi.rs" + +[dependencies] +anyhow = { workspace = true } +config = { workspace = true } +cache = { workspace = true } +db = { workspace = true } +service = { workspace = true } +session = { workspace = true } +api = { workspace = true } +email = { workspace = true } +storage = { workspace = true } +git = { workspace = true } +model = { workspace = true } +channel = { workspace = true } +socketio = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } +actix-web = { workspace = true, features = ["cookies", "secure-cookies"] } +actix-ws = { workspace = true } +tonic = { workspace = true, features = ["transport"] } +deadpool-redis = { workspace = true } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp", "connection-manager", "cluster"] } +sqlx = { workspace = true, features = ["postgres", "runtime-tokio"] } +serde_json = { workspace = true } +uuid = { workspace = true, features = ["v4", "v7", "serde"] } + +[lints] +workspace = true \ No newline at end of file diff --git a/app/gitdata/src/bin/gen-openapi.rs b/app/gitdata/src/bin/gen-openapi.rs new file mode 100644 index 0000000..8bc5c4c --- /dev/null +++ b/app/gitdata/src/bin/gen-openapi.rs @@ -0,0 +1,7 @@ +use std::fs; + +fn main() { + let json = api::openapi::openapi_json(); + fs::write("openapi.json", json).expect("Failed to write openapi.json"); + println!("openapi.json generated successfully"); +} diff --git a/app/gitdata/src/context.rs b/app/gitdata/src/context.rs new file mode 100644 index 0000000..c4c9c58 --- /dev/null +++ b/app/gitdata/src/context.rs @@ -0,0 +1,127 @@ +use actix_web::cookie::Key; +use cache::{AppCache, AppCacheConfig}; +use config::AppConfig; +use db::database::AppDatabase; +use deadpool_redis::{ + PoolConfig, Runtime, Timeouts, + cluster::{Config, Pool as RedisPool}, +}; +use email::AppEmail; +use service::AppService; +use session::storage::RedisClusterSessionStore; +use storage::{AppStorage, AppStorageConfig}; +use tonic::transport::Channel; + +use channel::{ChannelBus, ChannelBusConfig}; +use socketio::SocketIo; + +pub struct AppContext { + pub config: AppConfig, + pub service: AppService, + pub session_store: RedisClusterSessionStore, + pub session_key: Key, + pub channel_bus: ChannelBus, +} + +impl AppContext { + pub async fn init() -> anyhow::Result { + let config = AppConfig::load(); + init_tracing(&config)?; + + tracing::info!("initializing database"); + let db = AppDatabase::init(&config).await?; + + tracing::info!("initializing cache"); + let cache_config = AppCacheConfig::try_from(&config)?; + let cache = AppCache::init(cache_config).await?; + + tracing::info!("initializing storage"); + let storage_config = AppStorageConfig::try_from(&config)?; + let storage = AppStorage::init(storage_config).await?; + + tracing::info!("initializing email"); + let email = AppEmail::init(&config).await?; + + tracing::info!("connecting to git RPC"); + let rpc_addr = config.git_rpc_addr()?; + let rpc_port = config.git_rpc_port()?; + let git_channel = + Channel::from_shared(format!("http://{}:{}", rpc_addr, rpc_port)) + .expect("invalid gRPC endpoint") + .connect() + .await?; + + let service = AppService { + db, + cache, + email, + storage, + config: config.clone(), + git: git_channel, + redis_pool: init_redis_pool(&config)?, + }; + + tracing::info!("initializing session store"); + let redis_urls = config.redis_urls()?; + let session_store = RedisClusterSessionStore::new(redis_urls).await?; + + tracing::info!("initializing session key"); + let secret = config.session_secret()?; + let session_key = Key::from(secret.as_bytes()); + + tracing::info!("initializing channel bus"); + let io = SocketIo::new(); + let channel_config = ChannelBusConfig { + namespace: "/channel".to_owned(), + signing_secret: Some(secret.clone()), + ..Default::default() + }; + let channel_bus = ChannelBus::new( + service.db.clone(), + service.cache.clone(), + io, + channel_config, + ); + channel_bus.attach().await?; + + Ok(Self { + config, + service, + session_store, + session_key, + channel_bus, + }) + } +} + +fn init_tracing(config: &AppConfig) -> anyhow::Result<()> { + let level = config.log_level()?; + let filter = tracing_subscriber::EnvFilter::try_new(&level)?; + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(false) + .init(); + Ok(()) +} + +fn init_redis_pool(config: &AppConfig) -> anyhow::Result { + let redis_urls = config.redis_urls()?; + let pool_size = config.redis_pool_size()?; + let connect_timeout = config.redis_connect_timeout()?; + let acquire_timeout = config.redis_acquire_timeout()?; + + let mut pool_config = PoolConfig::new(pool_size as usize); + pool_config.timeouts = Timeouts { + wait: Some(std::time::Duration::from_secs(acquire_timeout)), + create: Some(std::time::Duration::from_secs(connect_timeout)), + recycle: Some(std::time::Duration::from_secs(connect_timeout)), + }; + + let cfg = Config { + urls: Some(redis_urls), + connections: None, + pool: Some(pool_config), + read_from_replicas: false, + }; + Ok(cfg.create_pool(Some(Runtime::Tokio1))?) +} diff --git a/app/gitdata/src/main.rs b/app/gitdata/src/main.rs new file mode 100644 index 0000000..1c4ebae --- /dev/null +++ b/app/gitdata/src/main.rs @@ -0,0 +1,111 @@ +mod context; +mod shutdown; + +use std::time::Instant; + +use actix_web::{App, dev::Service}; +use context::AppContext; +use service::ai::sync::spawn_model_sync_loop; + +const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[ + "/health", + "/live", + "/ready", + "/metrics", + "/favicon.ico", + "/robots.txt", +]; + +fn should_log_request(path: &str) -> bool { + !REQUEST_LOG_EXCLUDED_PATHS.contains(&path) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let ctx = AppContext::init().await?; + + let api_port = ctx.config.api_port()?; + tracing::info!("GitDataAI API service starting on 0.0.0.0:{}", api_port); + + let service = ctx.service.clone(); + let session_store = ctx.session_store.clone(); + let session_key = ctx.session_key.clone(); + let channel_bus = ctx.channel_bus.clone(); + + let srv = actix_web::HttpServer::new(move || { + let session_middleware = session::SessionMiddleware::builder( + session_store.clone(), + session_key.clone(), + ) + .cookie_secure(false) + .cookie_name("id".to_string()) + .session_lifecycle( + session::config::PersistentSession::default() + .session_ttl(actix_web::cookie::time::Duration::days(30)), + ) + .build(); + + App::new() + .app_data(actix_web::web::Data::new(service.clone())) + .app_data(actix_web::web::Data::new(channel_bus.clone())) + .wrap_fn(|req, srv| { + let should_log = should_log_request(req.path()); + let method = req.method().clone(); + let path = req.path().to_owned(); + let peer_addr = + req.connection_info().peer_addr().map(str::to_owned); + let started_at = Instant::now(); + let fut = srv.call(req); + + async move { + match fut.await { + Ok(res) => { + if should_log { + tracing::info!( + method = %method, + path = %path, + status = res.status().as_u16(), + elapsed_ms = started_at.elapsed().as_millis(), + peer_addr = peer_addr.as_deref().unwrap_or("-"), + "http request" + ); + } + Ok(res) + } + Err(err) => { + if should_log { + tracing::warn!( + method = %method, + path = %path, + elapsed_ms = started_at.elapsed().as_millis(), + peer_addr = peer_addr.as_deref().unwrap_or("-"), + error = %err, + "http request failed" + ); + } + Err(err) + } + } + } + }) + .wrap(session_middleware) + .configure(|cfg| api::configure(cfg, channel_bus.clone())) + }) + .bind(format!("0.0.0.0:{}", api_port))?; + + spawn_model_sync_loop(ctx.service.clone()); + + let server = srv.run(); + tracing::info!("API server is running"); + + tokio::select! { + _ = server => { + tracing::info!("API server stopped"); + } + _ = shutdown::wait_for_shutdown_signal() => { + tracing::info!("shutdown signal received, stopping gitdata API service"); + } + } + + Ok(()) +} diff --git a/app/gitdata/src/shutdown.rs b/app/gitdata/src/shutdown.rs new file mode 100644 index 0000000..fe10327 --- /dev/null +++ b/app/gitdata/src/shutdown.rs @@ -0,0 +1,25 @@ +pub async fn wait_for_shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("failed to listen for ctrl_c event"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal( + tokio::signal::unix::SignalKind::terminate(), + ) + .expect("failed to listen for SIGTERM") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } +} diff --git a/app/gitpod/Cargo.toml b/app/gitpod/Cargo.toml new file mode 100644 index 0000000..e43b586 --- /dev/null +++ b/app/gitpod/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "app-gitpod" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[[bin]] +name = "gitpod" +path = "src/main.rs" + +[dependencies] +anyhow = { workspace = true } +config = { workspace = true } +cache = { workspace = true } +db = { workspace = true } +git = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } +deadpool-redis = { workspace = true } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] } + +[lints] +workspace = true \ No newline at end of file diff --git a/app/gitpod/src/context.rs b/app/gitpod/src/context.rs new file mode 100644 index 0000000..45b5a8e --- /dev/null +++ b/app/gitpod/src/context.rs @@ -0,0 +1,65 @@ +use std::time::Duration; + +use cache::{AppCache, AppCacheConfig}; +use config::AppConfig; +use db::database::AppDatabase; +use deadpool_redis::{PoolConfig, Runtime, Timeouts, cluster::Config}; + +pub struct AppContext { + pub config: AppConfig, + pub db: AppDatabase, + pub cache: AppCache, + pub redis_pool: deadpool_redis::cluster::Pool, +} + +impl AppContext { + pub async fn init() -> anyhow::Result { + let config = AppConfig::load(); + init_tracing(&config)?; + + tracing::info!("initializing database"); + let db = AppDatabase::init(&config).await?; + + tracing::info!("initializing cache"); + let cache_config = AppCacheConfig::try_from(&config)?; + let cache = AppCache::init(cache_config).await?; + + tracing::info!("initializing redis pool"); + let redis_urls = config.redis_urls()?; + let pool_size = config.redis_pool_size()?; + let connect_timeout = config.redis_connect_timeout()?; + let acquire_timeout = config.redis_acquire_timeout()?; + + let mut pool_config = PoolConfig::new(pool_size as usize); + pool_config.timeouts = Timeouts { + wait: Some(Duration::from_secs(acquire_timeout)), + create: Some(Duration::from_secs(connect_timeout)), + recycle: Some(Duration::from_secs(connect_timeout)), + }; + + let cfg = Config { + urls: Some(redis_urls), + connections: None, + pool: Some(pool_config), + read_from_replicas: false, + }; + let redis_pool = cfg.create_pool(Some(Runtime::Tokio1))?; + + Ok(Self { + config, + db, + cache, + redis_pool, + }) + } +} + +fn init_tracing(config: &AppConfig) -> anyhow::Result<()> { + let level = config.log_level()?; + let filter = tracing_subscriber::EnvFilter::try_new(&level)?; + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(false) + .init(); + Ok(()) +} diff --git a/app/gitpod/src/main.rs b/app/gitpod/src/main.rs new file mode 100644 index 0000000..0b8bd62 --- /dev/null +++ b/app/gitpod/src/main.rs @@ -0,0 +1,82 @@ +mod context; +mod shutdown; + +use context::AppContext; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let ctx = AppContext::init().await?; + + let http_port = ctx.config.git_http_port()?; + let ssh_port = ctx.config.ssh_port()?; + let rpc_addr = ctx.config.git_rpc_addr()?; + let rpc_port = ctx.config.git_rpc_port()?; + + tracing::info!( + "gitpod service starting (HTTP:{} / SSH:{} / gRPC:{}:{})", + http_port, + ssh_port, + rpc_addr, + rpc_port + ); + + let http_task = tokio::spawn(git::http::run_http( + ctx.config.clone(), + ctx.db.clone(), + ctx.cache.clone(), + ctx.redis_pool.clone(), + )); + + let ssh_task = tokio::spawn(git::ssh::run_ssh( + ctx.config.clone(), + ctx.db.clone(), + ctx.cache.clone(), + ctx.redis_pool.clone(), + )); + + let rpc_addr_parsed = + format!("{}:{}", rpc_addr, rpc_port).parse::()?; + let sync_service = + git::sync::ReceiveSyncService::new(ctx.redis_pool.clone()); + let git_server = git::rpc::server::GitServer::new( + rpc_addr_parsed, + ctx.db.clone(), + ctx.cache.clone(), + sync_service, + ); + let rpc_task = tokio::spawn(async move { + git_server + .serve() + .await + .map_err(|e| anyhow::anyhow!("{}", e)) + }); + + tokio::select! { + result = http_task => { + match result { + Ok(Ok(())) => tracing::info!("HTTP server stopped"), + Ok(Err(e)) => tracing::error!("HTTP server error: {}", e), + Err(e) => tracing::error!("HTTP task panicked: {}", e), + } + } + result = ssh_task => { + match result { + Ok(Ok(())) => tracing::info!("SSH server stopped"), + Ok(Err(e)) => tracing::error!("SSH server error: {}", e), + Err(e) => tracing::error!("SSH task panicked: {}", e), + } + } + result = rpc_task => { + match result { + Ok(Ok(())) => tracing::info!("gRPC server stopped"), + Ok(Err(e)) => tracing::error!("gRPC server error: {}", e), + Err(e) => tracing::error!("gRPC task panicked: {}", e), + } + } + _ = shutdown::wait_for_shutdown_signal() => { + tracing::info!("shutdown signal received, stopping gitpod service"); + } + } + + Ok(()) +} diff --git a/app/gitpod/src/shutdown.rs b/app/gitpod/src/shutdown.rs new file mode 100644 index 0000000..fe10327 --- /dev/null +++ b/app/gitpod/src/shutdown.rs @@ -0,0 +1,25 @@ +pub async fn wait_for_shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("failed to listen for ctrl_c event"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal( + tokio::signal::unix::SignalKind::terminate(), + ) + .expect("failed to listen for SIGTERM") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } +} diff --git a/app/gitsync/Cargo.toml b/app/gitsync/Cargo.toml new file mode 100644 index 0000000..9228100 --- /dev/null +++ b/app/gitsync/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "app-gitsync" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[[bin]] +name = "gitsync" +path = "src/main.rs" + +[dependencies] +anyhow = { workspace = true } +config = { workspace = true } +cache = { workspace = true } +db = { workspace = true } +git = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter", "json"] } +actix-web = { workspace = true } +deadpool-redis = { workspace = true } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] } +uuid = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true, features = ["postgres", "runtime-tokio"] } + +[lints] +workspace = true \ No newline at end of file diff --git a/app/gitsync/src/context.rs b/app/gitsync/src/context.rs new file mode 100644 index 0000000..45b5a8e --- /dev/null +++ b/app/gitsync/src/context.rs @@ -0,0 +1,65 @@ +use std::time::Duration; + +use cache::{AppCache, AppCacheConfig}; +use config::AppConfig; +use db::database::AppDatabase; +use deadpool_redis::{PoolConfig, Runtime, Timeouts, cluster::Config}; + +pub struct AppContext { + pub config: AppConfig, + pub db: AppDatabase, + pub cache: AppCache, + pub redis_pool: deadpool_redis::cluster::Pool, +} + +impl AppContext { + pub async fn init() -> anyhow::Result { + let config = AppConfig::load(); + init_tracing(&config)?; + + tracing::info!("initializing database"); + let db = AppDatabase::init(&config).await?; + + tracing::info!("initializing cache"); + let cache_config = AppCacheConfig::try_from(&config)?; + let cache = AppCache::init(cache_config).await?; + + tracing::info!("initializing redis pool"); + let redis_urls = config.redis_urls()?; + let pool_size = config.redis_pool_size()?; + let connect_timeout = config.redis_connect_timeout()?; + let acquire_timeout = config.redis_acquire_timeout()?; + + let mut pool_config = PoolConfig::new(pool_size as usize); + pool_config.timeouts = Timeouts { + wait: Some(Duration::from_secs(acquire_timeout)), + create: Some(Duration::from_secs(connect_timeout)), + recycle: Some(Duration::from_secs(connect_timeout)), + }; + + let cfg = Config { + urls: Some(redis_urls), + connections: None, + pool: Some(pool_config), + read_from_replicas: false, + }; + let redis_pool = cfg.create_pool(Some(Runtime::Tokio1))?; + + Ok(Self { + config, + db, + cache, + redis_pool, + }) + } +} + +fn init_tracing(config: &AppConfig) -> anyhow::Result<()> { + let level = config.log_level()?; + let filter = tracing_subscriber::EnvFilter::try_new(&level)?; + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(false) + .init(); + Ok(()) +} diff --git a/app/gitsync/src/health.rs b/app/gitsync/src/health.rs new file mode 100644 index 0000000..e4a7e18 --- /dev/null +++ b/app/gitsync/src/health.rs @@ -0,0 +1,99 @@ +use std::time::Instant; + +use actix_web::dev::Service; +use actix_web::{App, HttpResponse, HttpServer, dev::Server, web}; +use cache::AppCache; +use db::database::AppDatabase; + +const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[ + "/health", + "/live", + "/ready", + "/metrics", + "/favicon.ico", + "/robots.txt", +]; + +fn should_log_request(path: &str) -> bool { + !REQUEST_LOG_EXCLUDED_PATHS.contains(&path) +} + +async fn health( + db: web::Data, + cache: web::Data, +) -> HttpResponse { + let db_ok = sqlx::query("SELECT 1").execute(db.reader()).await.is_ok(); + let cache_ok = cache.ping_cluster().await.is_ok(); + + if db_ok && cache_ok { + HttpResponse::Ok().json(serde_json::json!({ + "status": "ok", + "db": "ok", + "cache": "ok", + })) + } else { + HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "unhealthy", + "db": if db_ok { "ok" } else { "error" }, + "cache": if cache_ok { "ok" } else { "error" }, + })) + } +} + +pub fn start_health( + port: u16, + db: AppDatabase, + cache: AppCache, +) -> anyhow::Result { + tracing::info!("health endpoint starting on 0.0.0.0:{}", port); + + let srv = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(db.clone())) + .app_data(web::Data::new(cache.clone())) + .wrap_fn(|req, srv| { + let should_log = should_log_request(req.path()); + let method = req.method().clone(); + let path = req.path().to_owned(); + let peer_addr = + req.connection_info().peer_addr().map(str::to_owned); + let started_at = Instant::now(); + let fut = srv.call(req); + + async move { + match fut.await { + Ok(res) => { + if should_log { + tracing::info!( + method = %method, + path = %path, + status = res.status().as_u16(), + elapsed_ms = started_at.elapsed().as_millis(), + peer_addr = peer_addr.as_deref().unwrap_or("-"), + "http request" + ); + } + Ok(res) + } + Err(err) => { + if should_log { + tracing::warn!( + method = %method, + path = %path, + elapsed_ms = started_at.elapsed().as_millis(), + peer_addr = peer_addr.as_deref().unwrap_or("-"), + error = %err, + "http request failed" + ); + } + Err(err) + } + } + } + }) + .route("/health", web::get().to(health)) + }) + .bind(format!("0.0.0.0:{}", port))?; + + Ok(srv.run()) +} diff --git a/app/gitsync/src/main.rs b/app/gitsync/src/main.rs new file mode 100644 index 0000000..55bd35a --- /dev/null +++ b/app/gitsync/src/main.rs @@ -0,0 +1,51 @@ +mod context; +mod health; +mod shutdown; + +use context::AppContext; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let ctx = AppContext::init().await?; + + tracing::info!("gitsync service starting"); + + let health_port = ctx.config.gitsync_health_port(); + let health_server = + health::start_health(health_port, ctx.db.clone(), ctx.cache.clone())?; + let health_handle = health_server.handle(); + let health_task = tokio::spawn(health_server); + + let sync_service = + git::sync::ReceiveSyncService::new(ctx.redis_pool.clone()); + let consumer = git::sync::consumer::SyncConsumer::new(sync_service, 5); + let worker = git::sync::worker::SyncWorker::new( + consumer, + ctx.db.clone(), + ctx.cache.clone(), + ctx.redis_pool.clone(), + ctx.config.clone(), + format!("gitsync-{}", uuid::Uuid::new_v4()), + ); + + let worker_task = tokio::spawn(async move { worker.run().await }); + + tokio::select! { + result = health_task => { + match result { + Ok(Ok(())) => tracing::info!("health server stopped"), + Ok(Err(e)) => tracing::error!("health server error: {}", e), + Err(e) => tracing::error!("health task panicked: {}", e), + } + } + _ = worker_task => { + tracing::info!("sync worker stopped"); + } + _ = shutdown::wait_for_shutdown_signal() => { + tracing::info!("shutdown signal received, stopping gitsync service"); + health_handle.stop(true).await; + } + } + + Ok(()) +} diff --git a/app/gitsync/src/shutdown.rs b/app/gitsync/src/shutdown.rs new file mode 100644 index 0000000..fe10327 --- /dev/null +++ b/app/gitsync/src/shutdown.rs @@ -0,0 +1,25 @@ +pub async fn wait_for_shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("failed to listen for ctrl_c event"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal( + tokio::signal::unix::SignalKind::terminate(), + ) + .expect("failed to listen for SIGTERM") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } +} diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000..b23bf89 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,137 @@ +# GitDataAI Docker 配置 + +## 文件说明 + +### Dockerfile 文件 + +| 文件名 | 服务 | 说明 | +|--------|------|------| +| `gitdata.Dockerfile` | GitData API | 主 API 服务 | +| `email.Dockerfile` | Email Service | 邮件发送服务 | +| `gitpod.Dockerfile` | GitPod Service | Git 服务 | +| `gitsync.Dockerfile` | GitSync Service | Git 同步服务 | +| `migrate.Dockerfile` | Database Migration | 数据库迁移工具 | +| `web.Dockerfile` | Web Frontend | React 前端应用 | + +### 配置文件 + +| 文件名 | 说明 | +|--------|------| +| `docker-compose.yml` | 完整的开发环境配置 | +| `nginx.conf` | Nginx 反向代理配置 | + +## 快速开始 + +### 1. 启动完整开发环境 + +```bash +# 进入 docker 目录 +cd docker + +# 启动所有服务 +docker-compose up -d + +# 查看服务状态 +docker-compose ps + +# 查看日志 +docker-compose logs -f +``` + +### 2. 单独构建服务 + +```bash +# 构建 GitData API +docker build -f docker/gitdata.Dockerfile -t gitdata-api . + +# 构建前端 +docker build -f docker/web.Dockerfile -t gitdata-web . +``` + +### 3. 环境变量配置 + +创建 `.env` 文件配置环境变量: + +```bash +# 数据库配置 +POSTGRES_USER=gitdata +POSTGRES_PASSWORD=your_secure_password +POSTGRES_DB=app + +# MinIO 配置 +MINIO_ROOT_USER=admin +MINIO_ROOT_PASSWORD=your_secure_password +``` + +## 服务端口 + +| 服务 | 端口 | 说明 | +|------|------|------| +| Web Frontend | 80 | 前端访问入口 | +| GitData API | 8080 | 主 API 服务 | +| Git HTTP | 5023 | Git HTTP 访问 | +| Git RPC | 5030 | Git RPC 服务 | +| SSH | 5022 | SSH Git 访问 | +| GitPod | 5082 | GitPod 服务 | +| GitSync | 5083 | GitSync 健康检查 | +| PostgreSQL | 5432 | 数据库 | +| Redis | 6379 | 缓存 | +| Qdrant | 6333 | 向量数据库 | +| NATS | 4222 | 消息队列 | +| MinIO | 9000/9001 | 对象存储 | + +## 生产环境部署 + +### 1. 修改环境变量 + +```bash +# 复制示例配置 +cp .env.example .env + +# 编辑配置文件,修改密码等敏感信息 +vim .env +``` + +### 2. 启动服务 + +```bash +# 使用生产配置启动 +docker-compose -f docker-compose.yml up -d + +# 查看服务状态 +docker-compose ps +``` + +### 3. 数据备份 + +```bash +# 备份 PostgreSQL +docker exec gitdata-postgres pg_dump -U gitdata app > backup.sql + +# 备份 MinIO 数据 +docker cp gitdata-minio:/data ./minio-backup +``` + +## 常见问题 + +### 1. 服务启动失败 + +检查日志: +```bash +docker-compose logs +``` + +### 2. 数据库连接失败 + +确保 PostgreSQL 健康检查通过: +```bash +docker-compose ps postgres +``` + +### 3. 端口冲突 + +修改 `docker-compose.yml` 中的端口映射: +```yaml +ports: + - "8081:8080" # 修改宿主机端口 +``` \ No newline at end of file diff --git a/docker/build.sh b/docker/build.sh new file mode 100755 index 0000000..d0ccdb8 --- /dev/null +++ b/docker/build.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# GitDataAI Docker Build Script + +set -e + +# Get version from Cargo.toml +PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +CARGO_VERSION=$(grep -m1 'version' "${PROJECT_ROOT}/Cargo.toml" | sed 's/.*"\(.*\)".*/\1/') + +# Configuration +REGISTRY=${REGISTRY:-""} +TAG=${TAG:-"${CARGO_VERSION:-latest}"} +PLATFORM=${PLATFORM:-"linux/amd64"} + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Services to build +SERVICES=("gitdata" "email" "gitpod" "gitsync" "migrate" "web") + +# Function to print colored output +log_info() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +log_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +# Function to build a service +build_service() { + local service=$1 + local dockerfile="${PROJECT_ROOT}/docker/${service}.Dockerfile" + local image_name="gitdata-${service}" + + # Add registry prefix if set + if [ -n "$REGISTRY" ]; then + image_name="${REGISTRY}/${image_name}" + fi + + log_info "Building ${service}..." + + if [ ! -f "$dockerfile" ]; then + log_error "Dockerfile not found: ${dockerfile}" + return 1 + fi + + docker build \ + -f "$dockerfile" \ + -t "${image_name}:${TAG}" \ + --platform "$PLATFORM" \ + "$PROJECT_ROOT" + + log_info "Successfully built ${image_name}:${TAG}" +} + +# Parse command line arguments +BUILD_SERVICES=() +BUILD_ALL=true + +while [[ $# -gt 0 ]]; do + case $1 in + --tag|-t) + TAG="$2" + shift 2 + ;; + --registry|-r) + REGISTRY="$2" + shift 2 + ;; + --platform|-p) + PLATFORM="$2" + shift 2 + ;; + --help|-h) + echo "Usage: $0 [OPTIONS] [SERVICE...]" + echo "" + echo "Options:" + echo " -t, --tag TAG Docker image tag (default: latest)" + echo " -r, --registry REG Docker registry prefix" + echo " -p, --platform PLAT Target platform (default: linux/amd64)" + echo " -h, --help Show this help message" + echo "" + echo "Services:" + echo " gitdata Main API service" + echo " email Email service" + echo " gitpod GitPod service" + echo " gitsync GitSync service" + echo " migrate Database migration" + echo " web Web frontend" + echo "" + echo "Examples:" + echo " $0 # Build all services" + echo " $0 gitdata web # Build specific services" + echo " $0 -t v1.0.0 -r registry.com # Build with custom tag and registry" + exit 0 + ;; + *) + BUILD_SERVICES+=("$1") + BUILD_ALL=false + shift + ;; + esac +done + +# Build services +log_info "Starting Docker build..." +log_info "Registry: ${REGISTRY:-none}" +log_info "Tag: ${TAG}" +log_info "Platform: ${PLATFORM}" + +if [ "$BUILD_ALL" = true ]; then + log_info "Building all services..." + for service in "${SERVICES[@]}"; do + build_service "$service" + done +else + log_info "Building specified services: ${BUILD_SERVICES[*]}" + for service in "${BUILD_SERVICES[@]}"; do + build_service "$service" + done +fi + +log_info "Build completed successfully!" \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000..447cd34 --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,203 @@ +# GitDataAI Docker Compose +# Full stack deployment configuration + +services: + # PostgreSQL Database + postgres: + image: postgres:16-alpine + container_name: gitdata-postgres + environment: + POSTGRES_USER: ${POSTGRES_USER:-gitdata} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-gitdata123} + POSTGRES_DB: ${POSTGRES_DB:-app} + volumes: + - postgres_data:/var/lib/postgresql/data + ports: + - "5432:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-gitdata}"] + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + + # Redis Cluster + redis: + image: redis:7-alpine + container_name: gitdata-redis + ports: + - "6379:6379" + volumes: + - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + + # Qdrant Vector Database + qdrant: + image: qdrant/qdrant:latest + container_name: gitdata-qdrant + ports: + - "6333:6333" + volumes: + - qdrant_data:/qdrant/storage + restart: unless-stopped + + # NATS Message Queue + nats: + image: nats:alpine + container_name: gitdata-nats + ports: + - "4222:4222" + - "8222:8222" + command: "--jetstream" + restart: unless-stopped + + # MinIO S3 Storage + minio: + image: minio/minio:latest + container_name: gitdata-minio + command: server /data --console-address ":9001" + environment: + MINIO_ROOT_USER: ${MINIO_ROOT_USER:-admin} + MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-mysecret123} + ports: + - "9000:9000" + - "9001:9001" + volumes: + - minio_data:/data + restart: unless-stopped + + # Database Migration + migrate: + build: + context: .. + dockerfile: docker/migrate.Dockerfile + container_name: gitdata-migrate + environment: + DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app} + depends_on: + postgres: + condition: service_healthy + restart: "no" + + # GitData Main API Service + gitdata: + build: + context: .. + dockerfile: docker/gitdata.Dockerfile + container_name: gitdata-api + environment: + APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app} + APP_REDIS_URLS: redis://redis:6379 + APP_QDRANT_URL: http://qdrant:6333/ + NATS_URL: nats://nats:4222 + APP_STORAGE_S3_ENDPOINT_URL: http://minio:9000 + APP_STORAGE_S3_ACCESS_KEY_ID: ${MINIO_ROOT_USER:-admin} + APP_STORAGE_S3_SECRET_ACCESS_KEY: ${MINIO_ROOT_PASSWORD:-mysecret123} + ports: + - "8080:8080" + - "5023:5023" + - "5030:5030" + - "5022:5022" + volumes: + - gitdata_repos:/app/data/repos + - gitdata_files:/app/data/files + - gitdata_avatar:/app/data/avatar + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + qdrant: + condition: service_started + nats: + condition: service_started + minio: + condition: service_started + migrate: + condition: service_completed_successfully + restart: unless-stopped + + # Email Service + email: + build: + context: .. + dockerfile: docker/email.Dockerfile + container_name: gitdata-email + environment: + APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app} + APP_REDIS_URLS: redis://redis:6379 + NATS_URL: nats://nats:4222 + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + nats: + condition: service_started + restart: unless-stopped + + # GitPod Service + gitpod: + build: + context: .. + dockerfile: docker/gitpod.Dockerfile + container_name: gitdata-gitpod + environment: + APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app} + APP_REDIS_URLS: redis://redis:6379 + ports: + - "5082:5082" + volumes: + - gitdata_repos:/app/data/repos + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + restart: unless-stopped + + # GitSync Service + gitsync: + build: + context: .. + dockerfile: docker/gitsync.Dockerfile + container_name: gitdata-gitsync + environment: + APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app} + APP_REDIS_URLS: redis://redis:6379 + ports: + - "5083:5083" + volumes: + - gitdata_repos:/app/data/repos + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + restart: unless-stopped + + # Web Frontend + web: + build: + context: .. + dockerfile: docker/web.Dockerfile + container_name: gitdata-web + ports: + - "80:80" + depends_on: + - gitdata + restart: unless-stopped + +volumes: + postgres_data: + redis_data: + qdrant_data: + minio_data: + gitdata_repos: + gitdata_files: + gitdata_avatar: \ No newline at end of file diff --git a/docker/gitdata.Dockerfile b/docker/gitdata.Dockerfile new file mode 100644 index 0000000..6ada483 --- /dev/null +++ b/docker/gitdata.Dockerfile @@ -0,0 +1,74 @@ +# GitDataAI Backend - GitData Service +# Multi-stage build for Rust application + +# Stage 1: Build the application +FROM rust:1.96-bookworm AS builder + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + libpq-dev \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +# Create app directory +WORKDIR /app + +# Copy workspace files +COPY Cargo.toml Cargo.lock ./ +COPY app/ app/ +COPY lib/ lib/ + +# Build the application in release mode +RUN cargo build --release --bin gitdata + +# Stage 2: Create runtime image +FROM debian:bookworm-slim + +# Install runtime dependencies +RUN apt-get update && apt-get install -y \ + libssl3 \ + libpq5 \ + ca-certificates \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN useradd -r -s /bin/false appuser + +# Create directories +RUN mkdir -p /app/data/repos \ + /app/data/files \ + /app/data/avatar \ + /app/logs \ + && chown -R appuser:appuser /app + +# Copy binary from builder +COPY --from=builder /app/target/release/gitdata /app/gitdata + +# Set ownership +RUN chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Set working directory +WORKDIR /app + +# Expose ports +# API port +EXPOSE 8080 +# Git HTTP port +EXPOSE 5023 +# Git RPC port +EXPOSE 5030 +# SSH port +EXPOSE 5022 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Run the application +CMD ["./gitdata"] \ No newline at end of file diff --git a/docker/gitpod.Dockerfile b/docker/gitpod.Dockerfile new file mode 100644 index 0000000..deca8dc --- /dev/null +++ b/docker/gitpod.Dockerfile @@ -0,0 +1,65 @@ +# GitDataAI Backend - GitPod Service +# Multi-stage build for Rust application + +# Stage 1: Build the application +FROM rust:1.96-bookworm AS builder + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + libpq-dev \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +# Create app directory +WORKDIR /app + +# Copy workspace files +COPY Cargo.toml Cargo.lock ./ +COPY app/ app/ +COPY lib/ lib/ + +# Build the application in release mode +RUN cargo build --release --bin gitpod + +# Stage 2: Create runtime image +FROM debian:bookworm-slim + +# Install runtime dependencies +RUN apt-get update && apt-get install -y \ + libssl3 \ + libpq5 \ + ca-certificates \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN useradd -r -s /bin/false appuser + +# Create directories +RUN mkdir -p /app/data/repos \ + /app/logs \ + && chown -R appuser:appuser /app + +# Copy binary from builder +COPY --from=builder /app/target/release/gitpod /app/gitpod + +# Set ownership +RUN chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Set working directory +WORKDIR /app + +# Expose port +EXPOSE 5082 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:5082/health || exit 1 + +# Run the application +CMD ["./gitpod"] \ No newline at end of file diff --git a/docker/gitsync.Dockerfile b/docker/gitsync.Dockerfile new file mode 100644 index 0000000..6ce8777 --- /dev/null +++ b/docker/gitsync.Dockerfile @@ -0,0 +1,65 @@ +# GitDataAI Backend - GitSync Service +# Multi-stage build for Rust application + +# Stage 1: Build the application +FROM rust:1.96-bookworm AS builder + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + libpq-dev \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +# Create app directory +WORKDIR /app + +# Copy workspace files +COPY Cargo.toml Cargo.lock ./ +COPY app/ app/ +COPY lib/ lib/ + +# Build the application in release mode +RUN cargo build --release --bin gitsync + +# Stage 2: Create runtime image +FROM debian:bookworm-slim + +# Install runtime dependencies +RUN apt-get update && apt-get install -y \ + libssl3 \ + libpq5 \ + ca-certificates \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN useradd -r -s /bin/false appuser + +# Create directories +RUN mkdir -p /app/data/repos \ + /app/logs \ + && chown -R appuser:appuser /app + +# Copy binary from builder +COPY --from=builder /app/target/release/gitsync /app/gitsync + +# Set ownership +RUN chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Set working directory +WORKDIR /app + +# Expose health check port +EXPOSE 5083 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:5083/health || exit 1 + +# Run the application +CMD ["./gitsync"] \ No newline at end of file diff --git a/docker/migrate.Dockerfile b/docker/migrate.Dockerfile new file mode 100644 index 0000000..dab22a0 --- /dev/null +++ b/docker/migrate.Dockerfile @@ -0,0 +1,58 @@ +# GitDataAI Database Migration Dockerfile +# Multi-stage build for Rust migration tool + +# Stage 1: Build the application +FROM rust:1.96-bookworm AS builder + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + libpq-dev \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +# Create app directory +WORKDIR /app + +# Copy workspace files +COPY Cargo.toml Cargo.lock ./ +COPY app/ app/ +COPY lib/ lib/ + +# Build the migration binary +RUN cargo build --release --bin migrate + +# Stage 2: Create runtime image +FROM debian:bookworm-slim + +# Install runtime dependencies +RUN apt-get update && apt-get install -y \ + libssl3 \ + libpq5 \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN useradd -r -s /bin/false appuser + +# Create app directory +RUN mkdir -p /app && chown -R appuser:appuser /app + +# Copy binary from builder +COPY --from=builder /app/target/release/migrate /app/migrate + +# Copy migration files +COPY --from=builder /app/lib/migrate/sql /app/sql + +# Set ownership +RUN chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Set working directory +WORKDIR /app + +# Run migrations by default +CMD ["./migrate", "up"] \ No newline at end of file diff --git a/docker/nginx.conf b/docker/nginx.conf new file mode 100644 index 0000000..8455432 --- /dev/null +++ b/docker/nginx.conf @@ -0,0 +1,75 @@ +server { + listen 80; + server_name localhost; + + # Gzip compression + gzip on; + gzip_vary on; + gzip_min_length 1024; + gzip_proxied any; + gzip_comp_level 6; + gzip_types + text/plain + text/css + text/xml + text/javascript + application/json + application/javascript + application/xml + application/rss+xml + image/svg+xml; + + # Security headers + add_header X-Frame-Options "SAMEORIGIN" always; + add_header X-Content-Type-Options "nosniff" always; + add_header X-XSS-Protection "1; mode=block" always; + add_header Referrer-Policy "strict-origin-when-cross-origin" always; + + # Root directory + root /usr/share/nginx/html; + index index.html; + + # Enable static asset caching + location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ { + expires 1y; + add_header Cache-Control "public, immutable"; + try_files $uri =404; + } + + # API proxy (if needed) + location /api/ { + proxy_pass http://gitdata:8080/api/; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection 'upgrade'; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_cache_bypass $http_upgrade; + } + + # Socket.IO proxy + location /socket.io/ { + proxy_pass http://gitdata:8080/socket.io/; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } + + # SPA fallback + location / { + try_files $uri $uri/ /index.html; + } + + # Health check endpoint + location /health { + access_log off; + return 200 'OK'; + add_header Content-Type text/plain; + } +} \ No newline at end of file diff --git a/docker/push.sh b/docker/push.sh new file mode 100755 index 0000000..7804c5b --- /dev/null +++ b/docker/push.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# GitDataAI Docker Push Script + +set -e + +# Get version from Cargo.toml +PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +CARGO_VERSION=$(grep -m1 'version' "${PROJECT_ROOT}/Cargo.toml" | sed 's/.*"\(.*\)".*/\1/') + +# Configuration +REGISTRY=${REGISTRY:-""} +TAG=${TAG:-"${CARGO_VERSION:-latest}"} + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Services to push +SERVICES=("gitdata" "email" "gitpod" "gitsync" "migrate" "web") + +# Function to print colored output +log_info() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +log_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +# Function to push a service +push_service() { + local service=$1 + local image_name="gitdata-${service}" + + # Add registry prefix if set + if [ -n "$REGISTRY" ]; then + image_name="${REGISTRY}/${image_name}" + fi + + log_info "Pushing ${service}..." + + # Check if image exists locally + if ! docker image inspect "${image_name}:${TAG}" > /dev/null 2>&1; then + log_error "Image not found: ${image_name}:${TAG}" + log_error "Please build the image first with: ./build.sh ${service}" + return 1 + fi + + docker push "${image_name}:${TAG}" + + log_info "Successfully pushed ${image_name}:${TAG}" +} + +# Parse command line arguments +PUSH_SERVICES=() +PUSH_ALL=true + +while [[ $# -gt 0 ]]; do + case $1 in + --tag|-t) + TAG="$2" + shift 2 + ;; + --registry|-r) + REGISTRY="$2" + shift 2 + ;; + --help|-h) + echo "Usage: $0 [OPTIONS] [SERVICE...]" + echo "" + echo "Options:" + echo " -t, --tag TAG Docker image tag (default: latest)" + echo " -r, --registry REG Docker registry prefix (required)" + echo " -h, --help Show this help message" + echo "" + echo "Services:" + echo " gitdata Main API service" + echo " email Email service" + echo " gitpod GitPod service" + echo " gitsync GitSync service" + echo " migrate Database migration" + echo " web Web frontend" + echo "" + echo "Examples:" + echo " $0 -r registry.com # Push all services" + echo " $0 -r registry.com gitdata web # Push specific services" + echo " $0 -r registry.com -t v1.0.0 # Push with custom tag" + exit 0 + ;; + *) + PUSH_SERVICES+=("$1") + PUSH_ALL=false + shift + ;; + esac +done + +# Validate registry +if [ -z "$REGISTRY" ]; then + log_error "Registry is required. Use -r or --registry to specify." + echo "Example: $0 -r registry.com" + exit 1 +fi + +# Push services +log_info "Starting Docker push..." +log_info "Registry: ${REGISTRY}" +log_info "Tag: ${TAG}" + +if [ "$PUSH_ALL" = true ]; then + log_info "Pushing all services..." + for service in "${SERVICES[@]}"; do + push_service "$service" + done +else + log_info "Pushing specified services: ${PUSH_SERVICES[*]}" + for service in "${PUSH_SERVICES[@]}"; do + push_service "$service" + done +fi + +log_info "Push completed successfully!" \ No newline at end of file diff --git a/docker/web.Dockerfile b/docker/web.Dockerfile new file mode 100644 index 0000000..a45e549 --- /dev/null +++ b/docker/web.Dockerfile @@ -0,0 +1,62 @@ +# GitDataAI Frontend Dockerfile +# Multi-stage build for React application with Bun + +# Stage 1: Build the application +FROM node:24-bookworm AS builder + +# Install bun +RUN npm install -g bun + +# Create app directory +WORKDIR /app + +# Copy package files +COPY package.json bun.lock ./ + +# Install dependencies +RUN bun install --frozen-lockfile + +# Copy source code +COPY src/ src/ +COPY public/ public/ +COPY index.html ./ +COPY vite.config.ts ./ +COPY tsconfig*.json ./ +COPY eslint.config.js ./ +COPY components.json ./ +COPY orval.config.ts ./ + +# Build the application +RUN bun run build + +# Stage 2: Create runtime image with Nginx +FROM nginx:alpine + +# Copy custom nginx configuration +COPY docker/nginx.conf /etc/nginx/conf.d/default.conf + +# Copy built assets from builder +COPY --from=builder /app/dist /usr/share/nginx/html + +# Create non-root user +RUN adduser -D -S -h /var/cache/nginx -s /sbin/nologin -G nginx appuser + +# Set ownership +RUN chown -R appuser:nginx /var/cache/nginx \ + && chown -R appuser:nginx /var/log/nginx \ + && chown -R appuser:nginx /etc/nginx/conf.d \ + && touch /var/run/nginx.pid \ + && chown -R appuser:nginx /var/run/nginx.pid + +# Switch to non-root user +USER appuser + +# Expose port +EXPOSE 80 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD wget --no-verbose --tries=1 --spider http://localhost:80/ || exit 1 + +# Start Nginx +CMD ["nginx", "-g", "daemon off;"] \ No newline at end of file diff --git a/lib/ai/Cargo.toml b/lib/ai/Cargo.toml new file mode 100644 index 0000000..c58fbfb --- /dev/null +++ b/lib/ai/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "ai" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "ai" +[dependencies] +rig-core = { workspace = true, features = ["derive"] } +tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true } +tokio-stream = { workspace = true } +config = { workspace = true } +cache = { workspace = true } +db = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +serde_json = { workspace = true } +serde = { workspace = true, features = ["derive"] } +qdrant-client = { workspace = true, features = ["serde"] } +async-trait = { workspace = true } +redis = { workspace = true } +uuid = { workspace = true, features = ["v4", "v5", "serde"] } +reqwest = { workspace = true } +futures = { workspace = true } +[lints] +workspace = true diff --git a/lib/ai/agent/agent.rs b/lib/ai/agent/agent.rs new file mode 100644 index 0000000..5902500 --- /dev/null +++ b/lib/ai/agent/agent.rs @@ -0,0 +1,543 @@ +use futures::StreamExt; +use rig::agent::AgentBuilder; +use rig::client::CompletionClient; +use rig::streaming::StreamingPrompt; +use rig::tool::ToolDyn; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::{info, warn}; + +use super::config::AgentConfig; +use super::helpers::{build_input_string, check_token_budget, estimate_tokens}; +use super::hooks::{HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision}; +use super::persistence::ActiveAgentRun; +use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; +use super::subagent::run_experts; +use super::RigStreamChunk; +use crate::client::AiClient; +use crate::error::{AiError, AiResult}; + +pub struct RigAgent { + pub client: AiClient, + pub config: AgentConfig, + pub hooks: HookChain, +} + +impl RigAgent { + pub fn new(client: AiClient, config: AgentConfig) -> AiResult { + config.validate()?; + Ok(Self { + client, + config, + hooks: HookChain::empty(), + }) + } + + pub fn with_hooks(mut self, hooks: HookChain) -> Self { + self.hooks = hooks; + self + } + + pub fn config(&self) -> &AgentConfig { + &self.config + } + + pub async fn chat( + &self, + request: AgentRequest, + tools: Vec>, + ) -> AiResult { + let (mut rx, handle) = self.run(request, tools); + tokio::spawn(async move { + while rx.recv().await.is_some() {} + }); + let result = handle.await.map_err(|_| { + AiError::Response("agent task panicked".to_string()) + })?; + result.map(|r| r.output) + } + + #[allow(clippy::too_many_lines)] + pub fn run( + &self, + request: AgentRequest, + tools: Vec>, + ) -> ( + tokio::sync::mpsc::Receiver, + tokio::task::JoinHandle>, + ) { + let (tx, rx) = mpsc::channel::(256); + + let model_name = self.config.model.clone(); + let max_iterations = self.config.max_iterations; + let client = self.client.llm_client().clone(); + let ai_client = self.client.clone(); + let agent_config = self.config.clone(); + let system_prompt = self.config.system_prompt.clone(); + let temperature = self.config.temperature; + let max_completion_tokens = self.config.max_completion_tokens; + let max_total_tokens = self.config.max_total_tokens_per_run; + let cancellation = request.cancellation_token.clone(); + let timeout = request.timeout; + let hooks = self.hooks.clone(); + + let filtered_tools: Vec> = tools + .into_iter() + .filter(|tool| self.config.is_tool_exposed(&tool.name())) + .collect(); + + let handle = tokio::spawn(async move { + execute_agent_run( + client, + model_name, + system_prompt, + request, + filtered_tools, + max_iterations, + ai_client, + agent_config, + temperature, + max_completion_tokens, + max_total_tokens, + cancellation, + timeout, + hooks, + tx, + ) + .await + }); + + (rx, handle) + } +} + +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] +async fn execute_agent_run( + client: rig::providers::openai::Client, + model_name: String, + system_prompt: String, + request: AgentRequest, + tools: Vec>, + max_iterations: usize, + ai_client: AiClient, + agent_config: AgentConfig, + temperature: Option, + max_completion_tokens: Option, + max_total_tokens: Option, + cancellation: Option, + timeout: Option, + hooks: HookChain, + tx: mpsc::Sender, +) -> AiResult { + if let Some(ref ctx) = request.run_context { + let _ = hooks.run_session_start(ctx).await; + } + + let model = client.completion_model(&model_name); + let mut agent_builder = AgentBuilder::new(model) + .preamble(&system_prompt) + .tools(tools) + .default_max_turns(max_iterations); + + if let Some(temp) = temperature { + agent_builder = agent_builder.temperature(temp); + } + if let Some(mt) = max_completion_tokens { + agent_builder = agent_builder.max_tokens(mt); + } + + let agent = agent_builder.build(); + let mut input = build_input_string(&request); + + // ---- SubAgent execution ---- + let expert_outputs = if !request.experts.is_empty() { + let run = ActiveAgentRun { + conversation_id: request.run_context.as_ref().and_then(|c| c.conversation_id), + message_id: None, + invocation_id: request.run_context.as_ref().and_then(|c| c.invocation_id), + session_id: request.run_context.as_ref().and_then(|c| c.session_id), + user_id: request.run_context.as_ref().and_then(|c| c.user_id), + started_at: std::time::Instant::now(), + current_step: 0, + }; + let realtime = request.run_context.as_ref().and_then(|c| c.realtime.as_ref()); + + // Notify frontend that subagents are starting. + for expert in &request.experts { + let _ = tx + .send(RigStreamChunk::SubagentStarted { + subagent_id: expert.id.clone(), + role: expert.role.clone(), + task: expert.task.clone(), + }) + .await; + } + + match run_experts(&ai_client, &agent_config, &request.experts, realtime, &run).await { + Ok(outputs) => { + for out in &outputs { + let _ = tx + .send(RigStreamChunk::SubagentCompleted { + subagent_id: out.id.clone(), + role: out.role.clone(), + task: out.task.clone(), + output: out.output.clone(), + }) + .await; + input.push_str(&format!( + "\n--- Subagent: {} (role: {}) ---\nTask: {}\nResult: {}\n", + out.id, out.role, out.task, out.output + )); + } + outputs + } + Err(e) => { + warn!(error = %e, "subagent execution failed, continuing without expert inputs"); + let _ = tx + .send(RigStreamChunk::SubagentFailed { + error: e.to_string(), + }) + .await; + Vec::new() + } + } + } else { + Vec::new() + }; + + let estimated_input_tokens = estimate_tokens(&input); + + if let Some(limit) = max_total_tokens + && estimated_input_tokens > limit as u64 + { + return Err(AiError::TokenBudgetExceeded { + estimated: estimated_input_tokens, + limit, + }); + } + + if !hooks.is_empty() { + let hook_messages: Vec = request + .messages + .iter() + .map(|m| HookMessage { + role: match m { + super::request::AgentMessage::User(_) => "user".to_string(), + super::request::AgentMessage::Assistant(_) => { + "assistant".to_string() + } + }, + content: match m { + super::request::AgentMessage::User(c) => Some(c.clone()), + super::request::AgentMessage::Assistant(c) => { + Some(c.clone()) + } + }, + tool_calls: None, + tool_call_id: None, + }) + .collect(); + let hook_tools: Vec = Vec::new(); + let _ = hooks.run_pre_llm_call(&hook_messages, &hook_tools).await; + } + + let stream_future = agent + .stream_prompt(&input) + .with_history(Vec::::new()) + .multi_turn(max_iterations); + + let stream = if let Some(dur) = timeout { + match tokio::time::timeout(dur, stream_future).await { + Ok(stream) => stream, + Err(_elapsed) => { + let _ = tx + .send(RigStreamChunk::Failed { + error: format!("agent timed out after {}s", dur.as_secs()), + }) + .await; + return Err(AiError::Timeout { + seconds: dur.as_secs(), + }); + } + } + } else { + stream_future.await + }; + + tokio::pin!(stream); + + let mut steps = Vec::new(); + let mut delta_index = 0usize; + let mut current_step_tool_calls: Vec = Vec::new(); + let mut current_step_assistant = String::new(); + let mut current_step_reasoning = String::new(); + let mut accumulated_output_chars: usize = 0; + + while let Some(item) = stream.next().await { + if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { + let _ = tx + .send(RigStreamChunk::Failed { + error: "cancelled".to_string(), + }) + .await; + return Err(AiError::Response("agent run cancelled".to_string())); + } + + if let Some(limit) = max_total_tokens + && check_token_budget(estimated_input_tokens, accumulated_output_chars, limit) + { + let _ = tx + .send(RigStreamChunk::Failed { + error: format!("token budget exceeded: limit {limit}"), + }) + .await; + return Err(AiError::TokenBudgetExceeded { + estimated: estimated_input_tokens + + (accumulated_output_chars as f64 / 2.5).ceil() as u64, + limit, + }); + } + + match item { + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::Text(text), + )) => { + accumulated_output_chars += text.text.chars().count(); + current_step_assistant.push_str(&text.text); + let _ = tx + .send(RigStreamChunk::TextDelta { + index: delta_index, + content: text.text.clone(), + }) + .await; + delta_index += 1; + } + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::Reasoning(reasoning), + )) => { + for part in &reasoning.content { + if let rig::completion::message::ReasoningContent::Text { + text, .. + } = part + { + accumulated_output_chars += text.chars().count(); + current_step_reasoning.push_str(text); + let _ = tx + .send(RigStreamChunk::Thinking { + index: delta_index, + content: text.clone(), + }) + .await; + delta_index += 1; + } + } + } + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::ReasoningDelta { + reasoning, .. + }, + )) => { + accumulated_output_chars += reasoning.chars().count(); + current_step_reasoning.push_str(&reasoning); + let _ = tx + .send(RigStreamChunk::Thinking { + index: delta_index, + content: reasoning.clone(), + }) + .await; + delta_index += 1; + } + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::ToolCall { + tool_call, + internal_call_id: _, + }, + )) => { + let args = match &tool_call.function.arguments { + serde_json::Value::String(s) => s.clone(), + v => serde_json::to_string(v).unwrap_or_default(), + }; + accumulated_output_chars += args.chars().count(); + + let tool_name = tool_call.function.name.clone(); + let tool_args: serde_json::Value = + serde_json::from_str(&args).unwrap_or_default(); + + if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { + match decision { + ToolGuardrailDecision::Allow => {} + ToolGuardrailDecision::Block { reason } => { + let _ = tx + .send(RigStreamChunk::ToolCallFinished { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + output: format!("blocked: {reason}"), + error: Some(reason), + }) + .await; + current_step_tool_calls.push(ToolCallRecord { + id: tool_call.id.clone(), + name: tool_name.clone(), + arguments: tool_args.clone(), + output: None, + error: Some("blocked by guardrail".to_string()), + elapsed_ms: None, + }); + continue; + } + ToolGuardrailDecision::RequireApproval { message } => { + let _ = tx + .send(RigStreamChunk::ToolCallFinished { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + output: format!("awaiting approval: {message}"), + error: None, + }) + .await; + current_step_tool_calls.push(ToolCallRecord { + id: tool_call.id.clone(), + name: tool_name.clone(), + arguments: tool_args.clone(), + output: None, + error: Some(format!("requires approval: {message}")), + elapsed_ms: None, + }); + continue; + } + } + } + + let _ = tx + .send(RigStreamChunk::ToolCallStarted { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + arguments: args.clone(), + }) + .await; + current_step_tool_calls.push(ToolCallRecord { + id: tool_call.id.clone(), + name: tool_name.clone(), + arguments: tool_args.clone(), + output: None, + error: None, + elapsed_ms: None, + }); + } + Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( + rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, + )) => { + let content = + super::helpers::tool_result_content_to_string(&tool_result.content); + accumulated_output_chars += content.chars().count(); + + if let Some(last) = current_step_tool_calls.last_mut() + && last.id == tool_result.id + { + last.output = Some(serde_json::from_str(&content).unwrap_or_default()); + } + + let tool_name = current_step_tool_calls + .last() + .map(|tc| tc.name.clone()) + .unwrap_or_default(); + + let _ = tx + .send(RigStreamChunk::ToolCallFinished { + tool_call_id: tool_result.id.clone(), + tool_name, + output: content.clone(), + error: None, + }) + .await; + + if !hooks.is_empty() { + let outcome = ToolCallOutcome { + name: tool_result.id.clone(), + arguments: serde_json::Value::Null, + output: Some(serde_json::Value::String(content)), + error: None, + elapsed_ms: 0, + }; + let _ = hooks.run_post_tool_call(&outcome).await; + } + } + Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { + let usage = resp.usage(); + + if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() { + let reasoning = (!current_step_reasoning.is_empty()) + .then_some(std::mem::take(&mut current_step_reasoning)); + steps.push(AgentStep { + index: steps.len(), + assistant: (!current_step_assistant.is_empty()) + .then_some(std::mem::take(&mut current_step_assistant)), + reasoning_content: reasoning, + tool_calls: std::mem::take(&mut current_step_tool_calls), + reflection: None, + }); + } + let output = steps + .last() + .and_then(|s| s.assistant.clone()) + .unwrap_or_default(); + + if !hooks.is_empty() { + let hook_response = HookLlmResponse { + content: Some(output.clone()), + tool_calls: None, + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + finish_reason: None, + }; + let _ = hooks.run_post_llm_call(&hook_response).await; + } + + info!( + steps = steps.len(), + input_tokens = usage.input_tokens, + output_tokens = usage.output_tokens, + "agent run completed" + ); + + let _ = tx + .send(RigStreamChunk::Final { + content: output.clone(), + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + }) + .await; + + if let Some(ref ctx) = request.run_context { + let _ = hooks.run_session_end(ctx, true).await; + } + + return Ok(AgentResult { + output, + steps, + expert_outputs, + input_tokens: usage.input_tokens as i64, + output_tokens: usage.output_tokens as i64, + }); + } + Err(e) => { + let err = format!("{e}"); + warn!(error = %err, "agent stream error"); + let _ = tx.send(RigStreamChunk::Failed { error: err }).await; + + if let Some(ref ctx) = request.run_context { + let _ = hooks.run_session_end(ctx, false).await; + } + return Err(AiError::Api(format!("{e}"))); + } + _ => {} + } + } + + Err(AiError::Response("agent stream ended without final response".to_string())) +} + +impl Clone for HookChain { + fn clone(&self) -> Self { + HookChain::empty() + } +} diff --git a/lib/ai/agent/compression.rs b/lib/ai/agent/compression.rs new file mode 100644 index 0000000..331141c --- /dev/null +++ b/lib/ai/agent/compression.rs @@ -0,0 +1,222 @@ +use crate::error::AiResult; + +/// Compression strategy controlling when and how context compaction occurs. +#[derive(Clone, Debug)] +pub struct CompressionStrategy { + /// Token threshold that triggers compaction. + pub threshold_tokens: i64, + /// Target token count after compaction. + pub target_tokens: i64, + /// Number of recent message pairs to always preserve. + pub preserve_last_n_pairs: usize, + /// Optional model override for the compaction LLM call. + pub summary_model: String, + /// Reserve this many tokens for the compaction prompt itself. + pub reserve_tokens: i64, + /// Whether to generate branch summaries when forking. + pub branch_summarization: bool, + /// Custom instructions appended to the compaction prompt. + pub custom_instructions: Option, + /// Maximum word count for compaction summaries. + pub max_summary_words: usize, +} + +impl Default for CompressionStrategy { + fn default() -> Self { + Self { + threshold_tokens: 64_000, + target_tokens: 32_000, + preserve_last_n_pairs: 4, + summary_model: String::new(), + reserve_tokens: 16_384, + branch_summarization: true, + custom_instructions: None, + max_summary_words: 1500, + } + } +} + +impl CompressionStrategy { + pub fn new(threshold_tokens: i64, target_tokens: i64) -> Self { + Self { + threshold_tokens, + target_tokens, + ..Default::default() + } + } + + pub fn with_preserve_last(mut self, n: usize) -> Self { + self.preserve_last_n_pairs = n; + self + } + + pub fn with_summary_model(mut self, model: impl Into) -> Self { + self.summary_model = model.into(); + self + } + + pub fn with_reserve_tokens(mut self, tokens: i64) -> Self { + self.reserve_tokens = tokens; + self + } + + pub fn with_branch_summarization(mut self, enabled: bool) -> Self { + self.branch_summarization = enabled; + self + } + + pub fn with_custom_instructions(mut self, instructions: impl Into) -> Self { + self.custom_instructions = Some(instructions.into()); + self + } + + pub fn with_max_summary_words(mut self, words: usize) -> Self { + self.max_summary_words = words; + self + } + + /// Check whether compaction should be triggered based on current token count. + pub fn should_compact(&self, current_tokens: i64) -> bool { + current_tokens >= self.threshold_tokens + } +} + +#[derive(Debug, Clone)] +pub struct CompactionResult { + pub summary: String, + pub messages_compacted: usize, + pub tokens_saved: i64, + /// Whether this was a branch summary (vs. standard compaction). + pub is_branch_summary: bool, +} + +impl CompactionResult { + pub fn new(summary: String, messages_compacted: usize, tokens_saved: i64) -> Self { + Self { + summary, + messages_compacted, + tokens_saved, + is_branch_summary: false, + } + } + + pub fn branch_summary(summary: String, entries_summarized: usize) -> Self { + Self { + summary, + messages_compacted: entries_summarized, + tokens_saved: 0, + is_branch_summary: true, + } + } +} + +/// Build the compaction prompt for standard context compression. +pub fn build_compression_prompt( + existing_summary: Option<&str>, + messages_text: &str, +) -> String { + build_compression_prompt_with_options(existing_summary, messages_text, None, 1500) +} + +/// Build the compaction prompt with custom instructions and word limit. +pub fn build_compression_prompt_with_options( + existing_summary: Option<&str>, + messages_text: &str, + custom_instructions: Option<&str>, + max_words: usize, +) -> String { + let custom = custom_instructions + .map(|ci| format!("\n\nAdditional instructions: {ci}")) + .unwrap_or_default(); + + if let Some(summary) = existing_summary { + format!( + "## Previous Summary\n{summary}\n\n## New Messages\n{messages_text}\n\n\ + Combine the previous summary and the new messages into a concise, \ + single-paragraph summary of the conversation. Preserve facts, \ + decisions, code snippets, and anything essential for continuing \ + work. Target up to {max_words} words.{custom} \ + Output ONLY the summary text, no preamble.", + ) + } else { + format!( + "## Conversation\n{messages_text}\n\n\ + Summarise the conversation above into a concise, single-paragraph \ + summary. Preserve facts, decisions, code snippets, and anything \ + essential for continuing work. Target up to {max_words} words.{custom} \ + Output ONLY the summary text, no preamble.", + ) + } +} + +/// Build a prompt for generating a branch summary. +/// +/// Used when the user forks a conversation from a different point in the +/// session tree. Summarizes the divergent branch so context is preserved. +pub fn build_branch_summary_prompt( + branch_messages: &str, + custom_instructions: Option<&str>, +) -> String { + let custom = custom_instructions + .map(|ci| format!("\n\nAdditional instructions: {ci}")) + .unwrap_or_default(); + + format!( + "## Branch Conversation\n{branch_messages}\n\n\ + Summarize the conversation branch above. This summary will be used \ + to preserve context when the user navigates away from this branch. \ + Focus on key decisions, unresolved questions, and important context.{custom} \ + Output ONLY the summary text, no preamble.", + ) +} + +/// Calculate how many messages to truncate to reach the target token count. +pub fn estimate_truncation( + message_token_counts: &[i64], + current_total: i64, + target: i64, + preserve_last: usize, +) -> AiResult<(usize, i64)> { + let n = message_token_counts.len(); + if n <= preserve_last { + return Ok((0, 0)); + } + + let excess = (current_total - target).max(0); + + let mut truncated = 0; + let mut saved = 0i64; + let limit = n - preserve_last; + + for i in 0..limit { + if saved >= excess { + break; + } + saved += message_token_counts[i]; + truncated += 1; + } + + Ok((truncated, saved.min(excess))) +} + +/// Calculate compaction parameters for a given set of messages. +/// +/// Returns `(messages_to_compact, tokens_saved)` where `messages_to_compact` +/// is the count of oldest messages to summarize, and `tokens_saved` is the +/// estimated token savings. +pub fn plan_compaction( + strategy: &CompressionStrategy, + message_token_counts: &[i64], + current_total: i64, +) -> AiResult<(usize, i64)> { + if !strategy.should_compact(current_total) { + return Ok((0, 0)); + } + + estimate_truncation( + message_token_counts, + current_total, + strategy.target_tokens, + strategy.preserve_last_n_pairs * 2, // pairs → individual messages + ) +} diff --git a/lib/ai/agent/config.rs b/lib/ai/agent/config.rs new file mode 100644 index 0000000..d38c3a5 --- /dev/null +++ b/lib/ai/agent/config.rs @@ -0,0 +1,217 @@ +use crate::error::{AiError, AiResult}; + +pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are a precise autonomous agent that executes tasks through tool calls. + +## Core Principles +- Use tools when they can materially improve correctness or efficiency +- After each action, verify results and adjust approach if needed +- Keep reasoning concise and focus on actionable outcomes +- Return only the final useful answer to the user + +## Workflow +1. Analyze the request and plan your approach +2. Execute actions using appropriate tools +3. Review observations and verify assumptions +4. Iterate until the task is complete +5. Provide a clear, concise final response + +## Title Generation +If this is the first user message in a new conversation with a default title, you SHOULD call `set_conversation_title` as your first action to create a short, descriptive title (max 100 chars). This helps keep the conversation organized. Only do this once at the very beginning."#; + +#[derive(Clone, Debug)] +pub struct AgentConfig { + pub model: String, + pub provider: String, + pub api_mode: String, + pub system_prompt: String, + pub max_iterations: usize, + pub iteration_budget: usize, + pub temperature: Option, + pub max_completion_tokens: Option, + pub max_total_tokens_per_run: Option, + pub enabled_toolsets: Vec, + pub disabled_toolsets: Vec, + pub allowed_tools: Vec, + pub denied_tools: Vec, + pub retry_max_attempts: usize, + pub retry_base_delay_ms: u64, + pub retry_jitter: bool, + pub fallback_model: Option, + pub skip_memory: bool, + pub skip_context_files: bool, + pub skip_compression: bool, + pub quiet_mode: bool, + pub save_trajectories: bool, + pub reasoning_effort: Option, + pub service_tier: Option, + pub platform: Option, + pub session_id: Option, +} + +impl AgentConfig { + pub fn new(model: impl Into) -> AiResult { + let config = Self { + model: model.into(), + provider: String::new(), + api_mode: String::from("chat_completions"), + system_prompt: DEFAULT_SYSTEM_PROMPT.to_string(), + max_iterations: 64, + iteration_budget: 90, + temperature: Some(0.2), + max_completion_tokens: None, + max_total_tokens_per_run: Some(128_000), + enabled_toolsets: Vec::new(), + disabled_toolsets: Vec::new(), + allowed_tools: Vec::new(), + denied_tools: Vec::new(), + retry_max_attempts: 3, + retry_base_delay_ms: 1_000, + retry_jitter: true, + fallback_model: None, + skip_memory: false, + skip_context_files: false, + skip_compression: false, + quiet_mode: false, + save_trajectories: false, + reasoning_effort: None, + service_tier: None, + platform: None, + session_id: None, + }; + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> AiResult<()> { + if self.model.trim().is_empty() { + return Err(AiError::Config("agent model is required".to_string())); + } + if self.max_iterations == 0 { + return Err(AiError::Config( + "agent max_iterations must be greater than 0".to_string(), + )); + } + if let Some(tokens) = self.max_total_tokens_per_run + && tokens <= 0 + { + return Err(AiError::Config( + "agent max_total_tokens_per_run must be > 0".to_string(), + )); + } + Ok(()) + } + + pub fn with_provider(mut self, provider: impl Into) -> Self { + self.provider = provider.into(); + self + } + + pub fn with_api_mode(mut self, mode: impl Into) -> Self { + self.api_mode = mode.into(); + self + } + + pub fn with_max_iterations(mut self, max: usize) -> Self { + self.max_iterations = max; + self.iteration_budget = self.iteration_budget.max(max); + self + } + + pub fn with_iteration_budget(mut self, budget: usize) -> Self { + self.iteration_budget = budget; + self + } + + pub fn with_system_prompt(mut self, prompt: impl Into) -> Self { + self.system_prompt = prompt.into(); + self + } + + pub fn with_temperature(mut self, temperature: Option) -> Self { + self.temperature = temperature; + self + } + + pub fn with_max_completion_tokens(mut self, max_completion_tokens: Option) -> Self { + self.max_completion_tokens = max_completion_tokens; + self + } + + pub fn with_max_total_tokens(mut self, limit: Option) -> Self { + self.max_total_tokens_per_run = limit; + self + } + + pub fn with_toolset_policy(mut self, enabled: Vec, disabled: Vec) -> Self { + self.enabled_toolsets = enabled; + self.disabled_toolsets = disabled; + self + } + + pub fn with_tool_policy(mut self, allowed_tools: Vec, denied_tools: Vec) -> Self { + self.allowed_tools = allowed_tools; + self.denied_tools = denied_tools; + self + } + + pub fn with_retry(mut self, max_attempts: usize, base_delay_ms: u64) -> Self { + self.retry_max_attempts = max_attempts; + self.retry_base_delay_ms = base_delay_ms; + self + } + + pub fn with_retry_jitter(mut self, jitter: bool) -> Self { + self.retry_jitter = jitter; + self + } + + pub fn with_fallback_model(mut self, fallback_model: impl Into) -> Self { + self.fallback_model = Some(fallback_model.into()); + self + } + + pub fn with_skip_memory(mut self, skip: bool) -> Self { + self.skip_memory = skip; + self + } + + pub fn with_skip_compression(mut self, skip: bool) -> Self { + self.skip_compression = skip; + self + } + + pub fn with_quiet_mode(mut self, quiet: bool) -> Self { + self.quiet_mode = quiet; + self + } + + pub fn with_platform(mut self, platform: impl Into) -> Self { + self.platform = Some(platform.into()); + self + } + + pub fn with_session_id(mut self, session_id: uuid::Uuid) -> Self { + self.session_id = Some(session_id); + self + } + + pub fn with_reasoning_effort(mut self, effort: impl Into) -> Self { + self.reasoning_effort = Some(effort.into()); + self + } + + pub fn is_tool_exposed(&self, name: &str) -> bool { + let denied = self.denied_tools.iter().any(|tool| tool == name); + if denied { + return false; + } + if self.allowed_tools.is_empty() { + return true; + } + self.allowed_tools.iter().any(|tool| tool == name) + } +} + +pub fn default_system_prompt() -> &'static str { + DEFAULT_SYSTEM_PROMPT +} diff --git a/lib/ai/agent/error_classifier.rs b/lib/ai/agent/error_classifier.rs new file mode 100644 index 0000000..18718d3 --- /dev/null +++ b/lib/ai/agent/error_classifier.rs @@ -0,0 +1,239 @@ +use std::time::Duration; + +use crate::error::AiError; + +/// Categorized error for deciding retry/fallback/fatal strategy. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ErrorCategory { + /// Transient error, safe to retry with backoff. + Retryable { reason: String }, + /// Authentication or quota error, switch to fallback model. + FallbackModel { reason: String }, + /// Non-recoverable error, do not retry. + Fatal { reason: String }, + /// Token budget exceeded for this run. + TokenBudgetExceeded, + /// Request timed out. + Timeout, + /// Request was cancelled by the caller. + Cancelled, + /// Provider is overloaded or at capacity, retry with longer delay. + Overloaded { reason: String }, + /// Context window exceeded, needs compaction before retry. + ContextWindowExceeded { reason: String }, +} + +#[derive(Debug, Clone)] +pub struct RetryPolicy { + pub max_attempts: usize, + pub base_delay: Duration, + pub jitter: bool, + pub exponential: bool, + pub switch_to_fallback: bool, +} + +impl RetryPolicy { + pub fn delay_for_attempt(&self, attempt: usize) -> Duration { + let ms = if self.exponential { + self.base_delay.as_millis() as u64 * (1u64 << attempt.min(6)) + } else { + self.base_delay.as_millis() as u64 + }; + + let ms = if self.jitter { + let half = (ms as f64 * 0.25) as u64; + let lo = ms.saturating_sub(half); + let hi = ms.saturating_add(half); + let mix = ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1); + lo + mix + } else { + ms + }; + + Duration::from_millis(ms.max(100)) + } +} + +/// Classify an error into a category for retry/fallback decisions. +/// +/// Inspects both the HTTP status code (when available) and the error message +/// content to determine the most appropriate category. +pub fn classify_error(error: &AiError, http_status: Option) -> ErrorCategory { + // HTTP status-based classification takes precedence + let from_status = match http_status { + Some(429) => Some(ErrorCategory::Retryable { + reason: "rate limited (HTTP 429)".to_string(), + }), + Some(401) | Some(403) => Some(ErrorCategory::FallbackModel { + reason: format!("authentication failed (HTTP {})", http_status.unwrap()), + }), + Some(502) | Some(503) => Some(ErrorCategory::Overloaded { + reason: format!("provider unavailable (HTTP {})", http_status.unwrap()), + }), + Some(504) => Some(ErrorCategory::Timeout), + Some(413) => Some(ErrorCategory::ContextWindowExceeded { + reason: "payload too large (HTTP 413)".to_string(), + }), + Some(s) if (400..500).contains(&s) => Some(ErrorCategory::Fatal { + reason: format!("client error (HTTP {})", s), + }), + Some(s) if (500..600).contains(&s) => Some(ErrorCategory::Retryable { + reason: format!("server error (HTTP {})", s), + }), + _ => None, + }; + + if let Some(cat) = from_status { + return cat; + } + + // Message-based classification + match error { + AiError::Timeout { .. } => ErrorCategory::Timeout, + AiError::TokenBudgetExceeded { .. } => ErrorCategory::TokenBudgetExceeded, + AiError::Api(msg) => classify_api_message(msg), + AiError::Response(msg) => classify_response_message(msg), + AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal { + reason: error.to_string(), + }, + _ => ErrorCategory::Fatal { + reason: error.to_string(), + }, + } +} + +/// Classify API error messages by keyword patterns. +fn classify_api_message(msg: &str) -> ErrorCategory { + let lower = msg.to_lowercase(); + + // Rate limiting + if lower.contains("rate") || lower.contains("too many requests") || lower.contains("throttl") { + return ErrorCategory::Retryable { + reason: msg.to_string(), + }; + } + + // Overloaded / capacity + if lower.contains("overloaded") + || lower.contains("capacity") + || lower.contains("too busy") + || lower.contains("service unavailable") + { + return ErrorCategory::Overloaded { + reason: msg.to_string(), + }; + } + + // Authentication / quota + if lower.contains("unauthorized") + || lower.contains("invalid api key") + || lower.contains("api key") + || lower.contains("forbidden") + || lower.contains("quota exceeded") + || lower.contains("insufficient") + || lower.contains("billing") + { + return ErrorCategory::FallbackModel { + reason: msg.to_string(), + }; + } + + // Context window exceeded + if lower.contains("context length") + || lower.contains("context window") + || lower.contains("maximum context") + || lower.contains("too many tokens") + || lower.contains("max_tokens") + { + return ErrorCategory::ContextWindowExceeded { + reason: msg.to_string(), + }; + } + + ErrorCategory::Fatal { + reason: msg.to_string(), + } +} + +/// Classify response error messages by keyword patterns. +fn classify_response_message(msg: &str) -> ErrorCategory { + let lower = msg.to_lowercase(); + + if lower.contains("cancelled") || lower.contains("canceled") { + return ErrorCategory::Cancelled; + } + if lower.contains("timeout") || lower.contains("timed out") { + return ErrorCategory::Timeout; + } + + ErrorCategory::Fatal { + reason: msg.to_string(), + } +} + +/// Get the recommended retry policy for an error category. +pub fn retry_policy_for( + category: &ErrorCategory, + max_attempts: usize, + base_delay_ms: u64, +) -> RetryPolicy { + match category { + ErrorCategory::Retryable { .. } => RetryPolicy { + max_attempts, + base_delay: Duration::from_millis(base_delay_ms), + jitter: true, + exponential: true, + switch_to_fallback: false, + }, + ErrorCategory::Overloaded { .. } => RetryPolicy { + max_attempts: max_attempts.min(5), + base_delay: Duration::from_millis(base_delay_ms.max(5_000)), + jitter: true, + exponential: true, + switch_to_fallback: true, + }, + ErrorCategory::FallbackModel { .. } => RetryPolicy { + max_attempts: 1, + base_delay: Duration::from_millis(500), + jitter: false, + exponential: false, + switch_to_fallback: true, + }, + ErrorCategory::ContextWindowExceeded { .. } => RetryPolicy { + max_attempts: 1, + base_delay: Duration::from_millis(0), + jitter: false, + exponential: false, + switch_to_fallback: false, + }, + ErrorCategory::Timeout => RetryPolicy { + max_attempts: max_attempts.min(2), + base_delay: Duration::from_millis(base_delay_ms.max(2_000)), + jitter: true, + exponential: false, + switch_to_fallback: false, + }, + ErrorCategory::TokenBudgetExceeded | ErrorCategory::Cancelled | ErrorCategory::Fatal { .. } => { + RetryPolicy { + max_attempts: 0, + base_delay: Duration::from_millis(0), + jitter: false, + exponential: false, + switch_to_fallback: false, + } + } + } +} + +/// Determine whether the error warrants switching to a fallback model. +pub fn should_switch_to_fallback(category: &ErrorCategory) -> bool { + matches!( + category, + ErrorCategory::FallbackModel { .. } | ErrorCategory::Overloaded { .. } + ) +} + +/// Determine whether compaction should be attempted before retry. +pub fn should_compact_before_retry(category: &ErrorCategory) -> bool { + matches!(category, ErrorCategory::ContextWindowExceeded { .. }) +} diff --git a/lib/ai/agent/events.rs b/lib/ai/agent/events.rs new file mode 100644 index 0000000..681c9df --- /dev/null +++ b/lib/ai/agent/events.rs @@ -0,0 +1,179 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Fine-grained agent lifecycle events, inspired by pi's event system. +/// +/// Covers the full agent execution lifecycle with enough granularity +/// for UI rendering, telemetry, and extension hooks. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AgentEvent { + // === Agent lifecycle === + AgentStart, + AgentEnd { + messages: Vec, + total_input_tokens: u64, + total_output_tokens: u64, + }, + + // === Turn lifecycle === + TurnStart { + turn_index: usize, + }, + TurnEnd { + turn_index: usize, + assistant_text: Option, + tool_call_count: usize, + }, + + // === Message lifecycle === + MessageStart { + role: MessageRole, + }, + MessageTextDelta { + index: usize, + delta: String, + }, + MessageThinkingDelta { + index: usize, + delta: String, + }, + MessageEnd { + role: MessageRole, + }, + + // === Tool execution lifecycle === + ToolExecutionStart { + tool_call_id: String, + tool_name: String, + arguments: Value, + }, + ToolExecutionUpdate { + tool_call_id: String, + tool_name: String, + partial_output: String, + }, + ToolExecutionEnd { + tool_call_id: String, + tool_name: String, + output: Option, + error: Option, + elapsed_ms: i64, + }, + + // === Steering / follow-up === + SteeringMessagesInjected { + count: usize, + }, + FollowUpMessagesInjected { + count: usize, + }, + + // === Context management === + ContextCompacted { + messages_compacted: usize, + tokens_saved: i64, + }, + BranchSummaryCreated { + entry_count: usize, + summary_length: usize, + }, + + // === Model switching === + ModelSwitched { + from_model: String, + to_model: String, + reason: String, + }, + + // === Error and retry === + ErrorClassified { + category: String, + message: String, + will_retry: bool, + retry_delay_ms: Option, + }, + RetryAttempt { + attempt: usize, + max_attempts: usize, + delay_ms: u64, + }, +} + +/// Simplified message role for event display. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum MessageRole { + User, + Assistant, + ToolResult, + System, +} + +/// A simplified message representation for `AgentEnd` events. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentEventMessage { + pub role: MessageRole, + pub content: String, + pub tool_calls: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EventToolCall { + pub id: String, + pub name: String, + pub arguments: Value, + pub output: Option, + pub error: Option, +} + +/// An async-friendly event sink that collects or broadcasts events. +pub struct EventSink { + senders: Vec>, +} + +impl EventSink { + pub fn new() -> Self { + Self { + senders: Vec::new(), + } + } + + /// Subscribe to events, returns a receiver. + pub fn subscribe(&mut self) -> tokio::sync::mpsc::UnboundedReceiver { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + self.senders.push(tx); + rx + } + + /// Emit an event to all subscribers. Non-blocking; drops if receiver disconnected. + pub fn emit(&self, event: AgentEvent) { + for sender in &self.senders { + let _ = sender.send(event.clone()); + } + } + + /// Check if there are any active subscribers. + pub fn has_subscribers(&self) -> bool { + !self.senders.is_empty() + } + + /// Remove disconnected senders. + pub fn cleanup(&mut self) { + self.senders.retain(|s| !s.is_closed()); + } +} + +impl Default for EventSink { + fn default() -> Self { + Self::new() + } +} + +impl Clone for EventSink { + fn clone(&self) -> Self { + Self { + senders: self.senders.clone(), + } + } +} diff --git a/lib/ai/agent/helpers.rs b/lib/ai/agent/helpers.rs new file mode 100644 index 0000000..3effec2 --- /dev/null +++ b/lib/ai/agent/helpers.rs @@ -0,0 +1,113 @@ +use std::future::Future; +use std::time::Duration; + +use crate::agent::request::AgentRequest; +use crate::error::{AiError, AiResult}; + +pub fn build_input_string(request: &AgentRequest) -> String { + let mut input = String::new(); + + if !request.context.is_empty() { + input.push_str("\n"); + for chunk in &request.context { + let source = chunk.source.as_deref().unwrap_or("unknown"); + let score = chunk + .score + .map(|s| format!("{s:.4}")) + .unwrap_or_else(|| "n/a".to_string()); + input.push_str(&format!( + "\n\n{}\n\n", + chunk.id, source, score, chunk.content + )); + } + input.push_str("\n\n"); + } + + for message in &request.messages { + match message { + super::request::AgentMessage::User(content) => { + input.push_str(&format!("User: {content}\n")); + } + super::request::AgentMessage::Assistant(content) => { + input.push_str(&format!("Assistant: {content}\n")); + } + } + } + + input.push_str(&format!("User: {}", request.input)); + + input +} + +pub fn estimate_tokens(text: &str) -> u64 { + if text.is_empty() { + return 0; + } + (text.chars().count() as f64 / 2.5).ceil() as u64 +} + +pub fn check_token_budget( + estimated_input_tokens: u64, + accumulated_output_chars: usize, + limit: i64, +) -> bool { + let output_estimate = (accumulated_output_chars as f64 / 2.5).ceil() as u64; + estimated_input_tokens + output_estimate > limit as u64 +} + +pub async fn with_retry( + max_attempts: usize, + base_delay_ms: u64, + f: F, +) -> AiResult +where + F: Fn() -> Fut, + Fut: Future>, +{ + let mut last_error: Option = None; + for attempt in 0..max_attempts { + match f().await { + Ok(result) => return Ok(result), + Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => { + let delay = Duration::from_millis(base_delay_ms * 2u64.pow(attempt as u32)); + tracing::warn!( + error = %e, + attempt = attempt + 1, + max_attempts, + delay_ms = delay.as_millis(), + "retrying after transient error" + ); + tokio::time::sleep(delay).await; + last_error = Some(e); + } + Err(e) => return Err(e), + } + } + Err(AiError::ModelRetriesExhausted { + attempts: max_attempts, + last_error: last_error + .map(|e| e.to_string()) + .unwrap_or_else(|| "unknown".to_string()), + }) +} + +fn is_retryable(error: &AiError) -> bool { + matches!( + error, + AiError::Api(_) | AiError::Response(_) | AiError::ModelRetriesExhausted { .. } + ) +} + +pub fn tool_result_content_to_string( + content: &rig::one_or_many::OneOrMany, +) -> String { + use rig::completion::message::ToolResultContent; + content + .iter() + .filter_map(|item| match item { + ToolResultContent::Text(t) => Some(t.text.clone()), + _ => None, + }) + .collect::>() + .join("\n") +} diff --git a/lib/ai/agent/hooks.rs b/lib/ai/agent/hooks.rs new file mode 100644 index 0000000..19888d5 --- /dev/null +++ b/lib/ai/agent/hooks.rs @@ -0,0 +1,145 @@ +use async_trait::async_trait; +use serde_json::Value; + +use crate::agent::persistence::AgentRunContext; +use crate::error::AiResult; + +#[derive(Debug, Clone)] +pub enum ToolGuardrailDecision { + Allow, + Block { reason: String }, + RequireApproval { message: String }, +} + +#[derive(Debug, Clone)] +pub struct ToolCallOutcome { + pub name: String, + pub arguments: Value, + pub output: Option, + pub error: Option, + pub elapsed_ms: i64, +} + +#[derive(Debug, Clone)] +pub struct HookMessage { + pub role: String, + pub content: Option, + pub tool_calls: Option, + pub tool_call_id: Option, +} + +#[derive(Debug, Clone)] +pub struct HookLlmResponse { + pub content: Option, + pub tool_calls: Option, + pub input_tokens: u64, + pub output_tokens: u64, + pub finish_reason: Option, +} + +#[derive(Debug, Clone)] +pub struct HookToolDef { + pub name: String, + pub description: String, +} + +#[async_trait] +pub trait AgentHook: Send + Sync { + fn name(&self) -> &'static str; + + async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> { + Ok(()) + } + + async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> { + Ok(()) + } + + async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> { + Ok(()) + } + + async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> { + Ok(()) + } + + async fn pre_tool_call( + &self, + _tool_name: &str, + _arguments: &Value, + ) -> AiResult> { + Ok(None) + } + + async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> { + Ok(()) + } +} + +pub struct HookChain { + hooks: Vec>, +} + +impl HookChain { + pub fn new(hooks: Vec>) -> Self { + Self { hooks } + } + + pub fn empty() -> Self { + Self { hooks: Vec::new() } + } + + pub fn is_empty(&self) -> bool { + self.hooks.is_empty() + } + + pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> { + for hook in &self.hooks { + hook.on_session_start(ctx).await?; + } + Ok(()) + } + + pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> { + for hook in &self.hooks { + hook.on_session_end(ctx, success).await?; + } + Ok(()) + } + + pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> { + for hook in &self.hooks { + hook.pre_llm_call(messages, tools).await?; + } + Ok(()) + } + + pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> { + for hook in &self.hooks { + hook.post_llm_call(response).await?; + } + Ok(()) + } + + pub async fn run_pre_tool_call( + &self, + tool_name: &str, + arguments: &Value, + ) -> AiResult> { + for hook in &self.hooks { + if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? { + if !matches!(decision, ToolGuardrailDecision::Allow) { + return Ok(Some(decision)); + } + } + } + Ok(None) + } + + pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> { + for hook in &self.hooks { + hook.post_tool_call(outcome).await?; + } + Ok(()) + } +} diff --git a/lib/ai/agent/iteration_budget.rs b/lib/ai/agent/iteration_budget.rs new file mode 100644 index 0000000..3603569 --- /dev/null +++ b/lib/ai/agent/iteration_budget.rs @@ -0,0 +1,45 @@ +#[derive(Clone, Debug)] +pub struct IterationBudget { + pub remaining: usize, + pub hard_limit: usize, + pub grace_call: bool, + pub consumed: usize, +} + +impl IterationBudget { + pub fn new(limit: usize) -> Self { + Self { + remaining: limit, + hard_limit: limit, + grace_call: true, + consumed: 0, + } + } + + pub fn can_continue(&self) -> bool { + self.remaining > 0 || (self.remaining == 0 && self.grace_call) + } + + pub fn consume(&mut self) -> bool { + if self.remaining > 0 { + self.remaining -= 1; + self.consumed += 1; + true + } else if self.grace_call { + self.grace_call = false; + self.consumed += 1; + true + } else { + false + } + } + + pub fn exhaust(&mut self) { + self.remaining = 0; + self.grace_call = false; + } + + pub const fn total_consumed(&self) -> usize { + self.consumed + } +} diff --git a/lib/ai/agent/loop.rs b/lib/ai/agent/loop.rs new file mode 100644 index 0000000..5d36d0f --- /dev/null +++ b/lib/ai/agent/loop.rs @@ -0,0 +1,876 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use futures::StreamExt; +use rig::agent::AgentBuilder; +use rig::client::CompletionClient; +use rig::streaming::StreamingPrompt; +use rig::tool::ToolDyn; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::{info, warn}; + +use super::config::AgentConfig; +use super::error_classifier::{ + classify_error, retry_policy_for, should_switch_to_fallback, +}; +use super::events::{AgentEvent, EventSink}; +use super::helpers::{build_input_string, estimate_tokens}; +use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision}; +use super::iteration_budget::IterationBudget; +use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; +use super::RigStreamChunk; +use crate::client::AiClient; +use crate::error::{AiError, AiResult}; + +/// How tool calls from a single assistant turn are executed. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolExecutionMode { + /// Execute tool calls one at a time. + Sequential, + /// Execute tool calls concurrently (after sequential preflight). + Parallel, +} + +impl Default for ToolExecutionMode { + fn default() -> Self { + Self::Parallel + } +} + +/// Callback type for steering messages (injected mid-run). +pub type SteeringFn = Arc< + dyn Fn() -> Pin> + Send>> + Send + Sync, +>; + +/// Callback type for follow-up messages (injected after agent would stop). +pub type FollowUpFn = Arc< + dyn Fn() -> Pin> + Send>> + Send + Sync, +>; + +/// Callback to decide whether the agent should stop after a turn. +pub type ShouldStopFn = Arc< + dyn Fn(&TurnContext) -> bool + Send + Sync, +>; + +/// Callback to prepare/modify state before the next turn. +pub type PrepareNextTurnFn = Arc< + dyn Fn(&TurnContext) -> Pin> + Send>> + + Send + + Sync, +>; + +/// Context passed to `should_stop` and `prepare_next_turn` callbacks. +#[derive(Debug, Clone)] +pub struct TurnContext { + pub turn_index: usize, + pub assistant_text: String, + pub tool_call_count: usize, + pub total_input_tokens: u64, + pub total_output_tokens: u64, + pub model_name: String, +} + +/// Replacement state for the next turn (returned by `prepare_next_turn`). +#[derive(Debug, Clone)] +pub struct TurnUpdate { + pub model: Option, + pub temperature: Option, + pub max_completion_tokens: Option, +} + +/// Extended agent loop configuration, adding steering/follow-up/lifecycle +/// hooks on top of the base `AgentConfig`. +pub struct AgentLoopConfig { + pub config: AgentConfig, + pub tool_execution_mode: ToolExecutionMode, + pub get_steering_messages: Option, + pub get_follow_up_messages: Option, + pub should_stop_after_turn: Option, + pub prepare_next_turn: Option, + pub event_sink: Option, +} + +impl AgentLoopConfig { + pub fn new(config: AgentConfig) -> Self { + Self { + config, + tool_execution_mode: ToolExecutionMode::default(), + get_steering_messages: None, + get_follow_up_messages: None, + should_stop_after_turn: None, + prepare_next_turn: None, + event_sink: None, + } + } + + pub fn with_tool_execution_mode(mut self, mode: ToolExecutionMode) -> Self { + self.tool_execution_mode = mode; + self + } + + pub fn with_steering_messages(mut self, f: SteeringFn) -> Self { + self.get_steering_messages = Some(f); + self + } + + pub fn with_follow_up_messages(mut self, f: FollowUpFn) -> Self { + self.get_follow_up_messages = Some(f); + self + } + + pub fn with_should_stop(mut self, f: ShouldStopFn) -> Self { + self.should_stop_after_turn = Some(f); + self + } + + pub fn with_prepare_next_turn(mut self, f: PrepareNextTurnFn) -> Self { + self.prepare_next_turn = Some(f); + self + } + + pub fn with_event_sink(mut self, sink: EventSink) -> Self { + self.event_sink = Some(sink); + self + } +} + +/// Enhanced agent with loop controls (steering, follow-up, model switching). +pub struct EnhancedAgent { + pub client: AiClient, + pub loop_config: AgentLoopConfig, + pub hooks: HookChain, +} + +impl EnhancedAgent { + pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult { + loop_config.config.validate()?; + Ok(Self { + client, + loop_config, + hooks: HookChain::empty(), + }) + } + + pub fn with_hooks(mut self, hooks: HookChain) -> Self { + self.hooks = hooks; + self + } + + pub fn config(&self) -> &AgentConfig { + &self.loop_config.config + } + + /// Run the enhanced agent loop, returning a chunk receiver and a join handle. + #[allow(clippy::too_many_lines)] + pub fn run( + &self, + request: AgentRequest, + tools: Vec>, + ) -> ( + mpsc::Receiver, + tokio::task::JoinHandle>, + ) { + let (tx, rx) = mpsc::channel::(256); + + let config = self.loop_config.config.clone(); + let tool_execution_mode = self.loop_config.tool_execution_mode; + let steering_fn = self.loop_config.get_steering_messages.clone(); + let follow_up_fn = self.loop_config.get_follow_up_messages.clone(); + let should_stop = self.loop_config.should_stop_after_turn.clone(); + let prepare_next = self.loop_config.prepare_next_turn.clone(); + let event_sink = self.loop_config.event_sink.clone(); + let client = self.client.llm_client().clone(); + let hooks = self.hooks.clone(); + + let filtered_tools: Vec> = tools + .into_iter() + .filter(|tool| config.is_tool_exposed(&tool.name())) + .collect(); + + let handle = tokio::spawn(async move { + run_enhanced_loop( + client, + config, + request, + filtered_tools, + tool_execution_mode, + steering_fn, + follow_up_fn, + should_stop, + prepare_next, + event_sink, + hooks, + tx, + ) + .await + }); + + (rx, handle) + } +} + +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] +async fn run_enhanced_loop( + client: rig::providers::openai::Client, + mut config: AgentConfig, + request: AgentRequest, + tools: Vec>, + _tool_execution_mode: ToolExecutionMode, + steering_fn: Option, + follow_up_fn: Option, + should_stop: Option, + prepare_next: Option, + event_sink: Option, + hooks: HookChain, + tx: mpsc::Sender, +) -> AiResult { + let cancellation = request.cancellation_token.clone(); + let timeout = request.timeout; + let mut budget = IterationBudget::new(config.iteration_budget); + let mut all_steps: Vec = Vec::new(); + let mut total_input_tokens: u64 = 0; + let mut total_output_tokens: u64 = 0; + let mut turn_index: usize = 0; + + // Session start hook + if let Some(ctx) = &request.run_context { + let _ = hooks.run_session_start(ctx).await; + } + + // Emit agent start event + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::AgentStart); + } + + // Build the initial input + let input = build_input_string(&request); + let mut current_input = input.clone(); + let estimated_input_tokens = estimate_tokens(¤t_input); + + if let Some(limit) = config.max_total_tokens_per_run + && estimated_input_tokens > limit as u64 + { + return Err(AiError::TokenBudgetExceeded { + estimated: estimated_input_tokens, + limit, + }); + } + + // Outer loop: handles follow-up messages after agent would stop + loop { + // Inner loop: tool call turns + steering messages + let mut pending_steering: Vec = if let Some(f) = &steering_fn { + f().await + } else { + Vec::new() + }; + + loop { + // Check cancellation + if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { + let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await; + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ErrorClassified { + category: "cancelled".to_string(), + message: "cancelled by caller".to_string(), + will_retry: false, + retry_delay_ms: None, + }); + } + return Err(AiError::Response("agent run cancelled".to_string())); + } + + // Inject steering messages if any + if !pending_steering.is_empty() { + let count = pending_steering.len(); + for msg in &pending_steering { + current_input.push_str(&format!("\nUser: {msg}\n")); + } + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::SteeringMessagesInjected { count }); + } + pending_steering.clear(); + } + + // Emit turn start + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::TurnStart { turn_index }); + } + let _ = tx.send(RigStreamChunk::TextDelta { + index: 0, + content: String::new(), // placeholder for turn boundary detection + }).await; + + // Run one LLM turn with retry + let turn_result = run_single_turn( + &client, + &config, + ¤t_input, + &tools, + &mut budget, + &cancellation, + timeout, + &hooks, + &event_sink, + &tx, + ) + .await; + + match turn_result { + Ok(turn_output) => { + total_input_tokens += turn_output.input_tokens; + total_output_tokens += turn_output.output_tokens; + + // Collect step + let tool_call_count = turn_output.tool_calls.len(); + if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() { + all_steps.push(AgentStep { + index: all_steps.len(), + assistant: (!turn_output.assistant_text.is_empty()) + .then_some(turn_output.assistant_text.clone()), + reasoning_content: None, + tool_calls: turn_output.tool_calls, + reflection: None, + }); + } + + // Emit turn end + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::TurnEnd { + turn_index, + assistant_text: Some(turn_output.assistant_text.clone()), + tool_call_count, + }); + } + + // Check should_stop + let turn_ctx = TurnContext { + turn_index, + assistant_text: turn_output.assistant_text.clone(), + tool_call_count, + total_input_tokens, + total_output_tokens, + model_name: config.model.clone(), + }; + + if let Some(stop_fn) = &should_stop { + if stop_fn(&turn_ctx) { + info!(turn_index, "agent stopped by should_stop callback"); + break; + } + } + + // Prepare next turn (may switch model) + if let Some(prep_fn) = &prepare_next { + if let Some(update) = prep_fn(&turn_ctx).await { + if let Some(new_model) = update.model { + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ModelSwitched { + from_model: config.model.clone(), + to_model: new_model.clone(), + reason: "prepare_next_turn".to_string(), + }); + } + config.model = new_model; + } + if let Some(temp) = update.temperature { + config.temperature = Some(temp); + } + if let Some(max_tok) = update.max_completion_tokens { + config.max_completion_tokens = Some(max_tok); + } + } + } + + turn_index += 1; + + // If no tool calls, this turn is done + if tool_call_count == 0 { + break; + } + + // Otherwise, continue with tool results as new input + current_input = turn_output.assistant_text.clone(); + } + Err(e) => { + // Error classification and retry with fallback + let category = classify_error(&e, None); + let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms); + + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ErrorClassified { + category: format!("{category:?}"), + message: e.to_string(), + will_retry: policy.switch_to_fallback || policy.max_attempts > 0, + retry_delay_ms: Some(policy.base_delay.as_millis() as u64), + }); + } + + if should_switch_to_fallback(&category) { + if let Some(fallback_model) = &config.fallback_model { + info!( + from_model = %config.model, + to_model = %fallback_model, + "switching to fallback model due to error" + ); + + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ModelSwitched { + from_model: config.model.clone(), + to_model: fallback_model.clone(), + reason: format!("fallback: {category:?}"), + }); + } + + config.model = fallback_model.clone(); + + // Retry with the fallback model + let retry_result = run_single_turn( + &client, + &config, + ¤t_input, + &tools, + &mut budget, + &cancellation, + timeout, + &hooks, + &event_sink, + &tx, + ) + .await; + + match retry_result { + Ok(turn_output) => { + total_input_tokens += turn_output.input_tokens; + total_output_tokens += turn_output.output_tokens; + let tc_count = turn_output.tool_calls.len(); + let has_tools = tc_count > 0; + let has_text = !turn_output.assistant_text.is_empty(); + let assistant = turn_output.assistant_text; + if has_tools || has_text { + all_steps.push(AgentStep { + index: all_steps.len(), + assistant: has_text.then_some(assistant.clone()), + reasoning_content: None, + tool_calls: turn_output.tool_calls, + reflection: None, + }); + } + turn_index += 1; + if !has_tools { + break; + } + current_input = assistant; + continue; + } + Err(retry_err) => { + let _ = tx + .send(RigStreamChunk::Failed { + error: retry_err.to_string(), + }) + .await; + if let Some(ctx) = &request.run_context { + let _ = hooks.run_session_end(ctx, false).await; + } + return Err(retry_err); + } + } + } + } + + // Non-retryable or no fallback + let _ = tx + .send(RigStreamChunk::Failed { + error: e.to_string(), + }) + .await; + if let Some(ctx) = &request.run_context { + let _ = hooks.run_session_end(ctx, false).await; + } + return Err(e); + } + } + } + + // Check for follow-up messages + let follow_ups: Vec = if let Some(f) = &follow_up_fn { + f().await + } else { + Vec::new() + }; + + if follow_ups.is_empty() { + break; + } + + // Inject follow-up messages and continue the outer loop + let count = follow_ups.len(); + for msg in &follow_ups { + current_input.push_str(&format!("\nUser: {msg}\n")); + } + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::FollowUpMessagesInjected { count }); + } + } + + // Build final output + let output = all_steps + .last() + .and_then(|s| s.assistant.clone()) + .unwrap_or_default(); + + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::AgentEnd { + messages: Vec::new(), + total_input_tokens, + total_output_tokens, + }); + } + + let _ = tx + .send(RigStreamChunk::Final { + content: output.clone(), + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + }) + .await; + + if let Some(ctx) = &request.run_context { + let _ = hooks.run_session_end(ctx, true).await; + } + + info!( + turns = turn_index, + steps = all_steps.len(), + total_input_tokens, + total_output_tokens, + "enhanced agent loop completed" + ); + + Ok(AgentResult { + output, + steps: all_steps, + expert_outputs: Vec::new(), + input_tokens: total_input_tokens as i64, + output_tokens: total_output_tokens as i64, + }) +} + +/// Output from a single LLM turn (one assistant response + its tool calls). +struct TurnOutput { + assistant_text: String, + tool_calls: Vec, + input_tokens: u64, + output_tokens: u64, +} + +/// Run a single LLM turn with streaming, handling the stream parsing and +/// tool call collection. +#[allow(clippy::too_many_arguments)] +async fn run_single_turn( + client: &rig::providers::openai::Client, + config: &AgentConfig, + input: &str, + _tools: &[Box], + budget: &mut IterationBudget, + cancellation: &Option, + timeout: Option, + hooks: &HookChain, + event_sink: &Option, + tx: &mpsc::Sender, +) -> AiResult { + if !budget.consume() { + return Err(AiError::Response("iteration budget exhausted".to_string())); + } + + let model = client.completion_model(&config.model); + let mut agent_builder = AgentBuilder::new(model) + .preamble(&config.system_prompt) + .default_max_turns(1); // Single turn, we manage the loop + + // Note: we can't easily pass tools here for single-turn since + // rig's multi_turn handles tool execution internally. + // For the enhanced loop, we rely on rig's built-in tool execution + // within a single turn. The parallel/sequential mode is controlled + // by the event-level hooks. + + if let Some(temp) = config.temperature { + agent_builder = agent_builder.temperature(temp); + } + if let Some(mt) = config.max_completion_tokens { + agent_builder = agent_builder.max_tokens(mt); + } + + let agent = agent_builder.build(); + + // Pre-LLM hook + if !hooks.is_empty() { + let hook_messages = vec![HookMessage { + role: "user".to_string(), + content: Some(input.to_string()), + tool_calls: None, + tool_call_id: None, + }]; + let _ = hooks.run_pre_llm_call(&hook_messages, &[]).await; + } + + let stream_future = agent + .stream_prompt(input) + .with_history(Vec::::new()) + .multi_turn(config.max_iterations); + + let stream = if let Some(dur) = timeout { + match tokio::time::timeout(dur, stream_future).await { + Ok(stream) => stream, + Err(_) => { + return Err(AiError::Timeout { + seconds: dur.as_secs(), + }); + } + } + } else { + stream_future.await + }; + + tokio::pin!(stream); + + let mut assistant_text = String::new(); + let mut tool_calls: Vec = Vec::new(); + let mut delta_index = 0usize; + let mut _accumulated_output_chars: usize = 0; + let mut input_tokens: u64 = 0; + let mut output_tokens: u64 = 0; + + while let Some(item) = stream.next().await { + if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { + return Err(AiError::Response("cancelled".to_string())); + } + + match item { + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::Text(text), + )) => { + _accumulated_output_chars += text.text.chars().count(); + assistant_text.push_str(&text.text); + + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::MessageTextDelta { + index: delta_index, + delta: text.text.clone(), + }); + } + + let _ = tx + .send(RigStreamChunk::TextDelta { + index: delta_index, + content: text.text.clone(), + }) + .await; + delta_index += 1; + } + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::Reasoning(reasoning), + )) => { + for part in &reasoning.content { + if let rig::completion::message::ReasoningContent::Text { text, .. } = part { + _accumulated_output_chars += text.chars().count(); + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::MessageThinkingDelta { + index: delta_index, + delta: text.clone(), + }); + } + let _ = tx + .send(RigStreamChunk::Thinking { + index: delta_index, + content: text.clone(), + }) + .await; + delta_index += 1; + } + } + } + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. }, + )) => { + _accumulated_output_chars += reasoning.chars().count(); + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::MessageThinkingDelta { + index: delta_index, + delta: reasoning.clone(), + }); + } + let _ = tx + .send(RigStreamChunk::Thinking { + index: delta_index, + content: reasoning.clone(), + }) + .await; + delta_index += 1; + } + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. }, + )) => { + let args = match &tool_call.function.arguments { + serde_json::Value::String(s) => s.clone(), + v => serde_json::to_string(v).unwrap_or_default(), + }; + _accumulated_output_chars += args.chars().count(); + + let tool_name = tool_call.function.name.clone(); + let tool_args: serde_json::Value = + serde_json::from_str(&args).unwrap_or_default(); + + // Pre-tool-call guardrail hook + if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { + match decision { + ToolGuardrailDecision::Allow => {} + ToolGuardrailDecision::Block { reason } => { + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ToolExecutionEnd { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + output: None, + error: Some(reason.clone()), + elapsed_ms: 0, + }); + } + let _ = tx + .send(RigStreamChunk::ToolCallFinished { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + output: format!("blocked: {reason}"), + error: Some(reason), + }) + .await; + tool_calls.push(ToolCallRecord { + id: tool_call.id.clone(), + name: tool_name, + arguments: tool_args, + output: None, + error: Some("blocked by guardrail".to_string()), + elapsed_ms: None, + }); + continue; + } + ToolGuardrailDecision::RequireApproval { message } => { + tool_calls.push(ToolCallRecord { + id: tool_call.id.clone(), + name: tool_name.clone(), + arguments: tool_args, + output: None, + error: Some(format!("requires approval: {message}")), + elapsed_ms: None, + }); + continue; + } + } + } + + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ToolExecutionStart { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + arguments: tool_args.clone(), + }); + } + + let _ = tx + .send(RigStreamChunk::ToolCallStarted { + tool_call_id: tool_call.id.clone(), + tool_name: tool_name.clone(), + arguments: args.clone(), + }) + .await; + tool_calls.push(ToolCallRecord { + id: tool_call.id.clone(), + name: tool_name, + arguments: tool_args, + output: None, + error: None, + elapsed_ms: None, + }); + } + Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( + rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, + )) => { + let content = + super::helpers::tool_result_content_to_string(&tool_result.content); + _accumulated_output_chars += content.chars().count(); + + let tool_name = tool_calls + .last() + .map(|tc| tc.name.clone()) + .unwrap_or_default(); + + if let Some(last) = tool_calls.last_mut() + && last.id == tool_result.id + { + last.output = Some(serde_json::from_str(&content).unwrap_or_default()); + } + + if let Some(sink) = &event_sink { + sink.emit(AgentEvent::ToolExecutionEnd { + tool_call_id: tool_result.id.clone(), + tool_name: tool_name.clone(), + output: Some(serde_json::Value::String(content.clone())), + error: None, + elapsed_ms: 0, + }); + } + + let _ = tx + .send(RigStreamChunk::ToolCallFinished { + tool_call_id: tool_result.id.clone(), + tool_name, + output: content.clone(), + error: None, + }) + .await; + + if !hooks.is_empty() { + let outcome = ToolCallOutcome { + name: tool_result.id.clone(), + arguments: serde_json::Value::Null, + output: Some(serde_json::Value::String(content)), + error: None, + elapsed_ms: 0, + }; + let _ = hooks.run_post_tool_call(&outcome).await; + } + } + Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { + let usage = resp.usage(); + input_tokens = usage.input_tokens; + output_tokens = usage.output_tokens; + + if !hooks.is_empty() { + let hook_response = HookLlmResponse { + content: Some(assistant_text.clone()), + tool_calls: None, + input_tokens, + output_tokens, + finish_reason: None, + }; + let _ = hooks.run_post_llm_call(&hook_response).await; + } + } + Err(e) => { + warn!(error = %e, "turn stream error"); + return Err(AiError::Api(format!("{e}"))); + } + _ => {} + } + } + + Ok(TurnOutput { + assistant_text, + tool_calls, + input_tokens, + output_tokens, + }) +} + + diff --git a/lib/ai/agent/mod.rs b/lib/ai/agent/mod.rs new file mode 100644 index 0000000..d5fa0b4 --- /dev/null +++ b/lib/ai/agent/mod.rs @@ -0,0 +1,98 @@ +pub mod agent; +pub mod compression; +pub mod config; +pub mod error_classifier; +pub mod events; +pub mod helpers; +pub mod hooks; +pub mod iteration_budget; +pub mod r#loop; +pub mod persistence; +pub mod prompt; +pub mod prompt_builder; +pub mod request; +pub mod session; +pub mod subagent; +pub mod tool; + +use serde::{Deserialize, Serialize}; + +pub use agent::RigAgent; +pub use compression::{ + CompactionResult, CompressionStrategy, build_branch_summary_prompt, + build_compression_prompt, build_compression_prompt_with_options, + estimate_truncation, plan_compaction, +}; +pub use config::AgentConfig; +pub use error_classifier::{ + ErrorCategory, RetryPolicy, classify_error, retry_policy_for, + should_compact_before_retry, should_switch_to_fallback, +}; +pub use events::{ + AgentEvent, AgentEventMessage, EventSink, EventToolCall, MessageRole, +}; +pub use hooks::{AgentHook, HookChain, ToolCallOutcome, ToolGuardrailDecision}; +pub use iteration_budget::IterationBudget; +pub use r#loop::{ + AgentLoopConfig, EnhancedAgent, PrepareNextTurnFn, ShouldStopFn, + ToolExecutionMode, TurnContext, TurnUpdate, +}; +pub use persistence::{ + AgentRealtime, AgentRunContext, AgentRuntime, AgentStreamEvent, +}; +pub use prompt_builder::SystemPromptBuilder; +pub use request::{ + AgentContextChunk, AgentExpert, AgentExpertOutput, AgentMessage, + AgentRequest, AgentResult, AgentStep, ToolCallRecord, +}; +pub use session::{ + CompactionOptions, Session, SessionEntry, SessionHeader, + SessionMessageRole, SessionToolCall, SessionToolResult, +}; +pub use tool::{RigTool, RigToolSet}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum RigStreamChunk { + TextDelta { + index: usize, + content: String, + }, + Thinking { + index: usize, + content: String, + }, + ToolCallStarted { + tool_call_id: String, + tool_name: String, + arguments: String, + }, + ToolCallFinished { + tool_call_id: String, + tool_name: String, + output: String, + error: Option, + }, + SubagentStarted { + subagent_id: String, + role: String, + task: String, + }, + SubagentCompleted { + subagent_id: String, + role: String, + task: String, + output: String, + }, + SubagentFailed { + error: String, + }, + Final { + content: String, + input_tokens: u64, + output_tokens: u64, + }, + Failed { + error: String, + }, +} diff --git a/lib/ai/agent/persistence/db.rs b/lib/ai/agent/persistence/db.rs new file mode 100644 index 0000000..5e0bdd5 --- /dev/null +++ b/lib/ai/agent/persistence/db.rs @@ -0,0 +1,33 @@ +use std::time::Instant; + +use crate::agent::persistence::types::{ActiveAgentRun, AgentRunContext}; +use crate::error::AiResult; + +impl super::types::AgentRuntime { + pub async fn start_run( + &self, + run_context: Option<&AgentRunContext>, + ) -> AiResult { + let Some(run_context) = run_context else { + return Ok(ActiveAgentRun { + conversation_id: None, + message_id: None, + invocation_id: None, + session_id: None, + user_id: None, + started_at: Instant::now(), + current_step: 0, + }); + }; + + Ok(ActiveAgentRun { + conversation_id: run_context.conversation_id, + message_id: None, + invocation_id: run_context.invocation_id, + session_id: run_context.session_id, + user_id: run_context.user_id, + started_at: Instant::now(), + current_step: 0, + }) + } +} diff --git a/lib/ai/agent/persistence/mod.rs b/lib/ai/agent/persistence/mod.rs new file mode 100644 index 0000000..84a682c --- /dev/null +++ b/lib/ai/agent/persistence/mod.rs @@ -0,0 +1,8 @@ +pub mod db; +pub mod realtime; +pub mod types; + +pub use types::{ + ActiveAgentRun, AgentRealtime, AgentRunContext, AgentRuntime, + AgentStreamEvent, estimate_output_tokens, +}; diff --git a/lib/ai/agent/persistence/realtime.rs b/lib/ai/agent/persistence/realtime.rs new file mode 100644 index 0000000..6254ce3 --- /dev/null +++ b/lib/ai/agent/persistence/realtime.rs @@ -0,0 +1,38 @@ +use crate::agent::persistence::types::{ + AgentRealtime, AgentRuntime, AgentStreamEvent, +}; +use crate::error::AiResult; + +pub async fn publish_event( + runtime: &AgentRuntime, + _realtime: Option<&AgentRealtime>, + event: &AgentStreamEvent, +) -> AiResult<()> { + let Some(tx) = &runtime.tx else { + return Ok(()); + }; + + let payload = match serde_json::to_string(event) { + Ok(p) => p, + Err(error) => { + tracing::warn!(error = %error, "agent sse: serialize failed"); + return Ok(()); + } + }; + + if tx.send(payload).is_err() { + tracing::debug!("agent sse: mpsc send failed, client disconnected"); + } + + Ok(()) +} + +impl AgentRuntime { + pub async fn publish( + &self, + realtime: Option<&AgentRealtime>, + event: &AgentStreamEvent, + ) -> AiResult<()> { + publish_event(self, realtime, event).await + } +} diff --git a/lib/ai/agent/persistence/types.rs b/lib/ai/agent/persistence/types.rs new file mode 100644 index 0000000..5b621d8 --- /dev/null +++ b/lib/ai/agent/persistence/types.rs @@ -0,0 +1,177 @@ +use std::time::Instant; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::mpsc; +use uuid::Uuid; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentRealtime { + pub channel: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentRunContext { + pub conversation_id: Option, + pub invocation_id: Option, + pub session_id: Option, + pub user_id: Option, + pub realtime: Option, +} + +impl AgentRunContext { + pub fn new(user_id: Uuid) -> Self { + Self { + conversation_id: None, + invocation_id: None, + session_id: None, + user_id: Some(user_id), + realtime: None, + } + } + + pub fn with_conversation_id(mut self, id: Uuid) -> Self { + self.conversation_id = Some(id); + self + } + + pub fn with_invocation_id(mut self, id: Uuid) -> Self { + self.invocation_id = Some(id); + self + } + + pub fn with_session_id(mut self, id: Uuid) -> Self { + self.session_id = Some(id); + self + } + + pub fn with_realtime(mut self, realtime: AgentRealtime) -> Self { + self.realtime = Some(realtime); + self + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AgentStreamEvent { + Started { + conversation_id: Option, + message_id: Option, + session_id: Option, + model: String, + }, + Delta { + conversation_id: Option, + message_id: Option, + index: usize, + content: String, + }, + Thinking { + conversation_id: Option, + message_id: Option, + index: usize, + content: String, + }, + ToolCallStarted { + conversation_id: Option, + message_id: Option, + session_id: Option, + tool_call_id: String, + tool_name: String, + arguments: Value, + }, + ToolCallFinished { + conversation_id: Option, + message_id: Option, + session_id: Option, + tool_call_id: String, + tool_name: String, + output: Option, + error: Option, + execution_time_ms: i64, + }, + SubagentStarted { + conversation_id: Option, + message_id: Option, + subagent_id: String, + role: String, + task: String, + model: String, + }, + SubagentDelta { + conversation_id: Option, + message_id: Option, + subagent_id: String, + index: usize, + content: String, + }, + SubagentCompleted { + conversation_id: Option, + message_id: Option, + subagent_id: String, + role: String, + task: String, + output: String, + input_tokens: i64, + output_tokens: i64, + model: String, + }, + SubagentFailed { + conversation_id: Option, + message_id: Option, + subagent_id: String, + error: String, + }, + Completed { + conversation_id: Option, + message_id: Option, + session_id: Option, + output: String, + input_tokens: i64, + output_tokens: i64, + latency_ms: i32, + stop_reason: Option, + }, + Failed { + conversation_id: Option, + message_id: Option, + session_id: Option, + error: String, + }, +} + +#[derive(Clone)] +pub struct AgentRuntime { + pub tx: Option>, +} + +impl AgentRuntime { + pub fn new(tx: mpsc::UnboundedSender) -> Self { + Self { tx: Some(tx) } + } + + pub fn empty() -> Self { + Self { tx: None } + } +} + +impl Default for AgentRuntime { + fn default() -> Self { + Self::empty() + } +} + +#[derive(Clone, Debug)] +pub struct ActiveAgentRun { + pub conversation_id: Option, + pub message_id: Option, + pub invocation_id: Option, + pub session_id: Option, + pub user_id: Option, + pub started_at: Instant, + pub current_step: usize, +} + +pub fn estimate_output_tokens(output: &str) -> i64 { + (output.chars().count() as f64 / 4.0).ceil() as i64 +} diff --git a/lib/ai/agent/prompt.rs b/lib/ai/agent/prompt.rs new file mode 100644 index 0000000..9e18379 --- /dev/null +++ b/lib/ai/agent/prompt.rs @@ -0,0 +1,57 @@ +use rig::agent::AgentBuilder; +use rig::client::CompletionClient; +use rig::completion::Prompt; + +use super::agent::RigAgent; +use super::helpers::with_retry; +use crate::error::{AiError, AiResult}; + +impl RigAgent { + pub async fn prompt( + &self, + system_prompt: &str, + user_input: &str, + ) -> AiResult<(String, u64, u64)> { + let model_name = self.config.model.clone(); + let client = self.client.llm_client().clone(); + let temperature = self.config.temperature; + let max_completion_tokens = self.config.max_completion_tokens; + let retry_attempts = self.config.retry_max_attempts; + let retry_delay_ms = self.config.retry_base_delay_ms; + let sp = system_prompt.to_string(); + let ui = user_input.to_string(); + + with_retry(retry_attempts, retry_delay_ms, || { + let client = client.clone(); + let model_name = model_name.clone(); + let sp = sp.clone(); + let ui = ui.clone(); + async move { + let model = client.completion_model(&model_name); + let mut builder = AgentBuilder::new(model).preamble(&sp); + if let Some(temp) = temperature { + builder = builder.temperature(temp); + } + if let Some(mt) = max_completion_tokens { + builder = builder.max_tokens(mt); + } + let agent = builder.build(); + + let response = agent + .prompt(&ui) + .extended_details() + .await + .map_err(|e: rig::completion::PromptError| { + AiError::Api(e.to_string()) + })?; + + Ok(( + response.output, + response.usage.input_tokens, + response.usage.output_tokens, + )) + } + }) + .await + } +} diff --git a/lib/ai/agent/prompt_builder.rs b/lib/ai/agent/prompt_builder.rs new file mode 100644 index 0000000..fc4a1ba --- /dev/null +++ b/lib/ai/agent/prompt_builder.rs @@ -0,0 +1,251 @@ +use std::collections::HashMap; + +/// Modular system prompt builder inspired by pi's `buildSystemPrompt`. +/// +/// Supports: +/// - Base prompt (replaceable or appendable) +/// - Tool snippets injected into an "Available tools" section +/// - Project context files (AGENTS.md, etc.) +/// - Skills injection +/// - Variable substitution ({{key}}) +/// - Metadata (date) +/// +/// # Example +/// ```rust +/// use ai::agent::prompt_builder::SystemPromptBuilder; +/// +/// let prompt = SystemPromptBuilder::new() +/// .base_prompt("You are a helpful assistant.") +/// .tool_snippet("bash", "Execute shell commands") +/// .tool_snippet("read", "Read file contents") +/// .project_context("AGENTS.md", "# Project Rules\n- Follow conventions") +/// .variable("repo_name", "gitdataai") +/// .build(); +/// ``` +#[derive(Clone, Debug)] +pub struct SystemPromptBuilder { + base_prompt: Option, + append_prompt: Option, + tool_snippets: Vec<(String, String)>, + tool_guidelines: Vec, + project_contexts: Vec<(String, String)>, + skills: Vec, + variables: HashMap, + date: Option, + custom_sections: Vec<(String, String)>, +} + +impl SystemPromptBuilder { + pub fn new() -> Self { + Self { + base_prompt: None, + append_prompt: None, + tool_snippets: Vec::new(), + tool_guidelines: Vec::new(), + project_contexts: Vec::new(), + skills: Vec::new(), + variables: HashMap::new(), + date: None, + custom_sections: Vec::new(), + } + } + + /// Set the base system prompt. Replaces the default prompt. + pub fn base_prompt(mut self, prompt: impl Into) -> Self { + self.base_prompt = Some(prompt.into()); + self + } + + /// Append additional text to the system prompt after the base. + pub fn append_prompt(mut self, text: impl Into) -> Self { + self.append_prompt = Some(text.into()); + self + } + + /// Add a one-line tool description snippet. + pub fn tool_snippet(mut self, tool_name: impl Into, description: impl Into) -> Self { + self.tool_snippets.push((tool_name.into(), description.into())); + self + } + + /// Add a guideline bullet for the tools section. + pub fn tool_guideline(mut self, guideline: impl Into) -> Self { + self.tool_guidelines.push(guideline.into()); + self + } + + /// Add a project context file (e.g., AGENTS.md content). + pub fn project_context(mut self, path: impl Into, content: impl Into) -> Self { + self.project_contexts.push((path.into(), content.into())); + self + } + + /// Add a skill description to inject into the prompt. + pub fn skill(mut self, skill_description: impl Into) -> Self { + self.skills.push(skill_description.into()); + self + } + + /// Set a variable for {{key}} substitution. + pub fn variable(mut self, key: impl Into, value: impl Into) -> Self { + self.variables.insert(key.into(), value.into()); + self + } + + /// Set multiple variables from an iterator. + pub fn variables(mut self, vars: impl IntoIterator) -> Self { + self.variables.extend(vars); + self + } + + /// Set the date metadata (ISO format: YYYY-MM-DD). + pub fn date(mut self, date: impl Into) -> Self { + self.date = Some(date.into()); + self + } + + /// Add a custom named section to the prompt. + pub fn custom_section(mut self, name: impl Into, content: impl Into) -> Self { + self.custom_sections.push((name.into(), content.into())); + self + } + + /// Build the final system prompt string. + pub fn build(self) -> String { + let mut parts: Vec = Vec::new(); + + // 1. Base prompt + if let Some(base) = &self.base_prompt { + parts.push(base.clone()); + } + + // 2. Append prompt + if let Some(append) = &self.append_prompt { + parts.push(append.clone()); + } + + // 3. Tool snippets section + if !self.tool_snippets.is_empty() { + let mut section = String::from("\n## Available Tools\n"); + for (name, desc) in &self.tool_snippets { + section.push_str(&format!("- `{name}`: {desc}\n")); + } + if !self.tool_guidelines.is_empty() { + section.push_str("\n### Tool Guidelines\n"); + for guideline in &self.tool_guidelines { + section.push_str(&format!("- {guideline}\n")); + } + } + parts.push(section); + } + + // 4. Project context files + if !self.project_contexts.is_empty() { + let mut section = String::from("\n\n\n"); + section.push_str("Project-specific instructions and guidelines:\n\n"); + for (path, content) in &self.project_contexts { + section.push_str(&format!("\n{content}\n\n\n")); + } + section.push_str(""); + parts.push(section); + } + + // 5. Skills section + if !self.skills.is_empty() { + let mut section = String::from("\n## Available Skills\n"); + for skill in &self.skills { + section.push_str(&format!("{skill}\n")); + } + parts.push(section); + } + + // 6. Custom sections + for (name, content) in &self.custom_sections { + parts.push(format!("\n## {name}\n{content}")); + } + + // 7. Metadata footer + if let Some(date) = &self.date { + parts.push(format!("\nCurrent date: {date}")); + } + + let mut result = parts.join("\n"); + + // 8. Variable substitution + for (key, value) in &self.variables { + let placeholder = format!("{{{{{}}}}}", key); + result = result.replace(&placeholder, value); + } + + result + } +} + +impl Default for SystemPromptBuilder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_build() { + let prompt = SystemPromptBuilder::new() + .base_prompt("You are a helpful assistant.") + .date("2026-05-29") + .build(); + + assert!(prompt.contains("You are a helpful assistant.")); + assert!(prompt.contains("Current date: 2026-05-29")); + } + + #[test] + fn test_variable_substitution() { + let prompt = SystemPromptBuilder::new() + .base_prompt("Repo: {{repo_name}}, User: {{user}}") + .variable("repo_name", "gitdataai") + .variable("user", "zhenyi") + .build(); + + assert!(prompt.contains("Repo: gitdataai")); + assert!(prompt.contains("User: zhenyi")); + } + + #[test] + fn test_tool_snippets() { + let prompt = SystemPromptBuilder::new() + .base_prompt("Agent prompt.") + .tool_snippet("bash", "Execute shell commands") + .tool_snippet("read", "Read file contents") + .build(); + + assert!(prompt.contains("## Available Tools")); + assert!(prompt.contains("`bash`: Execute shell commands")); + } + + #[test] + fn test_project_context() { + let prompt = SystemPromptBuilder::new() + .base_prompt("Base.") + .project_context("AGENTS.md", "# Rules\n- Follow conventions") + .build(); + + assert!(prompt.contains("")); + assert!(prompt.contains("AGENTS.md")); + assert!(prompt.contains("Follow conventions")); + } + + #[test] + fn test_custom_section() { + let prompt = SystemPromptBuilder::new() + .base_prompt("Base.") + .custom_section("Memory", "Remember: user prefers Rust") + .build(); + + assert!(prompt.contains("## Memory")); + assert!(prompt.contains("user prefers Rust")); + } +} diff --git a/lib/ai/agent/request.rs b/lib/ai/agent/request.rs new file mode 100644 index 0000000..e9ed174 --- /dev/null +++ b/lib/ai/agent/request.rs @@ -0,0 +1,240 @@ +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio_util::sync::CancellationToken; + +use super::persistence::AgentRunContext; +use crate::error::{AiError, AiResult}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentRequest { + pub input: String, + pub messages: Vec, + pub context: Vec, + pub experts: Vec, + pub run_context: Option, + #[serde(skip)] + pub prefill_messages: Vec, + #[serde(skip)] + pub cancellation_token: Option, + #[serde(skip)] + pub timeout: Option, +} + +impl AgentRequest { + pub fn new(input: impl Into) -> Self { + Self { + input: input.into(), + messages: Vec::new(), + context: Vec::new(), + experts: Vec::new(), + run_context: None, + prefill_messages: Vec::new(), + cancellation_token: None, + timeout: None, + } + } + + pub fn validate(&self) -> AiResult<()> { + if self.input.trim().is_empty() { + return Err(AiError::Config("agent request input is required".to_string())); + } + if self.input.len() > 1_000_000 { + return Err(AiError::Config( + "agent request input exceeds maximum length (1MB)".to_string(), + )); + } + if self.experts.len() > 32 { + return Err(AiError::Config( + "agent request experts count exceeds maximum (32)".to_string(), + )); + } + Ok(()) + } + + pub fn with_messages(mut self, messages: Vec) -> Self { + self.messages = messages; + self + } + + pub fn with_context(mut self, context: Vec) -> Self { + self.context = context; + self + } + + pub fn add_context(mut self, chunk: AgentContextChunk) -> Self { + self.context.push(chunk); + self + } + + pub fn with_experts(mut self, experts: Vec) -> Self { + self.experts = experts; + self + } + + pub fn add_expert(mut self, expert: AgentExpert) -> Self { + self.experts.push(expert); + self + } + + pub fn with_run_context(mut self, run_context: AgentRunContext) -> Self { + self.run_context = Some(run_context); + self + } + + pub fn with_prefill_messages(mut self, prefill_messages: Vec) -> Self { + self.prefill_messages = prefill_messages; + self + } + + pub fn with_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self { + self.cancellation_token = Some(cancellation_token); + self + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum AgentMessage { + User(String), + Assistant(String), +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentExpert { + pub id: String, + pub role: String, + pub task: String, + pub system_prompt: Option, + pub context: Vec, + /// Override the master agent's temperature for this subagent. + pub temperature: Option, + /// Override the master agent's max_completion_tokens for this subagent. + pub max_completion_tokens: Option, +} + +impl AgentExpert { + pub fn new(id: impl Into, role: impl Into, task: impl Into) -> Self { + Self { + id: id.into(), + role: role.into(), + task: task.into(), + system_prompt: None, + context: Vec::new(), + temperature: None, + max_completion_tokens: None, + } + } + + pub fn with_system_prompt(mut self, system_prompt: impl Into) -> Self { + self.system_prompt = Some(system_prompt.into()); + self + } + + pub fn with_context(mut self, context: Vec) -> Self { + self.context = context; + self + } + + pub fn with_temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn with_max_completion_tokens(mut self, max_tokens: u64) -> Self { + self.max_completion_tokens = Some(max_tokens); + self + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentContextChunk { + pub id: String, + pub content: String, + pub source: Option, + pub score: Option, + pub metadata: Value, +} + +impl AgentContextChunk { + pub fn new(id: impl Into, content: impl Into) -> Self { + Self { + id: id.into(), + content: content.into(), + source: None, + score: None, + metadata: Value::Null, + } + } +} + +impl From for AgentContextChunk { + fn from(hit: crate::rag::RagSearchHit) -> Self { + Self { + id: hit.id, + content: hit.content, + source: Some(hit.session_id), + score: Some(hit.score), + metadata: Value::Object(hit.metadata.into_iter().collect()), + } + } +} + +impl From<&AgentExpertOutput> for AgentContextChunk { + fn from(output: &AgentExpertOutput) -> Self { + Self { + id: format!("subagent:{}", output.id), + content: output.output.clone(), + source: Some(output.role.clone()), + score: None, + metadata: serde_json::json!({ + "kind": "subagent", + "task": output.task, + }), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentResult { + pub output: String, + pub steps: Vec, + pub expert_outputs: Vec, + pub input_tokens: i64, + pub output_tokens: i64, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentStep { + pub index: usize, + pub assistant: Option, + pub reasoning_content: Option, + pub tool_calls: Vec, + pub reflection: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ToolCallRecord { + pub id: String, + pub name: String, + pub arguments: Value, + pub output: Option, + pub error: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub elapsed_ms: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentExpertOutput { + pub id: String, + pub role: String, + pub task: String, + pub output: String, + pub input_tokens: i64, + pub output_tokens: i64, +} diff --git a/lib/ai/agent/session.rs b/lib/ai/agent/session.rs new file mode 100644 index 0000000..972a980 --- /dev/null +++ b/lib/ai/agent/session.rs @@ -0,0 +1,535 @@ +use std::time::SystemTime; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use uuid::Uuid; + +use crate::error::{AiError, AiResult}; + +/// Current session file format version. +pub const SESSION_VERSION: u32 = 2; + +/// Session metadata header. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionHeader { + pub version: u32, + pub id: Uuid, + pub created_at: String, + pub parent_session: Option, + pub name: Option, +} + +impl SessionHeader { + pub fn new() -> Self { + Self { + version: SESSION_VERSION, + id: Uuid::new_v4(), + created_at: iso_now(), + parent_session: None, + name: None, + } + } + + pub fn with_parent(mut self, parent: Uuid) -> Self { + self.parent_session = Some(parent); + self + } + + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } +} + +/// Typed session entry — each entry in a session transcript is one of these variants. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum SessionEntry { + /// A user or assistant message. + Message { + id: Uuid, + parent_id: Option, + timestamp: String, + role: SessionMessageRole, + content: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + tool_result: Option, + }, + + /// A context compaction event (older messages summarized). + Compaction { + id: Uuid, + parent_id: Option, + timestamp: String, + summary: String, + first_kept_entry_id: Uuid, + messages_compacted: usize, + tokens_saved: i64, + #[serde(default, skip_serializing_if = "Option::is_none")] + details: Option, + }, + + /// A branch summary (created when forking from a different point in the tree). + BranchSummary { + id: Uuid, + parent_id: Option, + timestamp: String, + from_entry_id: Uuid, + summary: String, + entries_summarized: usize, + #[serde(default, skip_serializing_if = "Option::is_none")] + label: Option, + }, + + /// Model change during a session. + ModelChange { + id: Uuid, + parent_id: Option, + timestamp: String, + provider: String, + model_id: String, + }, + + /// Thinking level change during a session. + ThinkingLevelChange { + id: Uuid, + parent_id: Option, + timestamp: String, + level: String, + }, + + /// Custom extension data (not sent to LLM). + Custom { + id: Uuid, + parent_id: Option, + timestamp: String, + custom_type: String, + data: Option, + }, +} + +impl SessionEntry { + pub fn id(&self) -> Uuid { + match self { + Self::Message { id, .. } + | Self::Compaction { id, .. } + | Self::BranchSummary { id, .. } + | Self::ModelChange { id, .. } + | Self::ThinkingLevelChange { id, .. } + | Self::Custom { id, .. } => *id, + } + } + + pub fn parent_id(&self) -> Option { + match self { + Self::Message { parent_id, .. } + | Self::Compaction { parent_id, .. } + | Self::BranchSummary { parent_id, .. } + | Self::ModelChange { parent_id, .. } + | Self::ThinkingLevelChange { parent_id, .. } + | Self::Custom { parent_id, .. } => *parent_id, + } + } + + pub fn timestamp(&self) -> &str { + match self { + Self::Message { timestamp, .. } + | Self::Compaction { timestamp, .. } + | Self::BranchSummary { timestamp, .. } + | Self::ModelChange { timestamp, .. } + | Self::ThinkingLevelChange { timestamp, .. } + | Self::Custom { timestamp, .. } => timestamp, + } + } + + /// Create a user message entry. + pub fn user_message(parent_id: Option, content: impl Into) -> Self { + Self::Message { + id: Uuid::new_v4(), + parent_id, + timestamp: iso_now(), + role: SessionMessageRole::User, + content: content.into(), + tool_calls: None, + tool_result: None, + } + } + + /// Create an assistant message entry. + pub fn assistant_message( + parent_id: Option, + content: impl Into, + tool_calls: Option>, + ) -> Self { + Self::Message { + id: Uuid::new_v4(), + parent_id, + timestamp: iso_now(), + role: SessionMessageRole::Assistant, + content: content.into(), + tool_calls, + tool_result: None, + } + } + + /// Create a compaction entry. + pub fn compaction( + parent_id: Option, + summary: impl Into, + first_kept_entry_id: Uuid, + messages_compacted: usize, + tokens_saved: i64, + ) -> Self { + Self::Compaction { + id: Uuid::new_v4(), + parent_id, + timestamp: iso_now(), + summary: summary.into(), + first_kept_entry_id, + messages_compacted, + tokens_saved, + details: None, + } + } + + /// Create a branch summary entry. + pub fn branch_summary( + parent_id: Option, + from_entry_id: Uuid, + summary: impl Into, + entries_summarized: usize, + label: Option, + ) -> Self { + Self::BranchSummary { + id: Uuid::new_v4(), + parent_id, + timestamp: iso_now(), + from_entry_id, + summary: summary.into(), + entries_summarized, + label, + } + } + + /// Create a model change entry. + pub fn model_change( + parent_id: Option, + provider: impl Into, + model_id: impl Into, + ) -> Self { + Self::ModelChange { + id: Uuid::new_v4(), + parent_id, + timestamp: iso_now(), + provider: provider.into(), + model_id: model_id.into(), + } + } + + /// Create a custom extension entry. + pub fn custom( + parent_id: Option, + custom_type: impl Into, + data: Option, + ) -> Self { + Self::Custom { + id: Uuid::new_v4(), + parent_id, + timestamp: iso_now(), + custom_type: custom_type.into(), + data, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum SessionMessageRole { + User, + Assistant, + ToolResult, + System, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionToolCall { + pub id: String, + pub name: String, + pub arguments: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionToolResult { + pub tool_call_id: String, + pub tool_name: String, + pub content: String, + pub is_error: bool, +} + +/// A full session: header + ordered list of entries forming a tree. +/// +/// The tree structure supports forking: entries share `parent_id` links, +/// and the "active branch" is determined by following from the leaf +/// back to the root. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Session { + pub header: SessionHeader, + pub entries: Vec, +} + +impl Session { + pub fn new() -> Self { + Self { + header: SessionHeader::new(), + entries: Vec::new(), + } + } + + pub fn with_name(mut self, name: impl Into) -> Self { + self.header = self.header.with_name(name); + self + } + + /// Append an entry to the session. + pub fn push(&mut self, entry: SessionEntry) { + self.entries.push(entry); + } + + /// Get the last entry's id (used as parent_id for the next entry). + pub fn last_entry_id(&self) -> Option { + self.entries.last().map(|e| e.id()) + } + + /// Get all entries on the active branch (from root to leaf). + pub fn active_branch(&self) -> Vec<&SessionEntry> { + if self.entries.is_empty() { + return Vec::new(); + } + + let mut branch = Vec::new(); + let mut current_id = Some(self.entries.last().unwrap().id()); + + while let Some(id) = current_id { + if let Some(entry) = self.entries.iter().find(|e| e.id() == id) { + branch.push(entry); + current_id = entry.parent_id(); + } else { + break; + } + } + + branch.reverse(); + branch + } + + /// Get all message entries on the active branch (for LLM context). + pub fn active_messages(&self) -> Vec<&SessionEntry> { + self.active_branch() + .into_iter() + .filter(|e| matches!(e, SessionEntry::Message { .. } | SessionEntry::Compaction { .. })) + .collect() + } + + /// Find all children of a given entry (for tree navigation). + pub fn children_of(&self, parent_id: Uuid) -> Vec<&SessionEntry> { + self.entries + .iter() + .filter(|e| e.parent_id() == Some(parent_id)) + .collect() + } + + /// Get all leaf entries (entries with no children). + pub fn leaves(&self) -> Vec<&SessionEntry> { + let parent_ids: std::collections::HashSet = self + .entries + .iter() + .filter_map(|e| e.parent_id()) + .collect(); + + self.entries + .iter() + .filter(|e| !parent_ids.contains(&e.id())) + .collect() + } + + /// Count total entries. + pub fn entry_count(&self) -> usize { + self.entries.len() + } + + /// Fork from a specific entry, creating entries that belong to a new branch. + /// Returns the entries that should be in the new branch (from root to fork point). + pub fn fork_from(&self, fork_entry_id: Uuid) -> AiResult { + let fork_idx = self + .entries + .iter() + .position(|e| e.id() == fork_entry_id) + .ok_or_else(|| { + AiError::Config(format!("fork entry {fork_entry_id} not found in session")) + })?; + + let mut new_session = Session::new(); + new_session.header = new_session.header.with_parent(self.header.id); + + // Copy entries up to and including the fork point + for entry in &self.entries[..=fork_idx] { + new_session.entries.push(entry.clone()); + } + + Ok(new_session) + } + + /// Find the common ancestor of two entries. + pub fn common_ancestor(&self, id_a: Uuid, id_b: Uuid) -> Option { + let ancestors_a = self.ancestor_chain(id_a); + let ancestors_b: std::collections::HashSet = + self.ancestor_chain(id_b).into_iter().collect(); + + for ancestor in ancestors_a { + if ancestors_b.contains(&ancestor) { + return Some(ancestor); + } + } + None + } + + /// Get the chain of ancestor IDs from an entry back to the root. + fn ancestor_chain(&self, entry_id: Uuid) -> Vec { + let mut chain = Vec::new(); + let mut current_id = Some(entry_id); + + while let Some(id) = current_id { + chain.push(id); + current_id = self + .entries + .iter() + .find(|e| e.id() == id) + .and_then(|e| e.parent_id()); + } + + chain + } +} + +/// Options for session compaction. +#[derive(Debug, Clone)] +pub struct CompactionOptions { + /// Custom instructions for the compaction LLM call. + pub custom_instructions: Option, + /// Reserve this many tokens for the prompt + LLM response. + pub reserve_tokens: i64, + /// Keep this many recent message pairs untouched. + pub keep_recent_pairs: usize, + /// Whether to generate branch summaries for forked branches. + pub branch_summarization: bool, +} + +impl Default for CompactionOptions { + fn default() -> Self { + Self { + custom_instructions: None, + reserve_tokens: 16_384, + keep_recent_pairs: 4, + branch_summarization: true, + } + } +} + +fn iso_now() -> String { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| { + let secs = d.as_secs(); + // Simple ISO 8601 format (UTC) + let days = secs / 86400; + let years = (days * 400) / 146097; + let remaining_days = days - (years * 365 + years / 4 - years / 100 + years / 400); + let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; + let is_leap = (years % 4 == 0 && years % 100 != 0) || years % 400 == 0; + let mut month = 0usize; + let mut day_acc = remaining_days as i64; + for (i, &md) in month_days.iter().enumerate() { + let md = if i == 1 && is_leap { md + 1 } else { md }; + if day_acc < md as i64 { + month = i; + break; + } + day_acc -= md as i64; + } + let day = day_acc + 1; + let hour = (secs % 86400) / 3600; + let minute = (secs % 3600) / 60; + let second = secs % 60; + format!( + "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", + 1970 + years, + month + 1, + day, + hour, + minute, + second, + ) + }) + .unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_basic() { + let mut session = Session::new(); + + let msg1 = SessionEntry::user_message(None, "Hello"); + let msg1_id = msg1.id(); + session.push(msg1); + + let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None); + session.push(msg2); + + assert_eq!(session.entry_count(), 2); + assert_eq!(session.active_branch().len(), 2); + } + + #[test] + fn test_session_fork() { + let mut session = Session::new(); + + let msg1 = SessionEntry::user_message(None, "First"); + let msg1_id = msg1.id(); + session.push(msg1); + + let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None); + let msg2_id = msg2.id(); + session.push(msg2); + + let msg3 = SessionEntry::user_message(Some(msg2_id), "Second"); + session.push(msg3); + + // Fork from msg2 + let forked = session.fork_from(msg2_id).unwrap(); + assert_eq!(forked.entry_count(), 2); + assert_eq!(forked.header.parent_session, Some(session.header.id)); + } + + #[test] + fn test_session_leaves() { + let mut session = Session::new(); + + let msg1 = SessionEntry::user_message(None, "Root"); + let msg1_id = msg1.id(); + session.push(msg1); + + // Two children branching from root + let msg2a = SessionEntry::assistant_message(Some(msg1_id), "Branch A", None); + let msg2b = SessionEntry::assistant_message(Some(msg1_id), "Branch B", None); + session.push(msg2a); + session.push(msg2b); + + let leaves = session.leaves(); + assert_eq!(leaves.len(), 2); + } +} diff --git a/lib/ai/agent/subagent.rs b/lib/ai/agent/subagent.rs new file mode 100644 index 0000000..f23758d --- /dev/null +++ b/lib/ai/agent/subagent.rs @@ -0,0 +1,203 @@ +use rig::agent::AgentBuilder; +use rig::client::CompletionClient; +use rig::completion::Prompt; +use tracing::{debug, warn}; + +use super::config::AgentConfig; +use super::helpers::with_retry; +use super::persistence::{ + ActiveAgentRun, AgentRealtime, AgentRuntime, AgentStreamEvent, + estimate_output_tokens, +}; +use super::request::{AgentExpert, AgentExpertOutput}; +use crate::client::AiClient; +use crate::error::{AiError, AiResult}; + +pub async fn run_experts( + client: &AiClient, + config: &AgentConfig, + experts: &[AgentExpert], + realtime: Option<&AgentRealtime>, + run: &ActiveAgentRun, +) -> AiResult> { + let mut outputs = Vec::with_capacity(experts.len()); + let mut failed_count = 0; + + for expert in experts { + match run_single(client, config, expert, realtime, run).await { + Ok(output) => { + debug!(subagent_id = %output.id, role = %output.role, "subagent completed"); + outputs.push(output); + } + Err(error) => { + warn!(subagent_id = %expert.id, role = %expert.role, error = %error, "subagent failed"); + let _ = publish_subagent_failed(realtime, run, expert, &error.to_string()).await; + failed_count += 1; + } + } + } + + debug!(total = experts.len(), ok = outputs.len(), failed = failed_count, "experts done"); + Ok(outputs) +} + +async fn run_single( + client: &AiClient, + config: &AgentConfig, + expert: &AgentExpert, + realtime: Option<&AgentRealtime>, + run: &ActiveAgentRun, +) -> AiResult { + publish_subagent_started(realtime, run, config, expert).await?; + + let rig_client = client.llm_client().clone(); + let model_name = config.model.clone(); + let temperature = expert.temperature.or(config.temperature); + let max_completion_tokens = expert.max_completion_tokens.or(config.max_completion_tokens); + let retry_attempts = config.retry_max_attempts; + let retry_delay_ms = config.retry_base_delay_ms; + + let prompt = expert.system_prompt.clone().unwrap_or_else(|| { + format!( + "You are a specialist subagent. Role: {}. Produce a concise expert answer for the parent chat agent.", + expert.role + ) + }); + + let task = build_expert_task(expert); + + let (output, input_tokens_usage, output_tokens_usage) = with_retry( + retry_attempts, + retry_delay_ms, + || { + let rig_client = rig_client.clone(); + let model_name = model_name.clone(); + let prompt = prompt.clone(); + let task = task.clone(); + async move { + let model = rig_client.completion_model(&model_name); + let mut builder = AgentBuilder::new(model).preamble(&prompt); + if let Some(temp) = temperature { + builder = builder.temperature(temp); + } + if let Some(mt) = max_completion_tokens { + builder = builder.max_tokens(mt); + } + let agent = builder.build(); + + let response = agent + .prompt(&task) + .extended_details() + .await + .map_err(|e: rig::completion::PromptError| { + AiError::Api(e.to_string()) + })?; + + Ok(( + response.output, + response.usage.input_tokens, + response.usage.output_tokens, + )) + } + }, + ) + .await?; + + let input_tokens = input_tokens_usage as i64; + let output_tokens = if output_tokens_usage > 0 { + output_tokens_usage as i64 + } else { + estimate_output_tokens(&output) + }; + + let result = AgentExpertOutput { + id: expert.id.clone(), + role: expert.role.clone(), + task: expert.task.clone(), + output, + input_tokens, + output_tokens, + }; + + publish_subagent_completed(realtime, run, config, &result).await?; + Ok(result) +} + +fn build_expert_task(expert: &AgentExpert) -> String { + let mut task = String::new(); + + if !expert.context.is_empty() { + task.push_str("Retrieved context for this specialist task:\n"); + for (index, chunk) in expert.context.iter().enumerate() { + task.push_str(&format!( + "\n[{}] id={} source={}\n{}\n", + index + 1, + chunk.id, + chunk.source.as_deref().unwrap_or("unknown"), + chunk.content + )); + } + task.push('\n'); + } + + task.push_str(&expert.task); + task +} + +async fn publish_subagent_started( + realtime: Option<&AgentRealtime>, + run: &ActiveAgentRun, + config: &AgentConfig, + expert: &AgentExpert, +) -> AiResult<()> { + AgentRuntime::default().publish( + realtime, + &AgentStreamEvent::SubagentStarted { + conversation_id: run.conversation_id, + message_id: run.message_id, + subagent_id: expert.id.clone(), + role: expert.role.clone(), + task: expert.task.clone(), + model: config.model.clone(), + }, + ).await +} + +async fn publish_subagent_completed( + realtime: Option<&AgentRealtime>, + run: &ActiveAgentRun, + config: &AgentConfig, + output: &AgentExpertOutput, +) -> AiResult<()> { + AgentRuntime::default().publish( + realtime, + &AgentStreamEvent::SubagentCompleted { + conversation_id: run.conversation_id, + message_id: run.message_id, + subagent_id: output.id.clone(), + role: output.role.clone(), + task: output.task.clone(), + output: output.output.clone(), + input_tokens: output.input_tokens, + output_tokens: output.output_tokens, + model: config.model.clone(), + }, + ).await +} + +async fn publish_subagent_failed( + realtime: Option<&AgentRealtime>, + run: &ActiveAgentRun, + expert: &AgentExpert, + error: &str, +) -> AiResult<()> { + AgentRuntime::default().publish( + realtime, + &AgentStreamEvent::SubagentFailed { + conversation_id: run.conversation_id, + message_id: run.message_id, + subagent_id: expert.id.clone(), + error: error.to_string(), + }, + ).await +} diff --git a/lib/ai/agent/tool.rs b/lib/ai/agent/tool.rs new file mode 100644 index 0000000..08d9c75 --- /dev/null +++ b/lib/ai/agent/tool.rs @@ -0,0 +1,158 @@ +use std::pin::Pin; +use std::sync::Arc; + +use rig::completion::ToolDefinition as RigToolDefinition; +use rig::tool::ToolDyn; +use serde_json::Value; +use tokio::sync::Mutex; + +use crate::tool::tools::FunctionCall; + +pub struct RigTool +where + C: Clone + Send + Sync + 'static, +{ + context: Arc>, + tool: Arc>, + name: String, + description: String, + schema: Value, +} + +impl RigTool +where + C: Clone + Send + Sync + 'static, +{ + pub fn new(tool: Arc>, context: Arc>) -> Self { + let name = tool.name().to_string(); + let description = tool.description().to_string(); + let schema = tool.schema(); + + Self { + context, + tool, + name, + description, + schema, + } + } +} + +impl ToolDyn for RigTool +where + C: Clone + Send + Sync + 'static, +{ + fn name(&self) -> String { + self.name.clone() + } + + fn definition<'a>( + &'a self, + _prompt: String, + ) -> Pin + Send + 'a>> { + let name = self.name.clone(); + let description = self.description.clone(); + let params = self.schema.clone(); + + Box::pin(async move { + RigToolDefinition { + name, + description, + parameters: params, + } + }) + } + + fn call<'a>( + &'a self, + args: String, + ) -> Pin< + Box> + Send + 'a>, + > { + let tool = self.tool.clone(); + let context = self.context.clone(); + + Box::pin(async move { + let args_value: Value = + serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?; + + let mut ctx = context.lock().await; + + match tool.call(&mut *ctx, args_value).await { + Ok(value) => serde_json::to_string(&value) + .map_err(rig::tool::ToolError::JsonError), + Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new( + std::io::Error::other(ai_err.to_string()), + ))), + } + }) + } +} + +pub struct RigToolSet +where + C: Clone + Send + Sync + 'static, +{ + tools: Vec>, + context: Option>>, +} + +impl RigToolSet +where + C: Clone + Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + tools: Vec::new(), + context: None, + } + } + + pub fn from_register( + register: &crate::tool::register::ToolRegister, + context: Arc>, + ) -> Self { + let mut tools: Vec> = Vec::with_capacity(register.len()); + + for tool_arc in ®ister.tools { + tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone()))); + } + + Self { + tools, + context: Some(context), + } + } + + pub fn is_empty(&self) -> bool { + self.tools.is_empty() + } + + pub fn len(&self) -> usize { + self.tools.len() + } + + pub fn context(&self) -> Option<&Arc>> { + self.context.as_ref() + } + + pub fn take_tools(&mut self) -> Vec> { + std::mem::take(&mut self.tools) + } + + pub fn into_context(mut self) -> C { + self.context + .take() + .and_then(|arc| Arc::try_unwrap(arc).ok().map(|m| m.into_inner())) + .unwrap_or_else(|| unreachable!("context must be available")) + } +} + +impl Default for RigToolSet +where + C: Clone + Send + Sync + 'static, +{ + fn default() -> Self { + Self::new() + } +} diff --git a/lib/ai/client.rs b/lib/ai/client.rs new file mode 100644 index 0000000..25a2678 --- /dev/null +++ b/lib/ai/client.rs @@ -0,0 +1,219 @@ +use std::fmt; +use std::sync::Arc; + +use config::AppConfig; +use rig::providers::openai; + +use crate::error::{AiError, AiResult}; + +fn validate_required(scope: &str, field: &str, value: &str) -> AiResult<()> { + if value.trim().is_empty() { + return Err(AiError::Config(format!("{scope} {field} is required"))); + } + Ok(()) +} + +fn config_error(error: impl fmt::Display) -> AiError { + AiError::Config(error.to_string()) +} + +#[derive(Clone)] +pub struct EndpointConfig { + pub base_url: String, + pub api_key: String, +} + +impl EndpointConfig { + pub fn new(base_url: impl Into, api_key: impl Into) -> AiResult { + let config = Self { + base_url: base_url.into(), + api_key: api_key.into(), + }; + config.validate("endpoint")?; + Ok(config) + } + + fn validate(&self, scope: &str) -> AiResult<()> { + validate_required(scope, "base_url", &self.base_url)?; + validate_required(scope, "api_key", &self.api_key)?; + if !self.base_url.trim().starts_with("http://") + && !self.base_url.trim().starts_with("https://") + { + return Err(AiError::Config(format!( + "{scope} base_url must start with http:// or https://" + ))); + } + Ok(()) + } + + pub fn build_client(&self) -> AiResult { + openai::Client::builder() + .api_key(&self.api_key) + .base_url(self.base_url.trim()) + .build() + .map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}"))) + } +} + +impl fmt::Debug for EndpointConfig { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter + .debug_struct("EndpointConfig") + .field("base_url", &self.base_url) + .field("api_key", &"") + .finish() + } +} + +#[derive(Clone, Debug)] +pub struct EmbedConfig { + pub endpoint: EndpointConfig, + pub model: String, + pub dimensions: u64, +} + +impl EmbedConfig { + pub fn new( + endpoint: EndpointConfig, + model: impl Into, + dimensions: u64, + ) -> AiResult { + let config = Self { + endpoint, + model: model.into(), + dimensions, + }; + config.validate()?; + Ok(config) + } + + fn validate(&self) -> AiResult<()> { + self.endpoint.validate("embed endpoint")?; + validate_required("embed", "model", &self.model)?; + if self.dimensions == 0 { + return Err(AiError::Config( + "embed dimensions must be greater than 0".to_string(), + )); + } + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub struct AiClientConfig { + pub llm: EndpointConfig, + pub embed: EmbedConfig, +} + +impl AiClientConfig { + pub fn new(llm: EndpointConfig, embed: EmbedConfig) -> AiResult { + let config = Self { llm, embed }; + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> AiResult<()> { + self.llm.validate("llm endpoint")?; + self.embed.validate()?; + Ok(()) + } +} + +impl TryFrom<&AppConfig> for AiClientConfig { + type Error = AiError; + + fn try_from(config: &AppConfig) -> Result { + let llm = EndpointConfig::new( + config.ai_basic_url().map_err(config_error)?, + config.ai_api_key().map_err(config_error)?, + )?; + + let embed_endpoint = EndpointConfig::new( + config.get_embed_model_base_url().map_err(config_error)?, + config.get_embed_model_api_key().map_err(config_error)?, + )?; + + let embed = EmbedConfig::new( + embed_endpoint, + config.get_embed_model_name().map_err(config_error)?, + config.get_embed_model_dimensions().map_err(config_error)?, + )?; + + Self::new(llm, embed) + } +} + +#[derive(Clone, Debug)] +pub struct AiClient { + pub(super) llm_client: openai::Client, + pub(super) embed_client: openai::Client, + pub(super) config: Arc, +} + +impl AiClient { + pub fn new(config: AiClientConfig) -> AiResult { + config.validate()?; + + Ok(Self { + llm_client: config.llm.build_client()?, + embed_client: config.embed.endpoint.build_client()?, + config: Arc::new(config), + }) + } + + pub fn from_app_config(config: &AppConfig) -> AiResult { + Self::new(AiClientConfig::try_from(config)?) + } + + pub fn llm_client(&self) -> &openai::Client { + &self.llm_client + } + + pub fn embed_client(&self) -> &openai::Client { + &self.embed_client + } + + pub fn config(&self) -> &AiClientConfig { + &self.config + } + + pub fn llm_config(&self) -> &EndpointConfig { + &self.config.llm + } + + pub fn embed_config(&self) -> &EmbedConfig { + &self.config.embed + } + + pub fn embed_model(&self) -> &str { + self.config.embed.model.as_str() + } + + pub fn embed_dimensions(&self) -> u64 { + self.config.embed.dimensions + } + + pub fn embed_dimensions_u32(&self) -> u32 { + u32::try_from(self.config.embed.dimensions).unwrap_or(u32::MAX) + } +} + +pub fn build_http_client() -> Result { + let mut builder = reqwest::Client::builder(); + + if let Ok(proxy_url) = std::env::var("HTTPS_PROXY") + .or_else(|_| std::env::var("https_proxy")) + .or_else(|_| std::env::var("HTTP_PROXY")) + .or_else(|_| std::env::var("http_proxy")) + { + let proxy_url = proxy_url.trim().trim_matches('"').trim_matches('\''); + let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| { + AiError::Config(format!("Invalid proxy URL '{}': {}", proxy_url, e)) + })?; + builder = builder.proxy(proxy); + } + + builder.build().map_err(|e| { + AiError::Config(format!("Failed to build HTTP client: {}", e)) + }) +} diff --git a/lib/ai/embed/client.rs b/lib/ai/embed/client.rs new file mode 100644 index 0000000..f250cf4 --- /dev/null +++ b/lib/ai/embed/client.rs @@ -0,0 +1,76 @@ +use rig::client::EmbeddingsClient; +use rig::embeddings::EmbeddingModel; + +use crate::{client::AiClient, error::{AiError, AiResult}}; + +#[derive(Clone)] +pub struct EmbedClient { + model_name: String, + client: rig::providers::openai::Client, +} + +impl EmbedClient { + pub fn new(ai_client: &AiClient) -> AiResult { + Ok(Self { + model_name: ai_client.embed_model().to_string(), + client: ai_client.embed_client().clone(), + }) + } + + fn embedding_model(&self) -> impl EmbeddingModel + '_ { + self.client.embedding_model(&self.model_name) + } + + pub async fn embed_text(&self, text: String) -> AiResult> { + let model = self.embedding_model(); + let mut embeddings = model.embed_texts(vec![text]) + .await + .map_err(|e| AiError::Api(e.to_string()))?; + embeddings.pop() + .map(|e| e.vec.into_iter().map(|v| v as f32).collect()) + .ok_or_else(|| AiError::Response("no embedding returned".to_string())) + } + + pub async fn embed_texts(&self, texts: Vec) -> AiResult>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + let model = self.embedding_model(); + let embeddings = model.embed_texts(texts) + .await + .map_err(|e| AiError::Api(e.to_string()))?; + Ok(embeddings.into_iter() + .map(|e| e.vec.into_iter().map(|v| v as f32).collect()) + .collect()) + } + + pub async fn embed_texts_chunked( + &self, + texts: Vec, + batch_size: usize, + ) -> AiResult>> { + if batch_size == 0 { + return Err(AiError::Config("batch_size must be > 0".to_string())); + } + let mut embeddings: Vec> = Vec::with_capacity(texts.len()); + for chunk in texts.chunks(batch_size) { + let model = self.embedding_model(); + let chunk_embeddings = model.embed_texts(chunk.to_vec()) + .await + .map_err(|e| AiError::Api(e.to_string()))?; + embeddings.extend(chunk_embeddings.into_iter() + .map(|e| e.vec.into_iter().map(|v| v as f32).collect())); + } + Ok(embeddings) + } +} + +pub trait AiClientEmbedExt { + fn embedder(&self) -> AiResult; +} + +impl AiClientEmbedExt for AiClient { + fn embedder(&self) -> AiResult { + EmbedClient::new(self) + } +} diff --git a/lib/ai/embed/mod.rs b/lib/ai/embed/mod.rs new file mode 100644 index 0000000..18a300e --- /dev/null +++ b/lib/ai/embed/mod.rs @@ -0,0 +1,3 @@ +mod client; + +pub use client::{AiClientEmbedExt, EmbedClient}; diff --git a/lib/ai/error.rs b/lib/ai/error.rs new file mode 100644 index 0000000..f41f408 --- /dev/null +++ b/lib/ai/error.rs @@ -0,0 +1,52 @@ +pub type AiResult = Result; + +#[derive(Debug, thiserror::Error)] +pub enum AiError { + #[error("ai config error: {0}")] + Config(String), + + #[error("ai api error: {0}")] + Api(String), + + #[error("qdrant error: {0}")] + Qdrant(Box), + + #[error("database error: {0}")] + Database(#[from] db::sqlx::Error), + + #[error("cache error: {0}")] + Cache(#[from] cache::CacheError), + + #[error("redis error: {0}")] + Redis(#[from] redis::RedisError), + + #[error("ai response error: {0}")] + Response(String), + + #[error("model retries exhausted after {attempts} attempts: {last_error}")] + ModelRetriesExhausted { + attempts: usize, + last_error: String, + }, + + #[error("agent timeout after {seconds}s")] + Timeout { seconds: u64 }, + + #[error("tool not found: {tool}")] + ToolNotFound { tool: String }, + + #[error("tool execution failed: {cause}")] + ToolExecutionFailed { cause: String }, + + #[error("invalid input in '{field}': {reason}")] + InvalidInput { field: String, reason: String }, + + #[error("token budget exceeded: used ~{estimated} tokens, limit {limit}")] + TokenBudgetExceeded { estimated: u64, limit: i64 }, +} + +impl From for AiError { + fn from(e: qdrant_client::QdrantError) -> Self { + AiError::Qdrant(Box::new(e)) + } +} diff --git a/lib/ai/lib.rs b/lib/ai/lib.rs new file mode 100644 index 0000000..bc51824 --- /dev/null +++ b/lib/ai/lib.rs @@ -0,0 +1,8 @@ +pub mod agent; +pub mod client; +pub mod embed; +pub mod error; +pub mod memory; +pub mod rag; +pub mod sync; +pub mod tool; diff --git a/lib/ai/memory/mod.rs b/lib/ai/memory/mod.rs new file mode 100644 index 0000000..45a5c31 --- /dev/null +++ b/lib/ai/memory/mod.rs @@ -0,0 +1,46 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::error::AiResult; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryEntry { + pub key: String, + pub value: String, + pub importance: i32, + pub last_used_at: Option, +} +#[async_trait] +pub trait MemoryProvider: Send + Sync { + fn name(&self) -> &'static str; + async fn save( + &self, + session_id: Uuid, + key: &str, + value: &str, + importance: i32, + ) -> AiResult<()>; + async fn recall( + &self, + session_id: Uuid, + query: &str, + limit: usize, + ) -> AiResult>; + async fn forget(&self, session_id: Uuid, key: &str) -> AiResult<()>; + async fn prefetch( + &self, + _session_id: Uuid, + _query: &str, + ) -> AiResult> { + Ok(Vec::new()) + } + async fn build_context_block( + &self, + _session_id: Uuid, + ) -> AiResult { + Ok(String::new()) + } + async fn setup(&self) -> AiResult<()> { + Ok(()) + } +} diff --git a/lib/ai/rag/client.rs b/lib/ai/rag/client.rs new file mode 100644 index 0000000..c4da99b --- /dev/null +++ b/lib/ai/rag/client.rs @@ -0,0 +1,263 @@ +use config::AppConfig; +use qdrant_client::qdrant::{ + CreateCollectionBuilder, CreateFieldIndexCollectionBuilder, + DeletePointsBuilder, FieldType, PointStruct, QueryPointsBuilder, + SearchParamsBuilder, UpsertPointsBuilder, VectorParamsBuilder, +}; +use qdrant_client::{Qdrant, QdrantError}; + +use super::{ + config::RagConfig, + document::{RagDocument, RagSearchHit}, + payload::{ + SESSION_ID_KEY, document_payload, hit_from_scored_point, point_id, + }, + search::RagSearchOptions, + session::{session_filter, validate_session_id}, +}; +use crate::{ + client::AiClient, + embed::{AiClientEmbedExt, EmbedClient}, + error::{AiError, AiResult}, +}; + +#[derive(Clone)] +pub struct RagClient { + qdrant: Qdrant, + embedder: EmbedClient, + config: RagConfig, +} + +impl RagClient { + pub fn new( + qdrant: Qdrant, + embedder: EmbedClient, + config: RagConfig, + ) -> AiResult { + config.validate()?; + Ok(Self { + qdrant, + embedder, + config, + }) + } + + pub fn connect( + ai_client: &AiClient, + config: RagConfig, + ) -> AiResult { + config.validate()?; + let mut builder = + Qdrant::from_url(config.url.trim()).timeout(config.timeout); + if let Some(api_key) = config + .api_key + .as_deref() + .filter(|api_key| !api_key.trim().is_empty()) + { + builder = builder.api_key(api_key); + } + + Self::new(builder.build()?, ai_client.embedder()?, config) + } + + pub fn from_app_config( + ai_client: &AiClient, + config: &AppConfig, + collection_name: impl Into, + ) -> AiResult { + Self::connect( + ai_client, + RagConfig::from_app_config(config, collection_name)?, + ) + } + + pub fn qdrant(&self) -> &Qdrant { + &self.qdrant + } + + pub fn embedder(&self) -> &EmbedClient { + &self.embedder + } + + pub fn config(&self) -> &RagConfig { + &self.config + } + + pub async fn ensure_collection(&self) -> AiResult<()> { + if !self + .qdrant + .collection_exists(&self.config.collection_name) + .await? + { + self.qdrant + .create_collection( + CreateCollectionBuilder::new(&self.config.collection_name) + .vectors_config(VectorParamsBuilder::new( + self.config.vector_size, + self.config.distance, + )), + ) + .await?; + } + + match self + .qdrant + .create_field_index(CreateFieldIndexCollectionBuilder::new( + &self.config.collection_name, + SESSION_ID_KEY, + FieldType::Keyword, + )) + .await + { + Ok(_) => Ok(()), + Err(QdrantError::ResponseError { .. }) => Ok(()), + Err(error) => Err(error.into()), + } + } + + pub async fn upsert_document( + &self, + session_id: impl AsRef, + document: RagDocument, + ) -> AiResult<()> { + self.upsert_documents(session_id, vec![document]).await + } + + pub async fn upsert_documents( + &self, + session_id: impl AsRef, + documents: Vec, + ) -> AiResult<()> { + let session_id = session_id.as_ref(); + validate_session_id(session_id)?; + validate_documents(&documents)?; + + let texts: Vec = documents + .iter() + .map(|d| d.content.clone()) + .collect(); + let vectors = self + .embedder + .embed_texts_chunked(texts, self.config.upsert_batch_size) + .await?; + + let points = documents + .iter() + .zip(vectors) + .map(|(document, vector)| { + Ok(PointStruct::new( + point_id(session_id, &document.id), + vector, + document_payload(session_id, document)?, + )) + }) + .collect::>>()?; + + self.qdrant + .upsert_points( + UpsertPointsBuilder::new(&self.config.collection_name, points) + .wait(true), + ) + .await?; + + Ok(()) + } + + pub async fn search_session( + &self, + session_id: impl AsRef, + query: impl Into, + ) -> AiResult> { + let options = RagSearchOptions { + limit: self.config.default_search_limit, + exact: self.config.exact_session_search, + }; + self.search_session_with_options(session_id, query, options) + .await + } + + pub async fn search_session_with_options( + &self, + session_id: impl AsRef, + query: impl Into, + options: RagSearchOptions, + ) -> AiResult> { + let vector = self.embedder.embed_text(query.into()).await?; + self.search_session_by_vector(session_id, vector, options) + .await + } + + pub async fn search_session_by_vector( + &self, + session_id: impl AsRef, + vector: Vec, + options: RagSearchOptions, + ) -> AiResult> { + let session_id = session_id.as_ref(); + validate_session_id(session_id)?; + if options.limit == 0 { + return Err(AiError::Config( + "rag search limit must be greater than 0".to_string(), + )); + } + + let response = self + .qdrant + .query( + QueryPointsBuilder::new(&self.config.collection_name) + .query(vector) + .limit(options.limit) + .filter(session_filter(session_id)) + .with_payload(true) + .params( + SearchParamsBuilder::default().exact(options.exact), + ), + ) + .await?; + + Ok(response + .result + .into_iter() + .map(hit_from_scored_point) + .collect()) + } + + pub async fn clear_session( + &self, + session_id: impl AsRef, + ) -> AiResult<()> { + let session_id = session_id.as_ref(); + validate_session_id(session_id)?; + + self.qdrant + .delete_points( + DeletePointsBuilder::new(&self.config.collection_name) + .points(session_filter(session_id)) + .wait(true), + ) + .await?; + + Ok(()) + } +} + +fn validate_documents(documents: &[RagDocument]) -> AiResult<()> { + if documents.is_empty() { + return Err(AiError::Config("rag documents are required".to_string())); + } + + for document in documents { + if document.id.trim().is_empty() { + return Err(AiError::Config( + "rag document id is required".to_string(), + )); + } + if document.content.trim().is_empty() { + return Err(AiError::Config( + "rag document content is required".to_string(), + )); + } + } + + Ok(()) +} diff --git a/lib/ai/rag/config.rs b/lib/ai/rag/config.rs new file mode 100644 index 0000000..bd9bcab --- /dev/null +++ b/lib/ai/rag/config.rs @@ -0,0 +1,134 @@ +use std::time::Duration; + +use config::AppConfig; +use qdrant_client::qdrant::Distance; + +use crate::error::{AiError, AiResult}; + +#[derive(Clone, Debug)] +pub struct RagConfig { + pub url: String, + pub api_key: Option, + pub collection_name: String, + pub vector_size: u64, + pub distance: Distance, + pub timeout: Duration, + pub upsert_batch_size: usize, + pub default_search_limit: u64, + pub exact_session_search: bool, +} + +impl RagConfig { + pub fn new( + url: impl Into, + collection_name: impl Into, + vector_size: u64, + ) -> AiResult { + let config = Self { + url: url.into(), + api_key: None, + collection_name: collection_name.into(), + vector_size, + distance: Distance::Cosine, + timeout: Duration::from_secs(10), + upsert_batch_size: 64, + default_search_limit: 8, + exact_session_search: true, + }; + config.validate()?; + Ok(config) + } + + pub fn with_api_key(mut self, api_key: Option) -> Self { + self.api_key = api_key; + self + } + + pub fn with_distance(mut self, distance: Distance) -> Self { + self.distance = distance; + self + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn with_upsert_batch_size(mut self, upsert_batch_size: usize) -> Self { + self.upsert_batch_size = upsert_batch_size; + self + } + + pub fn with_default_search_limit( + mut self, + default_search_limit: u64, + ) -> Self { + self.default_search_limit = default_search_limit; + self + } + + pub fn with_exact_session_search( + mut self, + exact_session_search: bool, + ) -> Self { + self.exact_session_search = exact_session_search; + self + } + + pub fn validate(&self) -> AiResult<()> { + if self.url.trim().is_empty() { + return Err(AiError::Config("qdrant url is required".to_string())); + } + if !self.url.trim().starts_with("http://") + && !self.url.trim().starts_with("https://") + { + return Err(AiError::Config( + "qdrant url must start with http:// or https://".to_string(), + )); + } + if self.collection_name.trim().is_empty() { + return Err(AiError::Config( + "qdrant collection_name is required".to_string(), + )); + } + if self.vector_size == 0 { + return Err(AiError::Config( + "qdrant vector_size must be greater than 0".to_string(), + )); + } + if self.upsert_batch_size == 0 { + return Err(AiError::Config( + "qdrant upsert_batch_size must be greater than 0".to_string(), + )); + } + if self.default_search_limit == 0 { + return Err(AiError::Config( + "qdrant default_search_limit must be greater than 0" + .to_string(), + )); + } + Ok(()) + } +} + +impl RagConfig { + pub fn from_app_config( + config: &AppConfig, + collection_name: impl Into, + ) -> AiResult { + Ok(Self::new( + config + .qdrant_url() + .map_err(|error| AiError::Config(error.to_string()))?, + collection_name, + config + .get_embed_model_dimensions() + .map_err(|error| AiError::Config(error.to_string()))?, + )? + .with_api_key( + config + .qdrant_api_key() + .map_err(|error| AiError::Config(error.to_string()))?, + )) + } +} diff --git a/lib/ai/rag/document.rs b/lib/ai/rag/document.rs new file mode 100644 index 0000000..c56422c --- /dev/null +++ b/lib/ai/rag/document.rs @@ -0,0 +1,44 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RagDocument { + pub id: String, + pub content: String, + pub metadata: HashMap, +} + +impl RagDocument { + pub fn new(id: impl Into, content: impl Into) -> Self { + Self { + id: id.into(), + content: content.into(), + metadata: HashMap::new(), + } + } + + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } + + pub fn metadata_value( + mut self, + key: impl Into, + value: impl Into, + ) -> Self { + self.metadata.insert(key.into(), value.into()); + self + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RagSearchHit { + pub id: String, + pub session_id: String, + pub score: f32, + pub content: String, + pub metadata: HashMap, +} diff --git a/lib/ai/rag/mod.rs b/lib/ai/rag/mod.rs new file mode 100644 index 0000000..ba16d88 --- /dev/null +++ b/lib/ai/rag/mod.rs @@ -0,0 +1,11 @@ +mod client; +mod config; +mod document; +mod payload; +mod search; +mod session; + +pub use client::RagClient; +pub use config::RagConfig; +pub use document::{RagDocument, RagSearchHit}; +pub use search::RagSearchOptions; diff --git a/lib/ai/rag/payload.rs b/lib/ai/rag/payload.rs new file mode 100644 index 0000000..d8bb6e9 --- /dev/null +++ b/lib/ai/rag/payload.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; + +use qdrant_client::Payload; +use qdrant_client::qdrant::{ + PointId, ScoredPoint, point_id::PointIdOptions, value::Kind, +}; +use serde_json::{Map, Value, json}; +use uuid::Uuid; + +use super::document::{RagDocument, RagSearchHit}; +use crate::error::{AiError, AiResult}; + +pub(super) const SESSION_ID_KEY: &str = "session_id"; +pub(super) const DOCUMENT_ID_KEY: &str = "document_id"; +pub(super) const CONTENT_KEY: &str = "content"; +pub(super) const METADATA_KEY: &str = "metadata"; +pub(super) fn point_id(session_id: &str, document_id: &str) -> u64 { + let ns = Uuid::NAMESPACE_DNS; + let key = format!("{session_id}:{document_id}"); + let uuid = Uuid::new_v5(&ns, key.as_bytes()); + let bytes = uuid.as_bytes(); + u64::from_be_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], + bytes[4], bytes[5], bytes[6], bytes[7], + ]) +} + +pub(super) fn document_payload( + session_id: &str, + document: &RagDocument, +) -> AiResult { + Payload::try_from(json!({ + SESSION_ID_KEY: session_id, + DOCUMENT_ID_KEY: document.id, + CONTENT_KEY: document.content, + METADATA_KEY: document.metadata, + })) + .map_err(|error| AiError::Config(error.to_string())) +} + +pub(super) fn hit_from_scored_point(point: ScoredPoint) -> RagSearchHit { + let id = point_id_to_string(point.id); + let mut payload = qdrant_payload_to_json(point.payload); + let session_id = take_string(&mut payload, SESSION_ID_KEY); + let document_id = take_string(&mut payload, DOCUMENT_ID_KEY); + let content = take_string(&mut payload, CONTENT_KEY); + let metadata = payload + .remove(METADATA_KEY) + .and_then(|value| match value { + Value::Object(object) => Some(object.into_iter().collect()), + _ => None, + }) + .unwrap_or_default(); + + RagSearchHit { + id: if document_id.is_empty() { + id + } else { + document_id + }, + session_id, + score: point.score, + content, + metadata, + } +} + +fn point_id_to_string(id: Option) -> String { + match id.and_then(|id| id.point_id_options) { + Some(PointIdOptions::Num(id)) => id.to_string(), + Some(PointIdOptions::Uuid(id)) => id, + None => String::new(), + } +} + +fn qdrant_payload_to_json( + payload: HashMap, +) -> Map { + payload + .into_iter() + .map(|(key, value)| (key, value_to_json(value))) + .collect() +} + +fn value_to_json(value: qdrant_client::qdrant::Value) -> Value { + match value.kind { + Some(Kind::NullValue(_)) | None => Value::Null, + Some(Kind::DoubleValue(value)) => json!(value), + Some(Kind::IntegerValue(value)) => json!(value), + Some(Kind::StringValue(value)) => json!(value), + Some(Kind::BoolValue(value)) => json!(value), + Some(Kind::StructValue(value)) => Value::Object( + value + .fields + .into_iter() + .map(|(key, value)| (key, value_to_json(value))) + .collect(), + ), + Some(Kind::ListValue(value)) => { + Value::Array(value.values.into_iter().map(value_to_json).collect()) + } + } +} + +fn take_string(payload: &mut Map, key: &str) -> String { + payload + .remove(key) + .and_then(|value| value.as_str().map(ToOwned::to_owned)) + .unwrap_or_default() +} diff --git a/lib/ai/rag/search.rs b/lib/ai/rag/search.rs new file mode 100644 index 0000000..95fcd90 --- /dev/null +++ b/lib/ai/rag/search.rs @@ -0,0 +1,16 @@ +#[derive(Clone, Debug)] +pub struct RagSearchOptions { + pub limit: u64, + pub exact: bool, +} + +impl RagSearchOptions { + pub fn new(limit: u64) -> Self { + Self { limit, exact: true } + } + + pub fn with_exact(mut self, exact: bool) -> Self { + self.exact = exact; + self + } +} diff --git a/lib/ai/rag/session.rs b/lib/ai/rag/session.rs new file mode 100644 index 0000000..f049b88 --- /dev/null +++ b/lib/ai/rag/session.rs @@ -0,0 +1,15 @@ +use qdrant_client::qdrant::{Condition, Filter}; + +use super::payload::SESSION_ID_KEY; +use crate::error::{AiError, AiResult}; + +pub(super) fn validate_session_id(session_id: &str) -> AiResult<()> { + if session_id.trim().is_empty() { + return Err(AiError::Config("rag session_id is required".to_string())); + } + Ok(()) +} + +pub(super) fn session_filter(session_id: &str) -> Filter { + Filter::all([Condition::matches(SESSION_ID_KEY, session_id.to_string())]) +} diff --git a/lib/ai/sync.rs b/lib/ai/sync.rs new file mode 100644 index 0000000..4f7dae5 --- /dev/null +++ b/lib/ai/sync.rs @@ -0,0 +1,126 @@ +use std::error::Error; +use std::sync::LazyLock; + +use tracing::{debug, warn}; + +use crate::{ + client::EndpointConfig, + error::{AiError, AiResult}, +}; +#[derive(Debug, serde::Deserialize)] +struct ModelsListResponse { + data: Vec, +} +#[derive(Debug, Clone, serde::Deserialize)] +pub struct UpstreamModel { + pub id: String, + #[serde(default)] + pub name: Option, + #[serde(default)] + pub owned_by: Option, + #[serde(default)] + pub context_length: Option, + #[serde(default)] + pub max_output_tokens: Option, + #[serde(default)] + pub capabilities: Option, + #[serde(default)] + pub pricing: Option, +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct UpstreamCapabilities { + #[serde(default)] + pub vision: Option, + #[serde(default)] + pub tool_call: Option, + #[serde(default)] + pub reasoning: Option, +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct UpstreamPricing { + #[serde(default)] + pub prompt: Option, + #[serde(default)] + pub completion: Option, + #[serde(default)] + pub input: Option, + #[serde(default)] + pub output: Option, + #[serde(default)] + pub cache_read: Option, + #[serde(default)] + pub unit: Option, + #[serde(default)] + pub currency: Option, +} +static HTTP_CLIENT: LazyLock = LazyLock::new(|| { + let mut builder = reqwest::Client::builder(); + let proxy_url = std::env::var("HTTPS_PROXY") + .or_else(|_| std::env::var("https_proxy")) + .or_else(|_| std::env::var("HTTP_PROXY")) + .or_else(|_| std::env::var("http_proxy")) + .ok(); + if let Some(raw) = &proxy_url { + let url = raw.trim().trim_matches('"').trim_matches('\''); + match reqwest::Proxy::all(url) { + Ok(proxy) => { + debug!(proxy_url = %url, "sync: using proxy"); + builder = builder.proxy(proxy); + } + Err(e) => { + warn!(proxy_url = %url, error = %e, "sync: invalid proxy URL, skipping"); + } + } + } + #[allow(clippy::expect_used)] + builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration") +}); +pub async fn list_models( + config: &EndpointConfig, +) -> AiResult> { + let base = config.base_url.trim_end_matches('/'); + let url = if base.ends_with("/v1") { + format!("{}/models", base) + } else { + format!("{}/v1/models", base) + }; + + debug!(url = %url, "listing models from upstream"); + let resp = HTTP_CLIENT + .get(&url) + .header("Authorization", format!("Bearer {}", config.api_key.trim())) + .send() + .await + .map_err(|e| { + tracing::error!( + error = %e, + source = ?e.source(), + "list_models: request failed with full cause chain" + ); + AiError::Response(format!("failed to list models: {}", e)) + })?; + + let body = resp + .text() + .await + .map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?; + if let Ok(parsed) = serde_json::from_str::(&body) { + debug!(count = parsed.data.len(), "parsed models in standard format"); + return Ok(parsed.data); + } + if let Ok(parsed) = serde_json::from_str::>(&body) { + debug!(count = parsed.len(), "parsed models in array format"); + return Ok(parsed); + } + + warn!( + body = %body.chars().take(500).collect::(), + "list_models: unknown response format" + ); + Err(AiError::Response(format!( + "unexpected /v1/models response format (first 200 chars): {}", + body.chars().take(200).collect::() + ))) +} diff --git a/lib/ai/tool/mod.rs b/lib/ai/tool/mod.rs new file mode 100644 index 0000000..379b9e7 --- /dev/null +++ b/lib/ai/tool/mod.rs @@ -0,0 +1,5 @@ +pub mod register; +pub mod tools; +pub mod toolset; + +pub use toolset::{Toolset, ToolsetRegistry, toolset_names}; diff --git a/lib/ai/tool/register.rs b/lib/ai/tool/register.rs new file mode 100644 index 0000000..1c7d812 --- /dev/null +++ b/lib/ai/tool/register.rs @@ -0,0 +1,65 @@ +use crate::tool::tools::FunctionCall; +use std::collections::HashMap; +use std::sync::Arc; + +#[derive(Clone)] +pub struct ToolRegister +where + C: Clone + Send + Sync + 'static, +{ + pub tools: Vec>>, + index: HashMap, +} + +impl ToolRegister +where + C: Clone + Send + Sync + 'static, +{ + pub fn new() -> Self { + ToolRegister { + tools: Vec::new(), + index: HashMap::new(), + } + } + + pub fn register(&mut self, tool: T) + where + T: FunctionCall + 'static, + { + let idx = self.tools.len(); + self.index.insert(tool.name().to_string(), idx); + self.tools.push(Arc::new(tool)); + } + + pub fn with_tool(mut self, tool: T) -> Self + where + T: FunctionCall + 'static, + { + self.register(tool); + self + } + + pub fn get( + &self, + name: &str, + ) -> Option>> { + self.index.get(name).map(|&idx| self.tools[idx].clone()) + } + + pub fn is_empty(&self) -> bool { + self.tools.is_empty() + } + + pub fn len(&self) -> usize { + self.tools.len() + } +} + +impl Default for ToolRegister +where + C: Clone + Send + Sync + 'static, +{ + fn default() -> Self { + Self::new() + } +} diff --git a/lib/ai/tool/tools.rs b/lib/ai/tool/tools.rs new file mode 100644 index 0000000..be630f2 --- /dev/null +++ b/lib/ai/tool/tools.rs @@ -0,0 +1,18 @@ +use crate::error::AiResult; +use async_trait::async_trait; +use serde_json::Value; + +#[async_trait] +pub trait FunctionCall: Send + Sync { + type Context; + fn name(&self) -> &'static str; + fn description(&self) -> &'static str { + "" + } + fn schema(&self) -> Value; + async fn call( + &self, + context: &mut Self::Context, + args: Value, + ) -> AiResult; +} diff --git a/lib/ai/tool/toolset.rs b/lib/ai/tool/toolset.rs new file mode 100644 index 0000000..7e9da8e --- /dev/null +++ b/lib/ai/tool/toolset.rs @@ -0,0 +1,146 @@ +use std::collections::{HashMap, HashSet}; + +use serde::{Deserialize, Serialize}; +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Toolset { + pub name: String, + pub description: String, + pub tools: Vec, + pub requires_env: Vec, +} + +impl Toolset { + pub fn new( + name: impl Into, + description: impl Into, + ) -> Self { + Self { + name: name.into(), + description: description.into(), + tools: Vec::new(), + requires_env: Vec::new(), + } + } + + pub fn with_tool(mut self, tool_name: impl Into) -> Self { + self.tools.push(tool_name.into()); + self + } + + pub fn with_tools(mut self, tool_names: impl IntoIterator>) -> Self { + self.tools.extend(tool_names.into_iter().map(Into::into)); + self + } + + pub fn with_required_env( + mut self, + env_vars: impl IntoIterator>, + ) -> Self { + self.requires_env.extend(env_vars.into_iter().map(Into::into)); + self + } + pub fn is_available(&self) -> bool { + for env_var in &self.requires_env { + if std::env::var(env_var).is_err() { + return false; + } + } + true + } + + pub fn contains(&self, tool_name: &str) -> bool { + self.tools.iter().any(|t| t == tool_name) + } +} +pub mod toolset_names { + pub const CORE: &str = "core"; + pub const TERMINAL: &str = "terminal"; + pub const WEB: &str = "web"; + pub const FILE: &str = "file"; + pub const MEMORY: &str = "memory"; + pub const VISION: &str = "vision"; + pub const SEARCH: &str = "search"; + pub const BROWSER: &str = "browser"; + pub const CODE_EXECUTION: &str = "code_execution"; + pub const DELEGATION: &str = "delegation"; +} +#[derive(Clone, Debug, Default)] +pub struct ToolsetRegistry { + toolsets: HashMap, + tool_index: HashMap, +} + +impl ToolsetRegistry { + pub fn new() -> Self { + Self::default() + } + pub fn register(&mut self, toolset: Toolset) { + let name = toolset.name.clone(); + for tool in &toolset.tools { + self.tool_index.insert(tool.clone(), name.clone()); + } + self.toolsets.insert(name, toolset); + } + + pub fn get(&self, name: &str) -> Option<&Toolset> { + self.toolsets.get(name) + } + pub fn toolset_for(&self, tool_name: &str) -> Option<&str> { + self.tool_index.get(tool_name).map(String::as_str) + } + pub fn resolve_tool_names( + &self, + enabled: &[String], + disabled: &[String], + default_all: bool, + ) -> Vec { + let mut names = HashSet::new(); + let mut denied = HashSet::new(); + + for ts_name in disabled { + if let Some(ts) = self.toolsets.get(ts_name) { + for tool in &ts.tools { + denied.insert(tool.clone()); + } + } + } + + if enabled.is_empty() && default_all { + for ts in self.toolsets.values() { + if !disabled.contains(&ts.name) && ts.is_available() { + for tool in &ts.tools { + if !denied.contains(tool) { + names.insert(tool.clone()); + } + } + } + } + } else { + for ts_name in enabled { + if let Some(ts) = self.toolsets.get(ts_name) { + if ts.is_available() { + for tool in &ts.tools { + if !denied.contains(tool) { + names.insert(tool.clone()); + } + } + } + } + } + } + + let mut sorted: Vec = names.into_iter().collect(); + sorted.sort(); + sorted + } + + pub fn iter(&self) -> impl Iterator { + self.toolsets.values() + } + + pub fn all_tool_names(&self) -> Vec { + let mut names: Vec = self.tool_index.keys().cloned().collect(); + names.sort(); + names + } +} diff --git a/lib/api/Cargo.toml b/lib/api/Cargo.toml new file mode 100644 index 0000000..a507a27 --- /dev/null +++ b/lib/api/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "api" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "src/lib.rs" +name = "api" + +[dependencies] +service = { workspace = true } +session = { workspace = true } +config = { workspace = true } +db = { workspace = true } +model = { workspace = true } +git = { workspace = true } +channel = { workspace = true } +socketio = { workspace = true } + +actix-web = { workspace = true } +actix-ws = { workspace = true } +utoipa = { workspace = true, features = ["chrono", "uuid", "actix_extras"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +uuid = { workspace = true, features = ["v4", "v7", "serde"] } +chrono = { workspace = true } +async-stream = { workspace = true } +tokio-stream = { workspace = true } +base64 = { workspace = true } +comrak = { workspace = true } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] } +storage = { workspace = true } + +[lints] +workspace = true diff --git a/lib/api/src/agent/conversation.rs b/lib/api/src/agent/conversation.rs new file mode 100644 index 0000000..1663a77 --- /dev/null +++ b/lib/api/src/agent/conversation.rs @@ -0,0 +1,296 @@ +use actix_web::{HttpResponse, web, web::ServiceConfig}; +use service::AppService; +use service::agent::conversation::{ + ConversationResponse, ConversationWithSessionResponse, CreateConversation, MessageResponse, UpdateConversation, +}; +use service::agent::types::{AgentRunRequest, AgentRunResponse}; +use session::Session; +use tokio_stream::StreamExt; +use tokio_stream::wrappers::UnboundedReceiverStream; +use uuid::Uuid; + +use crate::error::{ApiError, ok_json}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("/sessions/{session_id}/conversations") + .route(web::get().to(list_conversations)) + .route(web::post().to(create_conversation)), + ) + .service( + web::resource("/conversations") + .route(web::get().to(list_all_conversations)), + ) + .service( + web::resource("/conversations/{id}") + .route(web::get().to(get_conversation)) + .route(web::patch().to(update_conversation)) + .route(web::delete().to(delete_conversation)), + ) + .service( + web::resource("/conversations/{id}/messages") + .route(web::get().to(list_messages)) + .route(web::post().to(send_message)), + ) + .service( + web::resource("/conversations/{id}/stream") + .route(web::post().to(stream_agent)), + ) + .service( + web::resource("/conversations/{id}/fork") + .route(web::post().to(fork_conversation)), + ); +} + +#[utoipa::path( + get, path = "/api/v1/agent/sessions/{session_id}/conversations", + params(("session_id" = Uuid, Path)), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_conversations( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_conversation_list(user_id, path.into_inner()).await?) +} + +#[utoipa::path( + post, path = "/api/v1/agent/sessions/{session_id}/conversations", + params(("session_id" = Uuid, Path)), + request_body = CreateConversation, + responses((status = 200, body = ConversationResponse)), + security(("session" = [])) +)] +pub async fn create_conversation( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_create(user_id, path.into_inner(), body.into_inner()) + .await?, + ) +} + +#[derive(Debug, serde::Deserialize, utoipa::IntoParams)] +pub struct ListAllConversationsQuery { + pub wk: Option, +} + +#[utoipa::path( + get, path = "/api/v1/agent/conversations", + params(("wk" = Option, Query, description = "Filter by workspace name")), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_all_conversations( + session: Session, + service: web::Data, + query: web::Query, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_list_all(user_id, query.wk.as_deref()) + .await?, + ) +} + +#[utoipa::path( + get, path = "/api/v1/agent/conversations/{id}", + params(("id" = Uuid, Path)), + responses((status = 200, body = ConversationResponse)), + security(("session" = [])) +)] +pub async fn get_conversation( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_conversation_get(user_id, path.into_inner()).await?) +} + +#[utoipa::path( + patch, path = "/api/v1/agent/conversations/{id}", + params(("id" = Uuid, Path)), + request_body = UpdateConversation, + responses((status = 200, body = ConversationResponse)), + security(("session" = [])) +)] +pub async fn update_conversation( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_update(user_id, path.into_inner(), body.into_inner()) + .await?, + ) +} + +#[utoipa::path( + delete, path = "/api/v1/agent/conversations/{id}", + params(("id" = Uuid, Path)), + responses((status = 200)), + security(("session" = [])) +)] +pub async fn delete_conversation( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + service.agent_conversation_delete(user_id, path.into_inner()).await?; + Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true }))) +} + +#[utoipa::path( + post, path = "/api/v1/agent/conversations/{id}/archive", + params(("id" = Uuid, Path)), + responses((status = 200, body = ConversationResponse)), + security(("session" = [])) +)] +pub async fn archive_conversation( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_conversation_archive(user_id, path.into_inner()).await?) +} + +#[utoipa::path( + post, path = "/api/v1/agent/conversations/{id}/unarchive", + params(("id" = Uuid, Path)), + responses((status = 200, body = ConversationResponse)), + security(("session" = [])) +)] +pub async fn unarchive_conversation( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_conversation_unarchive(user_id, path.into_inner()).await?) +} +#[utoipa::path( + get, path = "/api/v1/agent/conversations/{id}/messages", + params(("id" = Uuid, Path), ("before" = Option, Query), ("limit" = Option, Query)), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_messages( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_message_list(user_id, path.into_inner(), query.limit, query.before) + .await?, + ) +} + +#[derive(Debug, serde::Deserialize, utoipa::IntoParams)] +pub struct MessageListQuery { + pub limit: Option, + pub before: Option, +} +#[utoipa::path( + post, path = "/api/v1/agent/conversations/{id}/messages", + params(("id" = Uuid, Path)), + request_body = AgentRunRequest, + responses((status = 200, body = AgentRunResponse)), + security(("session" = [])) +)] +pub async fn send_message( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let conversation_id = path.into_inner(); + let mut req = body.into_inner(); + req.conversation_id = Some(conversation_id); + ok_json(service.agent_run(user_id, req).await?) +} +#[utoipa::path( + post, path = "/api/v1/agent/conversations/{id}/stream", + params(("id" = Uuid, Path)), + request_body = AgentRunRequest, + responses((status = 200, description = "SSE stream")), + security(("session" = [])) +)] +pub async fn stream_agent( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + let conversation_id = path.into_inner(); + let mut req = body.into_inner(); + req.conversation_id = Some(conversation_id); + + let rx = service.agent_run_streaming(user_id, req).await?; + + let stream = UnboundedReceiverStream::new(rx).map(|payload| { + let frame = if payload.starts_with("data:") { + payload + } else { + format!("data: {}\n\n", payload) + }; + Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame)) + }); + + Ok(HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("Connection", "keep-alive")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(stream)) +} + +#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] +pub struct ForkConversationRequest { + pub message_id: Option, + pub title: Option, +} +#[utoipa::path( + post, path = "/api/v1/agent/conversations/{id}/fork", + params(("id" = Uuid, Path)), + request_body = ForkConversationRequest, + responses((status = 200, body = ConversationResponse)), + security(("session" = [])) +)] +pub async fn fork_conversation( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_conversation_fork( + user_id, + path.into_inner(), + body.message_id, + body.title.as_deref(), + ) + .await?, + ) +} diff --git a/lib/api/src/agent/mod.rs b/lib/api/src/agent/mod.rs new file mode 100644 index 0000000..3c47877 --- /dev/null +++ b/lib/api/src/agent/mod.rs @@ -0,0 +1,12 @@ +pub mod conversation; +pub mod session; + +use actix_web::{web, web::ServiceConfig}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::scope("/agent") + .configure(session::configure) + .configure(conversation::configure), + ); +} diff --git a/lib/api/src/agent/session.rs b/lib/api/src/agent/session.rs new file mode 100644 index 0000000..db67e26 --- /dev/null +++ b/lib/api/src/agent/session.rs @@ -0,0 +1,162 @@ +use actix_web::{HttpResponse, web, web::ServiceConfig}; +use service::AppService; +use service::agent::session::{ + AgentSessionResponse, CreateAgentSession, UpdateAgentSession, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::{ApiError, ok_json}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("/sessions") + .route(web::get().to(list_sessions)) + .route(web::post().to(create_session)), + ) + .service( + web::resource("/sessions/search") + .route(web::get().to(search_sessions)), + ) + .service( + web::resource("/sessions/{id}") + .route(web::get().to(get_session)) + .route(web::patch().to(update_session)) + .route(web::delete().to(delete_session)), + ) + .service( + web::resource("/sessions/{id}/toolsets") + .route(web::patch().to(update_session_toolsets)), + ); +} +#[utoipa::path( + get, path = "/api/v1/agent/sessions", + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_sessions( + session: Session, + service: web::Data, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_session_list(user_id).await?) +} +#[utoipa::path( + post, path = "/api/v1/agent/sessions", + request_body = CreateAgentSession, + responses((status = 200, body = AgentSessionResponse)), + security(("session" = [])) +)] +pub async fn create_session( + session: Session, + service: web::Data, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_session_create(user_id, body.into_inner()).await?) +} +#[utoipa::path( + get, path = "/api/v1/agent/sessions/{id}", + params(("id" = Uuid, Path)), + responses((status = 200, body = AgentSessionResponse)), + security(("session" = [])) +)] +pub async fn get_session( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_session_get(user_id, path.into_inner()).await?) +} +#[utoipa::path( + patch, path = "/api/v1/agent/sessions/{id}", + params(("id" = Uuid, Path)), + request_body = UpdateAgentSession, + responses((status = 200, body = AgentSessionResponse)), + security(("session" = [])) +)] +pub async fn update_session( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.agent_session_update(user_id, path.into_inner(), body.into_inner()).await?) +} +#[utoipa::path( + delete, path = "/api/v1/agent/sessions/{id}", + params(("id" = Uuid, Path)), + responses((status = 200)), + security(("session" = [])) +)] +pub async fn delete_session( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + service.agent_session_delete(user_id, path.into_inner()).await?; + Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true }))) +} + +#[derive(Debug, serde::Deserialize, utoipa::IntoParams)] +pub struct SearchQuery { + pub q: String, + #[serde(default = "default_limit")] + pub limit: u32, +} + +const fn default_limit() -> u32 { + 20 +} +#[utoipa::path( + get, path = "/api/v1/agent/sessions/search", + params(("q" = String, Query), ("limit" = Option, Query)), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn search_sessions( + session: Session, + service: web::Data, + query: web::Query, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_session_search(user_id, &query.q, query.limit) + .await?, + ) +} + +#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] +pub struct UpdateToolsetsRequest { + pub enabled: Option>, + pub disabled: Option>, +} +#[utoipa::path( + patch, path = "/api/v1/agent/sessions/{id}/toolsets", + params(("id" = Uuid, Path)), + request_body = UpdateToolsetsRequest, + responses((status = 200, body = AgentSessionResponse)), + security(("session" = [])) +)] +pub async fn update_session_toolsets( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json( + service + .agent_session_update_toolsets( + user_id, + path.into_inner(), + body.enabled.clone(), + body.disabled.clone(), + ) + .await?, + ) +} diff --git a/lib/api/src/ai/mod.rs b/lib/api/src/ai/mod.rs new file mode 100644 index 0000000..09d0a59 --- /dev/null +++ b/lib/api/src/ai/mod.rs @@ -0,0 +1,47 @@ +pub mod model; +pub mod provider; + +use actix_web::web; +use actix_web::web::ServiceConfig; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::scope("/ai") + .service( + web::resource("/providers") + .route(web::get().to(provider::list_providers)), + ) + .service( + web::resource("/providers/{id}") + .route(web::get().to(provider::get_provider)), + ) + .service( + web::resource("/models") + .route(web::get().to(model::list_models)), + ) + .service( + web::resource("/models/{id}") + .route(web::get().to(model::get_model)), + ) + .service( + web::resource("/models/{id}/versions") + .route(web::get().to(model::list_versions)), + ) + .service( + web::resource("/models/{id}/card") + .route(web::get().to(model::get_card)), + ) + .service( + web::resource("/models/{id}/tags") + .route(web::get().to(model::list_tags)), + ) + .service( + web::resource("/models/{id}/discussions") + .route(web::get().to(model::list_discussions)), + ) + .service( + web::resource("/models/{id}/likes") + .route(web::get().to(model::list_likes)), + ), + ); +} diff --git a/lib/api/src/ai/model.rs b/lib/api/src/ai/model.rs new file mode 100644 index 0000000..d75bfa1 --- /dev/null +++ b/lib/api/src/ai/model.rs @@ -0,0 +1,139 @@ +use crate::error::{ApiError, ok_json}; +use actix_web::{HttpResponse, web}; +use serde::Deserialize; +use service::AppService; +use service::Pagination; +use service::ai::types::{ + AiDiscussionResponse, AiLikeResponse, AiModelCardResponse, AiModelFilter, + AiModelListItem, AiModelResponse, AiModelVersionResponse, +}; +use session::Session; +use uuid::Uuid; + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ModelIdPath { + pub id: Uuid, +} +#[utoipa::path( + get, path = "/api/v1/ai/models", + params(AiModelFilter, Pagination), + responses((status = 200, body = Vec), (status = 401, description = "Unauthorized")), + security(("session" = [])) +)] +pub async fn list_models( + session: Session, + service: web::Data, + filter: web::Query, + pagination: web::Query, +) -> Result { + ok_json( + service + .ai_model_list( + &session, + filter.into_inner(), + pagination.into_inner(), + ) + .await?, + ) +} +#[utoipa::path( + get, path = "/api/v1/ai/models/{id}", + params(("id" = String, Path)), responses((status = 200, body = AiModelResponse), + (status = 401, description = "Unauthorized"), (status = 404, description = "Not found")), + security(("session" = [])) +)] +pub async fn get_model( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json(service.ai_model_get(&session, path.into_inner().id).await?) +} +#[utoipa::path( + get, path = "/api/v1/ai/models/{id}/versions", + params(("id" = String, Path)), responses((status = 200, body = Vec), + (status = 401, description = "Unauthorized")), + security(("session" = [])) +)] +pub async fn list_versions( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json( + service + .ai_model_versions(&session, path.into_inner().id) + .await?, + ) +} +#[utoipa::path( + get, path = "/api/v1/ai/models/{id}/card", + params(("id" = String, Path)), responses((status = 200, body = Option), + (status = 401, description = "Unauthorized")), + security(("session" = [])) +)] +pub async fn get_card( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json( + service + .ai_model_card(&session, path.into_inner().id) + .await?, + ) +} +#[utoipa::path( + get, path = "/api/v1/ai/models/{id}/tags", + params(("id" = String, Path)), responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_tags( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json( + service + .ai_model_tags(&session, path.into_inner().id) + .await?, + ) +} +#[utoipa::path( + get, path = "/api/v1/ai/models/{id}/discussions", + params(("id" = String, Path), Pagination), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_discussions( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + ok_json( + service + .ai_model_discussions( + &session, + path.into_inner().id, + pagination.into_inner(), + ) + .await?, + ) +} +#[utoipa::path( + get, path = "/api/v1/ai/models/{id}/likes", + params(("id" = String, Path)), responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_likes( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json( + service + .ai_model_likes(&session, path.into_inner().id) + .await?, + ) +} diff --git a/lib/api/src/ai/provider.rs b/lib/api/src/ai/provider.rs new file mode 100644 index 0000000..98a3e48 --- /dev/null +++ b/lib/api/src/ai/provider.rs @@ -0,0 +1,39 @@ +use crate::error::{ApiError, ok_json}; +use actix_web::{HttpResponse, web}; +use serde::Deserialize; +use service::AppService; +use service::ai::types::AiProviderResponse; +use session::Session; + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ProviderIdPath { + pub id: uuid::Uuid, +} +#[utoipa::path( + get, path = "/api/v1/ai/providers", + responses((status = 200, body = Vec), (status = 401, description = "Unauthorized")), + security(("session" = [])) +)] +pub async fn list_providers( + session: Session, + service: web::Data, +) -> Result { + ok_json(service.ai_provider_list(&session).await?) +} +#[utoipa::path( + get, path = "/api/v1/ai/providers/{id}", + params(("id" = String, Path)), responses((status = 200, body = AiProviderResponse), + (status = 401, description = "Unauthorized"), (status = 404, description = "Not found")), + security(("session" = [])) +)] +pub async fn get_provider( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json( + service + .ai_provider_get(&session, path.into_inner().id) + .await?, + ) +} diff --git a/lib/api/src/auth/captcha.rs b/lib/api/src/auth/captcha.rs new file mode 100644 index 0000000..d80fd44 --- /dev/null +++ b/lib/api/src/auth/captcha.rs @@ -0,0 +1,29 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + auth::captcha::{CaptchaQuery, CaptchaResponse}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/auth/captcha", + params(CaptchaQuery), + responses((status = 200, body = CaptchaResponse)), + tag = "auth" +)] +pub async fn captcha( + session: Session, + query: web::Query, + service: web::Data, +) -> Result { + let result = service.auth_captcha(&session, query.into_inner()).await?; + ok_json(result) +} diff --git a/lib/api/src/auth/email.rs b/lib/api/src/auth/email.rs new file mode 100644 index 0000000..3850714 --- /dev/null +++ b/lib/api/src/auth/email.rs @@ -0,0 +1,64 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + auth::email::{EmailChangeRequest, EmailResponse, EmailVerifyRequest}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + get, + path = "/api/v1/auth/email", + responses((status = 200, body = EmailResponse)), + tag = "auth" +)] +pub async fn get_email( + session: Session, + service: web::Data, +) -> Result { + let result = service.auth_get_email(&session).await?; + ok_json(result) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/email", + request_body = EmailChangeRequest, + responses((status = 200)), + tag = "auth" +)] +pub async fn email_change_request( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + service + .auth_email_change_request(&session, params.into_inner()) + .await?; + ok() +} + +#[utoipa::path( + post, + path = "/api/v1/auth/email/verify", + request_body = EmailVerifyRequest, + responses((status = 200)), + tag = "auth" +)] +pub async fn email_verify( + params: web::Json, + service: web::Data, +) -> Result { + service.auth_email_verify(params.into_inner()).await?; + ok() +} diff --git a/lib/api/src/auth/login.rs b/lib/api/src/auth/login.rs new file mode 100644 index 0000000..2b48e70 --- /dev/null +++ b/lib/api/src/auth/login.rs @@ -0,0 +1,25 @@ +use actix_web::{HttpResponse, web}; +use service::{AppService, auth::login::LoginParams}; +use session::Session; + +use crate::error::ApiError; + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/login", + request_body = LoginParams, + responses((status = 200)), + tag = "auth" +)] +pub async fn login( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + service.auth_login(params.into_inner(), session).await?; + ok() +} diff --git a/lib/api/src/auth/logout.rs b/lib/api/src/auth/logout.rs new file mode 100644 index 0000000..2e519a2 --- /dev/null +++ b/lib/api/src/auth/logout.rs @@ -0,0 +1,23 @@ +use actix_web::{HttpResponse, web}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/logout", + responses((status = 200)), + tag = "auth" +)] +pub async fn logout( + session: Session, + service: web::Data, +) -> Result { + service.auth_logout(&session).await?; + ok() +} diff --git a/lib/api/src/auth/me.rs b/lib/api/src/auth/me.rs new file mode 100644 index 0000000..9e40beb --- /dev/null +++ b/lib/api/src/auth/me.rs @@ -0,0 +1,24 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, auth::me::ContextMe}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/auth/me", + responses((status = 200, body = ContextMe)), + tag = "auth" +)] +pub async fn me( + session: Session, + service: web::Data, +) -> Result { + let result = service.auth_me(session).await?; + ok_json(result) +} diff --git a/lib/api/src/auth/mod.rs b/lib/api/src/auth/mod.rs new file mode 100644 index 0000000..7067c69 --- /dev/null +++ b/lib/api/src/auth/mod.rs @@ -0,0 +1,75 @@ +pub mod captcha; +pub mod email; +pub mod login; +pub mod logout; +pub mod me; +pub mod register; +pub mod reset_pass; +pub mod rsa; +pub mod totp; + +use actix_web::{web, web::ServiceConfig}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::scope("/auth") + .service( + web::resource("/captcha") + .route(web::get().to(captcha::captcha)), + ) + .service( + web::resource("/login").route(web::post().to(login::login)), + ) + .service( + web::resource("/logout").route(web::post().to(logout::logout)), + ) + .service(web::resource("/me").route(web::get().to(me::me))) + .service( + web::resource("/register") + .route(web::post().to(register::register)), + ) + .service( + web::scope("/reset-password") + .service(web::resource("/request").route( + web::post().to(reset_pass::reset_password_request), + )) + .service(web::resource("/verify").route( + web::post().to(reset_pass::reset_password_verify), + )), + ) + .service(web::resource("/public-key").route(web::get().to(rsa::rsa))) + .service( + web::scope("/2fa") + .service( + web::resource("/enable") + .route(web::post().to(totp::enable_2fa)), + ) + .service( + web::resource("/verify") + .route(web::post().to(totp::verify_2fa)), + ) + .service( + web::resource("") + .route(web::get().to(totp::status_2fa)) + .route(web::delete().to(totp::disable_2fa)), + ) + .service( + web::resource("/backup-codes").route( + web::post().to(totp::regenerate_backup_codes), + ), + ), + ) + .service( + web::scope("/email") + .service( + web::resource("") + .route(web::get().to(email::get_email)) + .route(web::put().to(email::email_change_request)), + ) + .service( + web::resource("/verify") + .route(web::post().to(email::email_verify)), + ), + ), + ); +} diff --git a/lib/api/src/auth/register.rs b/lib/api/src/auth/register.rs new file mode 100644 index 0000000..3918cb5 --- /dev/null +++ b/lib/api/src/auth/register.rs @@ -0,0 +1,26 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, auth::register::RegisterParams}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/register", + request_body = RegisterParams, + responses((status = 200)), + tag = "auth" +)] +pub async fn register( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + let result = service.auth_register(params.into_inner(), &session).await?; + ok_json(result) +} diff --git a/lib/api/src/auth/reset_pass.rs b/lib/api/src/auth/reset_pass.rs new file mode 100644 index 0000000..ed2fe9f --- /dev/null +++ b/lib/api/src/auth/reset_pass.rs @@ -0,0 +1,47 @@ +use actix_web::{HttpResponse, web}; +use service::{ + AppService, + auth::reset_pass::{ResetPasswordRequest, ResetPasswordVerifyParams}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/reset-password/request", + request_body = ResetPasswordRequest, + responses((status = 200)), + tag = "auth" +)] +pub async fn reset_password_request( + params: web::Json, + service: web::Data, +) -> Result { + service + .auth_reset_password_request(params.into_inner()) + .await?; + ok() +} + +#[utoipa::path( + post, + path = "/api/v1/auth/reset-password/verify", + request_body = ResetPasswordVerifyParams, + responses((status = 200)), + tag = "auth" +)] +pub async fn reset_password_verify( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + service + .auth_reset_password_verify(&session, params.into_inner()) + .await?; + ok() +} diff --git a/lib/api/src/auth/rsa.rs b/lib/api/src/auth/rsa.rs new file mode 100644 index 0000000..66ae85b --- /dev/null +++ b/lib/api/src/auth/rsa.rs @@ -0,0 +1,24 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, auth::rsa::RsaResponse}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/auth/public-key", + responses((status = 200, body = RsaResponse)), + tag = "auth" +)] +pub async fn rsa( + session: Session, + service: web::Data, +) -> Result { + let result = service.auth_rsa(&session).await?; + ok_json(result) +} diff --git a/lib/api/src/auth/totp.rs b/lib/api/src/auth/totp.rs new file mode 100644 index 0000000..c1e81ab --- /dev/null +++ b/lib/api/src/auth/totp.rs @@ -0,0 +1,102 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + auth::totp::{ + Disable2FAParams, Enable2FAResponse, Get2FAStatusResponse, + Verify2FAParams, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/2fa/enable", + responses((status = 200, body = Enable2FAResponse)), + tag = "auth" +)] +pub async fn enable_2fa( + session: Session, + service: web::Data, +) -> Result { + let result = service.auth_2fa_enable(&session).await?; + ok_json(result) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/2fa/verify", + request_body = Verify2FAParams, + responses((status = 200)), + tag = "auth" +)] +pub async fn verify_2fa( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + service + .auth_2fa_verify_and_enable(&session, params.into_inner()) + .await?; + ok() +} + +#[utoipa::path( + delete, + path = "/api/v1/auth/2fa", + request_body = Disable2FAParams, + responses((status = 200)), + tag = "auth" +)] +pub async fn disable_2fa( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + service + .auth_2fa_disable(&session, params.into_inner()) + .await?; + ok() +} + +#[utoipa::path( + get, + path = "/api/v1/auth/2fa", + responses((status = 200, body = Get2FAStatusResponse)), + tag = "auth" +)] +pub async fn status_2fa( + session: Session, + service: web::Data, +) -> Result { + let result = service.auth_2fa_status(&session).await?; + ok_json(result) +} + +#[utoipa::path( + post, + path = "/api/v1/auth/2fa/backup-codes", + request_body = String, + responses((status = 200)), + tag = "auth" +)] +pub async fn regenerate_backup_codes( + session: Session, + params: web::Json, + service: web::Data, +) -> Result { + let result = service + .auth_2fa_regenerate_backup_codes(&session, params.into_inner()) + .await?; + ok_json(result) +} diff --git a/lib/api/src/channel/mod.rs b/lib/api/src/channel/mod.rs new file mode 100644 index 0000000..fb14f87 --- /dev/null +++ b/lib/api/src/channel/mod.rs @@ -0,0 +1,207 @@ +pub mod rest; +pub mod rest_ai; +pub mod rest_interact; +pub mod rest_member; +pub mod rest_message; +pub mod rest_room; +pub mod rest_voice; +pub mod token; + +pub use channel::ChannelBus; + +use actix_web::web::ServiceConfig; + +pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) { + socketio::configure_at(cfg, "/socket.io", bus.io().clone()); + socketio::configure_at(cfg, "/socket.io/", bus.io().clone()); + cfg.service( + actix_web::web::resource("/ping") + .route(actix_web::web::get().to(rest::ping)), + ) + .service( + actix_web::web::resource("/csrf") + .route(actix_web::web::get().to(rest::csrf_token)), + ); + cfg.service( + actix_web::web::resource("/rooms/{room_id}/messages") + .route(actix_web::web::get().to(rest_message::list_messages)) + .route(actix_web::web::post().to(rest_message::create_message)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/messages/around") + .route(actix_web::web::get().to(rest_message::messages_around)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/messages/missed") + .route(actix_web::web::get().to(rest_message::missed_messages)), + ) + .service( + actix_web::web::resource("/messages/{message_id}") + .route(actix_web::web::patch().to(rest_message::update_message)) + .route(actix_web::web::delete().to(rest_message::revoke_message)), + ) + .service( + actix_web::web::resource("/search") + .route(actix_web::web::get().to(rest_message::search)), + ); + cfg.service( + actix_web::web::resource("/rooms") + .route(actix_web::web::get().to(rest_room::list_rooms)) + .route(actix_web::web::post().to(rest_room::room_create)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}") + .route(actix_web::web::get().to(rest_room::room_get)) + .route(actix_web::web::patch().to(rest_room::room_update)) + .route(actix_web::web::delete().to(rest_room::room_delete)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/subscribe") + .route(actix_web::web::post().to(rest_room::subscribe)) + .route(actix_web::web::delete().to(rest_room::unsubscribe)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/members") + .route(actix_web::web::post().to(rest_room::access_grant)), + ) + .service( + actix_web::web::resource("/workspaces/{workspace_id}/members") + .route(actix_web::web::get().to(rest_member::list_workspace_members)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/members/{user_id}") + .route(actix_web::web::delete().to(rest_room::access_revoke)), + ) + .service( + actix_web::web::resource("/workspaces/{workspace_id}/categories") + .route(actix_web::web::post().to(rest_room::category_create)), + ) + .service( + actix_web::web::resource("/categories/{category_id}") + .route(actix_web::web::patch().to(rest_room::category_update)) + .route(actix_web::web::delete().to(rest_room::category_delete)), + ); + cfg.service( + actix_web::web::resource("/rooms/{room_id}/reactions") + .route(actix_web::web::post().to(rest_interact::reaction_add)) + .route(actix_web::web::delete().to(rest_interact::reaction_remove)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/threads") + .route(actix_web::web::post().to(rest_interact::thread_create)), + ) + .service( + actix_web::web::resource("/threads/{thread_id}/resolve") + .route(actix_web::web::patch().to(rest_interact::thread_resolve)), + ) + .service( + actix_web::web::resource("/threads/{thread_id}/archive") + .route(actix_web::web::patch().to(rest_interact::thread_archive)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/pins") + .route(actix_web::web::post().to(rest_interact::pin_add)) + .route(actix_web::web::delete().to(rest_interact::pin_remove)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/drafts") + .route(actix_web::web::put().to(rest_interact::draft_save)) + .route(actix_web::web::delete().to(rest_interact::draft_clear)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/typing") + .route(actix_web::web::post().to(rest_interact::typing)), + ); + cfg.service( + actix_web::web::resource("/rooms/{room_id}/read-receipt") + .route(actix_web::web::post().to(rest_member::read_receipt)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/dnd") + .route(actix_web::web::patch().to(rest_member::dnd_update)), + ) + .service( + actix_web::web::resource("/notifications/{id}/read").route( + actix_web::web::patch().to(rest_member::notification_mark_read), + ), + ) + .service(actix_web::web::resource("/notifications/read-all").route( + actix_web::web::post().to(rest_member::notification_mark_all_read), + )) + .service( + actix_web::web::resource("/notifications/{id}").route( + actix_web::web::delete().to(rest_member::notification_archive), + ), + ) + .service( + actix_web::web::resource("/presence") + .route(actix_web::web::post().to(rest_member::presence_update)), + ) + .service( + actix_web::web::resource("/custom-status").route( + actix_web::web::post().to(rest_member::custom_status_update), + ), + ) + .service( + actix_web::web::resource("/invites") + .route(actix_web::web::post().to(rest_member::invite_create)), + ) + .service( + actix_web::web::resource("/invites/accept") + .route(actix_web::web::post().to(rest_member::invite_accept)), + ) + .service( + actix_web::web::resource("/invites/{id}") + .route(actix_web::web::delete().to(rest_member::invite_revoke)), + ) + .service( + actix_web::web::resource("/workspaces/{workspace_id}/bans") + .route(actix_web::web::post().to(rest_member::ban_create)), + ) + .service( + actix_web::web::resource("/workspaces/{workspace_id}/bans/{user_id}") + .route(actix_web::web::delete().to(rest_member::ban_remove)), + ); + cfg.service( + actix_web::web::resource("/rooms/{room_id}/voice/join") + .route(actix_web::web::post().to(rest_voice::voice_join)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/voice/leave") + .route(actix_web::web::post().to(rest_voice::voice_leave)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/voice/mute") + .route(actix_web::web::post().to(rest_voice::voice_mute)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/voice/deaf") + .route(actix_web::web::post().to(rest_voice::voice_deaf)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/screen-share") + .route(actix_web::web::post().to(rest_voice::screen_share)), + ); + cfg.service( + actix_web::web::resource("/rooms/{room_id}/ai/stop") + .route(actix_web::web::post().to(rest_ai::ai_stop)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/ai") + .route(actix_web::web::get().to(rest_ai::ai_list)) + .route(actix_web::web::post().to(rest_ai::ai_add)), + ) + .service( + actix_web::web::resource("/rooms/{room_id}/ai/{agent_session_id}") + .route(actix_web::web::delete().to(rest_ai::ai_remove)), + ) + .service( + actix_web::web::resource("/users/summary/{username}") + .route(actix_web::web::get().to(rest_ai::user_summary)), + ); + cfg.service( + actix_web::web::resource("/token") + .route(actix_web::web::post().to(token::generate_token)), + ); + cfg.app_data(actix_web::web::Data::new(bus)); +} diff --git a/lib/api/src/channel/rest.rs b/lib/api/src/channel/rest.rs new file mode 100644 index 0000000..2f37e56 --- /dev/null +++ b/lib/api/src/channel/rest.rs @@ -0,0 +1,110 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::http::{WsHandler, WsInMessage, WsOutEvent}; +use channel::{ChannelBus, ChannelError}; +use session::SessionExt; +use uuid::Uuid; + +use crate::error::ApiError; + +pub(crate) fn extract_user(req: &HttpRequest) -> Result { + req.get_session() + .user() + .ok_or_else(|| ApiError(service::error::AppError::Unauthorized)) +} + +pub(crate) fn channel_err(e: ChannelError) -> ApiError { + ApiError(match e { + ChannelError::Unauthorized | ChannelError::TokenInvalidOrExpired => { + service::error::AppError::Unauthorized + } + ChannelError::AccessDenied => { + service::error::AppError::PermissionDenied + } + ChannelError::Validation(msg) => { + service::error::AppError::BadRequest(msg) + } + ChannelError::RateLimitExceeded => { + service::error::AppError::BadRequest("rate limit exceeded".into()) + } + ChannelError::RenewalLimitExceeded => { + service::error::AppError::BadRequest( + "renewal limit exceeded".into(), + ) + } + ChannelError::RoomNotFound => { + service::error::AppError::NotFound("room not found".into()) + } + ChannelError::UserNotFound => { + service::error::AppError::NotFound("user not found".into()) + } + ChannelError::Internal(msg) => { + service::error::AppError::InternalServerError(msg) + } + ChannelError::Database(e) => { + service::error::AppError::InternalServerError(e.to_string()) + } + ChannelError::Cache(e) => { + service::error::AppError::InternalServerError(e.to_string()) + } + ChannelError::SocketIo(e) => { + service::error::AppError::InternalServerError(e.to_string()) + } + ChannelError::Serialization(e) => { + service::error::AppError::InternalServerError(e.to_string()) + } + ChannelError::Redis(e) => { + service::error::AppError::InternalServerError(e.to_string()) + } + ChannelError::Storage(e) => { + service::error::AppError::InternalServerError(e.to_string()) + } + }) +} + +pub(crate) fn ok_json(event: Option) -> HttpResponse { + match event { + Some(e) => HttpResponse::Ok().json(e), + None => HttpResponse::NoContent().finish(), + } +} + +pub(crate) fn created_json(event: Option) -> HttpResponse { + match event { + Some(e) => HttpResponse::Created().json(e), + None => HttpResponse::NoContent().finish(), + } +} + +#[utoipa::path( + get, + path = "/api/v1/ws/ping", + responses((status = 200, description = "Pong with protocol version")), + tag = "channel", +)] +pub async fn ping( + req: HttpRequest, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let result = WsHandler::handle(&bus, user_id, WsInMessage::Ping) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/csrf", + responses((status = 200, description = "CSRF token")), + tag = "channel", +)] +pub async fn csrf_token( + req: HttpRequest, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let result = WsHandler::handle(&bus, user_id, WsInMessage::CsrfToken) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/channel/rest_ai.rs b/lib/api/src/channel/rest_ai.rs new file mode 100644 index 0000000..a29b5cf --- /dev/null +++ b/lib/api/src/channel/rest_ai.rs @@ -0,0 +1,120 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; +use serde::Deserialize; +use uuid::Uuid; + +use super::rest::{channel_err, created_json, extract_user, ok_json}; +use crate::error::ApiError; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct AiAddRequest { + pub agent_session: Uuid, +} + +#[utoipa::path( + get, + path = "/api/v1/ws/rooms/{room_id}/ai", + responses((status = 200, description = "AI agents in room")), + tag = "channel", +)] +pub async fn ai_list( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::AiList { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/ai", + request_body = AiAddRequest, + responses((status = 201, description = "AI agent added to room")), + tag = "channel", +)] +pub async fn ai_add( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::AiUpsert { + room: room_id.into_inner(), + model: body.agent_session, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(created_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}/ai/{agent_session_id}", + responses((status = 200, description = "AI agent removed from room")), + tag = "channel", +)] +pub async fn ai_remove( + req: HttpRequest, + path: web::Path<(Uuid, Uuid)>, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let (room, agent_id) = path.into_inner(); + let msg = WsInMessage::AiDelete { room, agent_id }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/ai/stop", + responses((status = 204, description = "AI agent stopped")), + tag = "channel", +)] +pub async fn ai_stop( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::AiStop { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/users/summary/{username}", + responses((status = 200, description = "User summary")), + tag = "channel", +)] +pub async fn user_summary( + req: HttpRequest, + username: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::UserSummary { + username: username.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/channel/rest_interact.rs b/lib/api/src/channel/rest_interact.rs new file mode 100644 index 0000000..e160431 --- /dev/null +++ b/lib/api/src/channel/rest_interact.rs @@ -0,0 +1,274 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; +use serde::Deserialize; +use uuid::Uuid; + +use super::rest::{channel_err, created_json, extract_user, ok_json}; +use crate::error::ApiError; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct ReactionRequest { + pub message: Uuid, + pub emoji: String, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct ThreadCreateRequest { + pub parent: i64, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct TypingRequest { + pub action: TypingAction, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub enum TypingAction { + Start, + Stop, +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/reactions", + request_body = ReactionRequest, + responses((status = 204, description = "Reaction added")), + tag = "channel", +)] +pub async fn reaction_add( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ReactionAdd { + room: room_id.into_inner(), + message: body.message, + emoji: body.emoji.clone(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}/reactions", + request_body = ReactionRequest, + responses((status = 204, description = "Reaction removed")), + tag = "channel", +)] +pub async fn reaction_remove( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ReactionRemove { + room: room_id.into_inner(), + message: body.message, + emoji: body.emoji.clone(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/threads", + request_body = ThreadCreateRequest, + responses((status = 201, description = "Thread created")), + tag = "channel", +)] +pub async fn thread_create( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ThreadCreate { + room: room_id.into_inner(), + parent: body.parent, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(created_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/threads/{thread_id}/resolve", + responses((status = 200, description = "Thread resolved")), + tag = "channel", +)] +pub async fn thread_resolve( + req: HttpRequest, + thread_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ThreadResolve { + thread_id: thread_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/threads/{thread_id}/archive", + responses((status = 200, description = "Thread archived")), + tag = "channel", +)] +pub async fn thread_archive( + req: HttpRequest, + thread_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ThreadArchive { + thread_id: thread_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct PinRequest { + pub message: Uuid, +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/pins", + request_body = PinRequest, + responses((status = 204, description = "Message pinned")), + tag = "channel", +)] +pub async fn pin_add( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::PinAdd { + room: room_id.into_inner(), + message: body.message, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}/pins", + request_body = PinRequest, + responses((status = 204, description = "Pin removed")), + tag = "channel", +)] +pub async fn pin_remove( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::PinRemove { + room: room_id.into_inner(), + message: body.message, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct DraftSaveRequest { + pub content: String, +} + +#[utoipa::path( + put, + path = "/api/v1/ws/rooms/{room_id}/drafts", + request_body = DraftSaveRequest, + responses((status = 204, description = "Draft saved")), + tag = "channel", +)] +pub async fn draft_save( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::DraftSave { + room: room_id.into_inner(), + content: body.content.clone(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}/drafts", + responses((status = 204, description = "Draft cleared")), + tag = "channel", +)] +pub async fn draft_clear( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::DraftClear { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/typing", + request_body = TypingRequest, + responses((status = 204, description = "Typing indicator broadcasted")), + tag = "channel", +)] +pub async fn typing( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let room = room_id.into_inner(); + let msg = match body.action { + TypingAction::Start => WsInMessage::TypingStart { room }, + TypingAction::Stop => WsInMessage::TypingStop { room }, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/channel/rest_member.rs b/lib/api/src/channel/rest_member.rs new file mode 100644 index 0000000..4a224f0 --- /dev/null +++ b/lib/api/src/channel/rest_member.rs @@ -0,0 +1,375 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use super::rest::{channel_err, created_json, extract_user, ok_json}; +use crate::error::ApiError; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct ReadReceiptRequest { + pub last_read_seq: i64, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct DndRequest { + pub do_not_disturb: Option, + pub dnd_start_hour: Option, + pub dnd_end_hour: Option, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct NotificationMarkAllReadRequest { + pub workspace_id: Option, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct PresenceUpdateRequest { + #[schema(example = "online")] + pub status: String, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct CustomStatusRequest { + pub emoji: Option, + pub text: Option, + pub expires_at: Option>, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct InviteCreateRequest { + pub workspace: Uuid, + pub room: Option, + pub max_uses: Option, + pub expires_at: Option>, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct InviteAcceptRequest { + pub code: String, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct BanCreateRequest { + pub user: Uuid, + pub reason: Option, + pub expires_at: Option>, +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/read-receipt", + request_body = ReadReceiptRequest, + responses((status = 200, description = "Read receipt saved")), + tag = "channel", +)] +pub async fn read_receipt( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ReadReceipt { + room: room_id.into_inner(), + last_read_seq: body.last_read_seq, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/rooms/{room_id}/dnd", + request_body = DndRequest, + responses((status = 204, description = "DND updated")), + tag = "channel", +)] +pub async fn dnd_update( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::StateUpdateDnd { + room: room_id.into_inner(), + do_not_disturb: body.do_not_disturb, + dnd_start_hour: body.dnd_start_hour, + dnd_end_hour: body.dnd_end_hour, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/notifications/{id}/read", + responses((status = 204, description = "Notification marked read")), + tag = "channel", +)] +pub async fn notification_mark_read( + req: HttpRequest, + id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::NotificationMarkRead { + id: id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/notifications/read-all", + request_body = NotificationMarkAllReadRequest, + responses((status = 204, description = "All notifications marked read")), + tag = "channel", +)] +pub async fn notification_mark_all_read( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::NotificationMarkAllRead { + workspace_id: body.workspace_id, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/notifications/{id}", + responses((status = 204, description = "Notification archived")), + tag = "channel", +)] +pub async fn notification_archive( + req: HttpRequest, + id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::NotificationArchive { + id: id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/presence", + request_body = PresenceUpdateRequest, + responses((status = 204, description = "Presence updated")), + tag = "channel", +)] +pub async fn presence_update( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let status: channel::event::presence::UserPresenceStatus = + serde_json::from_value(serde_json::Value::String(body.status.clone())) + .map_err(|e| { + ApiError(service::error::AppError::BadRequest(e.to_string())) + })?; + let msg = WsInMessage::PresenceUpdate { status }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/custom-status", + request_body = CustomStatusRequest, + responses((status = 204, description = "Custom status updated")), + tag = "channel", +)] +pub async fn custom_status_update( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::CustomStatusUpdate { + emoji: body.emoji.clone(), + text: body.text.clone(), + expires_at: body.expires_at, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/invites", + request_body = InviteCreateRequest, + responses((status = 201, description = "Invite created")), + tag = "channel", +)] +pub async fn invite_create( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::InviteCreate { + workspace: body.workspace, + room: body.room, + max_uses: body.max_uses, + expires_at: body.expires_at, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(created_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/invites/accept", + request_body = InviteAcceptRequest, + responses((status = 200, description = "Invite accepted")), + tag = "channel", +)] +pub async fn invite_accept( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::InviteAccept { + code: body.code.clone(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/invites/{id}", + responses((status = 204, description = "Invite revoked")), + tag = "channel", +)] +pub async fn invite_revoke( + req: HttpRequest, + id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::InviteRevoke { + id: id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/workspaces/{workspace_id}/bans", + request_body = BanCreateRequest, + responses((status = 201, description = "User banned")), + tag = "channel", +)] +pub async fn ban_create( + req: HttpRequest, + workspace_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::BanCreate { + workspace: workspace_id.into_inner(), + user: body.user, + reason: body.reason.clone(), + expires_at: body.expires_at, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/workspaces/{workspace_id}/bans/{user_id}", + responses((status = 204, description = "User unbanned")), + tag = "channel", +)] +pub async fn ban_remove( + req: HttpRequest, + path: web::Path<(Uuid, Uuid)>, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let (workspace, target_user) = path.into_inner(); + let msg = WsInMessage::BanRemove { + workspace, + user: target_user, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[derive(Debug, Serialize, utoipa::ToSchema)] +pub struct RoomMember { + pub id: Uuid, + pub username: String, + pub display_name: String, + pub avatar_url: String, +} + +#[utoipa::path( + get, + path = "/api/v1/ws/workspaces/{workspace_id}/members", + responses((status = 200, description = "Workspace members list")), + tag = "channel", +)] +pub async fn list_workspace_members( + req: HttpRequest, + workspace_id: web::Path, + bus: web::Data, +) -> Result { + let _user_id = extract_user(&req)?; + let workspace = workspace_id.into_inner(); + + let members = bus.list_workspace_members(workspace).await.map_err(channel_err)?; + let result: Vec = members + .into_iter() + .map(|(id, username, display_name, avatar_url)| RoomMember { + id, + username, + display_name, + avatar_url, + }) + .collect(); + + Ok(HttpResponse::Ok().json(result)) +} diff --git a/lib/api/src/channel/rest_message.rs b/lib/api/src/channel/rest_message.rs new file mode 100644 index 0000000..516ac60 --- /dev/null +++ b/lib/api/src/channel/rest_message.rs @@ -0,0 +1,225 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; +use serde::Deserialize; +use uuid::Uuid; + +use super::rest::{channel_err, created_json, extract_user, ok_json}; +use crate::error::ApiError; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct CreateMessageRequest { + pub content: String, + pub content_type: Option, + pub thread: Option, + pub in_reply_to: Option, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct UpdateMessageRequest { + pub content: String, +} + +#[derive(Debug, Deserialize, utoipa::IntoParams)] +pub struct MessageListParams { + pub before_seq: Option, + pub after_seq: Option, + pub limit: Option, +} + +#[derive(Debug, Deserialize, utoipa::IntoParams)] +pub struct MessageAroundParams { + pub seq: i64, + pub limit: Option, +} + +#[derive(Debug, Deserialize, utoipa::IntoParams)] +pub struct MissedMessagesParams { + pub after_seq: i64, + pub limit: Option, +} + +#[derive(Debug, Deserialize, utoipa::IntoParams)] +pub struct SearchParams { + pub q: String, + pub room: Option, + pub limit: Option, + pub offset: Option, +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/messages", + request_body = CreateMessageRequest, + responses((status = 201, description = "Message created")), + tag = "channel", +)] +pub async fn create_message( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::MessageCreate { + room: room_id.into_inner(), + content: body.content.clone(), + content_type: body.content_type.clone(), + thread: body.thread, + in_reply_to: body.in_reply_to, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(created_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/messages/{message_id}", + request_body = UpdateMessageRequest, + responses((status = 200, description = "Message updated")), + tag = "channel", +)] +pub async fn update_message( + req: HttpRequest, + message_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::MessageUpdate { + message: message_id.into_inner(), + content: body.content.clone(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/messages/{message_id}", + responses((status = 200, description = "Message revoked")), + tag = "channel", +)] +pub async fn revoke_message( + req: HttpRequest, + message_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::MessageRevoke { + message: message_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/rooms/{room_id}/messages", + params(MessageListParams), + responses((status = 200, description = "Message list")), + tag = "channel", +)] +pub async fn list_messages( + req: HttpRequest, + room_id: web::Path, + params: web::Query, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::MessageList { + room: room_id.into_inner(), + before_seq: params.before_seq, + after_seq: params.after_seq, + limit: params.limit, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/rooms/{room_id}/messages/around", + params(MessageAroundParams), + responses((status = 200, description = "Messages around seq")), + tag = "channel", +)] +pub async fn messages_around( + req: HttpRequest, + room_id: web::Path, + params: web::Query, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::MessageAround { + room: room_id.into_inner(), + seq: params.seq, + limit: params.limit, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/rooms/{room_id}/messages/missed", + params(MissedMessagesParams), + responses((status = 200, description = "Missed messages")), + tag = "channel", +)] +pub async fn missed_messages( + req: HttpRequest, + room_id: web::Path, + params: web::Query, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::MissedMessages { + room: room_id.into_inner(), + after_seq: params.after_seq, + limit: params.limit, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/search", + params(SearchParams), + responses((status = 200, description = "Search results")), + tag = "channel", +)] +pub async fn search( + req: HttpRequest, + params: web::Query, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::Search { + q: params.q.clone(), + room: params.room, + start_time: None, + end_time: None, + sender_id: None, + content_type: None, + limit: params.limit, + offset: params.offset, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/channel/rest_room.rs b/lib/api/src/channel/rest_room.rs new file mode 100644 index 0000000..2c969fd --- /dev/null +++ b/lib/api/src/channel/rest_room.rs @@ -0,0 +1,321 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; +use serde::Deserialize; +use uuid::Uuid; + +use super::rest::{channel_err, created_json, extract_user, ok_json}; +use crate::error::ApiError; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct RoomCreateRequest { + pub workspace: Uuid, + pub room_name: String, + pub public: bool, + pub category: Option, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct RoomUpdateRequest { + pub room_name: Option, + pub public: Option, + pub category: Option, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct AccessRequest { + pub user: Uuid, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct CategoryCreateRequest { + pub name: String, + pub position: Option, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct CategoryUpdateRequest { + pub name: Option, + pub position: Option, +} +#[utoipa::path( + get, + path = "/api/v1/ws/rooms", + responses((status = 200, description = "List of rooms")), + tag = "channel", +)] +pub async fn list_rooms( + req: HttpRequest, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let rooms = bus.list_user_rooms(user_id) + .await + .map_err(channel_err)?; + let categories = bus.list_user_categories(user_id) + .await + .map_err(channel_err)?; + let workspace_id = if let Some(r) = rooms.first() { + Some(r.workspace_id) + } else { + bus.first_workspace_id(user_id).await.unwrap_or(None) + }; + Ok(HttpResponse::Ok().json(serde_json::json!({ + "rooms": rooms, + "categories": categories, + "workspace_id": workspace_id, + }))) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/subscribe", + responses((status = 204, description = "Subscribed, user room cache refreshed")), + tag = "channel", +)] +pub async fn subscribe( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::Subscribe { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}/subscribe", + responses((status = 204, description = "Unsubscribed")), + tag = "channel", +)] +pub async fn unsubscribe( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::Unsubscribe { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + get, + path = "/api/v1/ws/rooms/{room_id}", + responses((status = 200, description = "Room info")), + tag = "channel", +)] +pub async fn room_get( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::RoomGet { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms", + request_body = RoomCreateRequest, + responses((status = 201, description = "Room created")), + tag = "channel", +)] +pub async fn room_create( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::RoomCreate { + workspace: body.workspace, + room_name: body.room_name.clone(), + public: body.public, + category: body.category, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(created_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/rooms/{room_id}", + request_body = RoomUpdateRequest, + responses((status = 200, description = "Room updated")), + tag = "channel", +)] +pub async fn room_update( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::RoomUpdate { + room: room_id.into_inner(), + room_name: body.room_name.clone(), + public: body.public, + category: body.category, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}", + responses((status = 204, description = "Room deleted")), + tag = "channel", +)] +pub async fn room_delete( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::RoomDelete { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/members", + request_body = AccessRequest, + responses((status = 204, description = "Access granted")), + tag = "channel", +)] +pub async fn access_grant( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::AccessGrant { + room: room_id.into_inner(), + user: body.user, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/rooms/{room_id}/members/{user_id}", + responses((status = 204, description = "Access revoked")), + tag = "channel", +)] +pub async fn access_revoke( + req: HttpRequest, + path: web::Path<(Uuid, Uuid)>, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let (room, target_user) = path.into_inner(); + let msg = WsInMessage::AccessRevoke { + room, + user: target_user, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/workspaces/{workspace_id}/categories", + request_body = CategoryCreateRequest, + responses((status = 201, description = "Category created")), + tag = "channel", +)] +pub async fn category_create( + req: HttpRequest, + workspace_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::CategoryCreate { + workspace: workspace_id.into_inner(), + name: body.name.clone(), + position: body.position, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(created_json(result)) +} + +#[utoipa::path( + patch, + path = "/api/v1/ws/categories/{category_id}", + request_body = CategoryUpdateRequest, + responses((status = 200, description = "Category updated")), + tag = "channel", +)] +pub async fn category_update( + req: HttpRequest, + category_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::CategoryUpdate { + id: category_id.into_inner(), + name: body.name.clone(), + position: body.position, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + delete, + path = "/api/v1/ws/categories/{category_id}", + responses((status = 204, description = "Category deleted")), + tag = "channel", +)] +pub async fn category_delete( + req: HttpRequest, + category_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::CategoryDelete { + id: category_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/channel/rest_voice.rs b/lib/api/src/channel/rest_voice.rs new file mode 100644 index 0000000..4b3c011 --- /dev/null +++ b/lib/api/src/channel/rest_voice.rs @@ -0,0 +1,137 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::ChannelBus; +use channel::http::{WsHandler, WsInMessage}; +use serde::Deserialize; +use uuid::Uuid; + +use super::rest::{channel_err, extract_user, ok_json}; +use crate::error::ApiError; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct VoiceMuteRequest { + pub muted: bool, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct VoiceDeafRequest { + pub deafened: bool, +} + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct ScreenShareRequest { + pub start: bool, +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/voice/join", + responses((status = 204, description = "Joined voice channel")), + tag = "channel", +)] +pub async fn voice_join( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::VoiceJoin { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/voice/leave", + responses((status = 204, description = "Left voice channel")), + tag = "channel", +)] +pub async fn voice_leave( + req: HttpRequest, + room_id: web::Path, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::VoiceLeave { + room: room_id.into_inner(), + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/voice/mute", + request_body = VoiceMuteRequest, + responses((status = 204, description = "Mute toggled")), + tag = "channel", +)] +pub async fn voice_mute( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::VoiceMute { + room: room_id.into_inner(), + muted: body.muted, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/voice/deaf", + request_body = VoiceDeafRequest, + responses((status = 204, description = "Deaf toggled")), + tag = "channel", +)] +pub async fn voice_deaf( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::VoiceDeaf { + room: room_id.into_inner(), + deafened: body.deafened, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} + +#[utoipa::path( + post, + path = "/api/v1/ws/rooms/{room_id}/screen-share", + request_body = ScreenShareRequest, + responses((status = 204, description = "Screen share toggled")), + tag = "channel", +)] +pub async fn screen_share( + req: HttpRequest, + room_id: web::Path, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = extract_user(&req)?; + let msg = WsInMessage::ScreenShare { + room: room_id.into_inner(), + start: body.start, + }; + let result = WsHandler::handle(&bus, user_id, msg) + .await + .map_err(channel_err)?; + Ok(ok_json(result)) +} diff --git a/lib/api/src/channel/token.rs b/lib/api/src/channel/token.rs new file mode 100644 index 0000000..fc12931 --- /dev/null +++ b/lib/api/src/channel/token.rs @@ -0,0 +1,53 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use channel::{ChannelBus, ChannelTokenApply, TOKEN_TTL_SECS}; +use serde::Deserialize; +use session::SessionExt; + +use crate::error::ApiError; + +use super::rest::channel_err; + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct TokenRequest { + pub device_id: String, + pub client_id: String, +} + +#[derive(Debug, serde::Serialize, utoipa::ToSchema)] +pub struct TokenResponse { + pub access_token: String, + pub expires_in_secs: u64, +} + +#[utoipa::path( + post, + path = "/api/v1/ws/token", + request_body = TokenRequest, + responses((status = 200, body = TokenResponse)), + tag = "channel" +)] +pub async fn generate_token( + req: HttpRequest, + body: web::Json, + bus: web::Data, +) -> Result { + let user_id = req + .get_session() + .user() + .ok_or_else(|| ApiError(service::error::AppError::Unauthorized))?; + + let apply = ChannelTokenApply { + device_id: body.device_id.clone(), + client_id: body.client_id.clone(), + }; + + let token = bus + .apply_access_token(user_id, apply) + .await + .map_err(channel_err)?; + + Ok(HttpResponse::Ok().json(TokenResponse { + access_token: token.access_token, + expires_in_secs: TOKEN_TTL_SECS, + })) +} diff --git a/lib/api/src/error.rs b/lib/api/src/error.rs new file mode 100644 index 0000000..8c3cc40 --- /dev/null +++ b/lib/api/src/error.rs @@ -0,0 +1,91 @@ +use actix_web::{HttpResponse, error::ResponseError, http::StatusCode}; +use serde::Serialize; +use service::error::AppError; + +pub fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +pub struct ApiError(pub AppError); + +impl From for ApiError { + fn from(err: AppError) -> Self { + ApiError(err) + } +} + +impl std::fmt::Display for ApiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl std::fmt::Debug for ApiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl ResponseError for ApiError { + fn status_code(&self) -> StatusCode { + match &self.0 { + AppError::Unauthorized => StatusCode::UNAUTHORIZED, + AppError::UserNotFound => StatusCode::NOT_FOUND, + AppError::InvalidPassword => StatusCode::UNAUTHORIZED, + AppError::PasswordTooWeak => StatusCode::BAD_REQUEST, + AppError::CaptchaError => StatusCode::BAD_REQUEST, + AppError::TwoFactorRequired => { + StatusCode::from_u16(402).unwrap_or(StatusCode::BAD_REQUEST) + } + AppError::TwoFactorAlreadyEnabled => StatusCode::CONFLICT, + AppError::TwoFactorNotSetup => StatusCode::NOT_FOUND, + AppError::InvalidTwoFactorCode => StatusCode::BAD_REQUEST, + AppError::TwoFactorNotEnabled => StatusCode::NOT_FOUND, + AppError::RsaGenerationError => StatusCode::INTERNAL_SERVER_ERROR, + AppError::RsaDecodeError => StatusCode::BAD_REQUEST, + AppError::UserNameExists => StatusCode::CONFLICT, + AppError::EmailExists => StatusCode::CONFLICT, + AppError::AccountAlreadyExists => StatusCode::CONFLICT, + AppError::TxnError => StatusCode::INTERNAL_SERVER_ERROR, + AppError::PasswordHashError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::DoMainNotSet => StatusCode::INTERNAL_SERVER_ERROR, + AppError::InternalError => StatusCode::INTERNAL_SERVER_ERROR, + AppError::InternalServerError(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } + AppError::PermissionDenied => StatusCode::FORBIDDEN, + AppError::ProjectNotFound => StatusCode::NOT_FOUND, + AppError::NoPower => StatusCode::FORBIDDEN, + AppError::RoleParseError => StatusCode::BAD_REQUEST, + AppError::ProjectNameAlreadyExists => StatusCode::CONFLICT, + AppError::RepoNameAlreadyExists => StatusCode::CONFLICT, + AppError::AvatarUploadError(_) => StatusCode::BAD_REQUEST, + AppError::RepoNotFound => StatusCode::NOT_FOUND, + AppError::RepoForBidAccess => StatusCode::FORBIDDEN, + AppError::SerdeError(_) => StatusCode::BAD_REQUEST, + AppError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::BadRequest(_) => StatusCode::BAD_REQUEST, + AppError::Forbidden(_) => StatusCode::FORBIDDEN, + AppError::Conflict(_) => StatusCode::CONFLICT, + AppError::NotFound(_) => StatusCode::NOT_FOUND, + AppError::InvalidResetToken => StatusCode::BAD_REQUEST, + AppError::ResetTokenExpired => StatusCode::BAD_REQUEST, + AppError::ResetTokenUsed => StatusCode::BAD_REQUEST, + AppError::IssueNotFound => StatusCode::NOT_FOUND, + AppError::LabelNotFound => StatusCode::NOT_FOUND, + AppError::MilestoneNotFound => StatusCode::NOT_FOUND, + AppError::PullRequestNotFound => StatusCode::NOT_FOUND, + AppError::CommentNotFound => StatusCode::NOT_FOUND, + AppError::GitRpcError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::AiError(_) => StatusCode::INTERNAL_SERVER_ERROR, + } + } + + fn error_response(&self) -> HttpResponse { + let status = self.status_code(); + let message = self.0.to_string(); + HttpResponse::build(status) + .json(serde_json::json!({ "error": message })) + } +} diff --git a/lib/api/src/git/archive.rs b/lib/api/src/git/archive.rs new file mode 100644 index 0000000..dd4427f --- /dev/null +++ b/lib/api/src/git/archive.rs @@ -0,0 +1,47 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ArchiveQuery { + #[serde(default = "default_format")] + pub format: String, + pub tree: Option, + pub prefix: Option, + pub pathspec: Option>, +} + +fn default_format() -> String { + "tar".to_string() +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/archive", + params(WkRepoPath, ArchiveQuery), + responses((status = 200, description = "Archive download")), + security(("session" = [])) +)] +pub async fn archive( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + match query.format.as_str() { + "zip" => ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?), + _ => ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?), + } +} diff --git a/lib/api/src/git/blame.rs b/lib/api/src/git/blame.rs new file mode 100644 index 0000000..46d6bb1 --- /dev/null +++ b/lib/api/src/git/blame.rs @@ -0,0 +1,62 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct BlameQuery { + pub path: String, + pub rev: Option, + pub start_line: Option, + pub end_line: Option, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blame", + params(WkRepoPath, BlameQuery), + responses((status = 200, description = "Blame result", body = dto::BlameFileResponseDto)), + security(("session" = [])) +)] +pub async fn blame_file( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + + match (query.start_line, query.end_line) { + (Some(start), Some(end)) => { + let data: dto::BlameFileResponseDto = service + .git_blame_hunk( + &session, &wk, &repo, query.path.clone(), + query.rev.clone(), start, end, + ) + .await? + .into(); + ok_json(data) + } + _ => { + let data: dto::BlameFileResponseDto = service + .git_blame_file( + &session, &wk, &repo, query.path.clone(), + query.rev.clone(), None, + ) + .await? + .into(); + ok_json(data) + } + } +} diff --git a/lib/api/src/git/blob.rs b/lib/api/src/git/blob.rs new file mode 100644 index 0000000..95ab810 --- /dev/null +++ b/lib/api/src/git/blob.rs @@ -0,0 +1,97 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoBlobPath { + pub wk: String, + pub repo: String, + pub oid: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct BlobPathQuery { + pub path: Option, +} +#[derive(Serialize, utoipa::ToSchema)] +pub struct BlobInfoResponse { + #[serde(flatten)] + pub load: dto::BlobLoadResponseDto, + pub size: u64, + pub is_binary: bool, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blobs/{oid}", + params(WkRepoBlobPath, BlobPathQuery), + responses((status = 200, description = "Blob info", body = BlobInfoResponse)), + security(("session" = [])) +)] +pub async fn blob_info( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoBlobPath { wk, repo, oid } = path.into_inner(); + let path_opt = query.path.clone().unwrap_or_default(); + + let load: dto::BlobLoadResponseDto = service + .git_blob_load(&session, &wk, &repo, oid.clone(), path_opt.clone()) + .await? + .into(); + let size_resp: dto::BlobSizeResponseDto = service + .git_blob_size(&session, &wk, &repo, oid.clone(), path_opt) + .await? + .into(); + let binary_resp: dto::BlobIsBinaryResponseDto = service + .git_blob_is_binary(&session, &wk, &repo, oid) + .await? + .into(); + + ok_json(BlobInfoResponse { + load, + size: size_resp.size, + is_binary: binary_resp.is_binary, + }) +} + +#[derive(Deserialize, utoipa::ToSchema)] +pub struct BlobUploadBody { + pub path: String, + pub blob: Vec, +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blobs", + params(WkRepoPath), + request_body = BlobUploadBody, + responses((status = 200, description = "Upload result", body = dto::BlobUploadResponseDto)), + security(("session" = [])) +)] +pub async fn blob_upload( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let p = params.into_inner(); + let data: dto::BlobUploadResponseDto = service + .git_blob_upload(&session, &wk, &repo, p.path, p.blob) + .await? + .into(); + ok_json(data) +} diff --git a/lib/api/src/git/branch.rs b/lib/api/src/git/branch.rs new file mode 100644 index 0000000..4d8063d --- /dev/null +++ b/lib/api/src/git/branch.rs @@ -0,0 +1,205 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{AppService, Pagination}; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoBranchPath { + pub wk: String, + pub repo: String, + pub name: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct BranchDeleteQuery { + #[serde(default)] + pub force: bool, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct BranchListQuery { + #[serde(default)] + pub summary: bool, + #[serde(default)] + pub default_only: bool, +} + +#[derive(Deserialize, utoipa::ToSchema)] +pub struct RenameBranchBody { + pub new_branch: String, + #[serde(default)] + pub force: bool, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches", + params(WkRepoPath, Pagination, BranchListQuery), + responses((status = 200, description = "Branch list or summary", body = dto::BranchListResponseDto)), + security(("session" = [])) +)] +pub async fn list_branches( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + if query.summary { + let data: dto::BranchSummaryResponseDto = service + .git_branch_summary(&session, &wk, &repo) + .await? + .into(); + return ok_json(data); + } + if query.default_only { + let data: dto::BranchHeadResponseDto = service + .git_branch_head(&session, &wk, &repo) + .await? + .into(); + return ok_json(data); + } + let data: dto::BranchListResponseDto = service + .git_branch_list(&session, &wk, &repo, pagination.into_inner()) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}", + params(WkRepoBranchPath), + responses((status = 200, description = "Branch info", body = dto::BranchInfoResponseDto)), + security(("session" = [])) +)] +pub async fn branch_info( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoBranchPath { wk, repo, name } = path.into_inner(); + let data: dto::BranchInfoResponseDto = service + .git_branch_info(&session, &wk, &repo, name) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches", + params(WkRepoPath), + request_body = Object, description = "BranchForkParams { name, oid, force }", + responses((status = 200, description = "Branch created")), + security(("session" = [])) +)] +pub async fn fork_branch( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .git_branch_fork(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + patch, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}", + params(WkRepoBranchPath), + request_body = RenameBranchBody, + responses((status = 200, description = "Branch renamed")), + security(("session" = [])) +)] +pub async fn rename_branch( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let WkRepoBranchPath { wk, repo, name } = path.into_inner(); + let body = body.into_inner(); + let params = git::rpc::proto::BranchReNameParams { + old_branch: name, + new_branch: body.new_branch, + force: body.force, + }; + let data = service + .git_branch_rename(&session, &wk, &repo, params) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}", + params(WkRepoBranchPath, BranchDeleteQuery), + responses((status = 200, description = "Branch deleted")), + security(("session" = [])) +)] +pub async fn delete_branch( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoBranchPath { wk, repo, name } = path.into_inner(); + let params = git::rpc::proto::BranchDeleteParams { + name, + force: query.force, + }; + let data = service + .git_branch_delete(&session, &wk, &repo, params) + .await?; + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct AheadBehindQuery { + pub remote_branch: String, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}/ahead-behind", + params(WkRepoBranchPath, AheadBehindQuery), + responses((status = 200, description = "Ahead/behind counts", body = dto::BranchAheadBehindResponseDto)), + security(("session" = [])) +)] +pub async fn ahead_behind( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoBranchPath { wk, repo, name } = path.into_inner(); + let data: dto::BranchAheadBehindResponseDto = service + .git_branch_ahead_behind(&session, &wk, &repo, name, query.remote_branch.clone()) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}/upstream", + params(WkRepoBranchPath), + responses((status = 200, description = "Upstream branch", body = dto::BranchUpstreamResponseDto)), + security(("session" = [])) +)] +pub async fn branch_upstream( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoBranchPath { wk, repo, name } = path.into_inner(); + let data: dto::BranchUpstreamResponseDto = service + .git_branch_upstream(&session, &wk, &repo, name) + .await? + .into(); + ok_json(data) +} diff --git a/lib/api/src/git/commit.rs b/lib/api/src/git/commit.rs new file mode 100644 index 0000000..099fbaa --- /dev/null +++ b/lib/api/src/git/commit.rs @@ -0,0 +1,169 @@ +use actix_web::{HttpResponse, web}; +use git::rpc::proto as p; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoCommitPath { + pub wk: String, + pub repo: String, + pub oid: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct HistoryQuery { + pub limit: Option, + pub skip: Option, + pub sort: Option, + pub branch: Option, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct CommitListQuery { + #[serde(default)] + pub summary: bool, + #[serde(default)] + pub refs: bool, + pub prefix: Option, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits", + params(WkRepoPath, CommitListQuery), + responses( + (status = 200, description = "Commit list / summary / refs / prefix", body = dto::CommitHistoryResponseDto), + ), + security(("session" = [])) +)] +pub async fn list_commits( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + + if let Some(prefix) = &query.prefix { + let data: dto::CommitPrefixResponseDto = service + .git_commit_prefix(&session, &wk, &repo, prefix.clone()) + .await? + .into(); + return ok_json(data); + } + if query.refs { + let data: dto::CommitRefsResponseDto = service + .git_commit_refs(&session, &wk, &repo) + .await? + .into(); + return ok_json(data); + } + if query.summary { + let data: dto::CommitSummaryResponseDto = service + .git_commit_summary(&session, &wk, &repo) + .await? + .into(); + return ok_json(data); + } + let data: dto::CommitHistoryResponseDto = service + .git_commit_history(&session, &wk, &repo, 20, 0, 0, None) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/history", + params(WkRepoPath, HistoryQuery), + responses((status = 200, description = "Commit history", body = dto::CommitHistoryResponseDto)), + security(("session" = [])) +)] +pub async fn commit_history( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data: dto::CommitHistoryResponseDto = service + .git_commit_history( + &session, &wk, &repo, + query.limit.unwrap_or(20), + query.skip.unwrap_or(0), + query.sort.unwrap_or(0), + query.branch.clone(), + ) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/{oid}", + params(WkRepoCommitPath), + responses((status = 200, description = "Commit info", body = dto::CommitInfoResponseDto)), + security(("session" = [])) +)] +pub async fn commit_info( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoCommitPath { wk, repo, oid } = path.into_inner(); + let data: dto::CommitInfoResponseDto = service + .git_commit_info(&session, &wk, &repo, oid) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/walk", + params(WkRepoPath), + request_body = Object, description = "CommitWalkParams", + responses((status = 200, description = "Walk result", body = dto::CommitHistoryResponseDto)), + security(("session" = [])) +)] +pub async fn commit_walk( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let proto_resp = service + .git_commit_walk(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(dto::CommitHistoryResponseDto { + commits: proto_resp.commits.into_iter().map(Into::into).collect(), + }) +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/cherry-pick", + params(WkRepoPath), + request_body = Object, description = "CommitCherryPickParams", + responses((status = 200, description = "Cherry-pick result", body = dto::CherryPickResponseDto)), + security(("session" = [])) +)] +pub async fn cherry_pick( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data: dto::CherryPickResponseDto = service + .git_cherry_pick(&session, &wk, &repo, params.into_inner()) + .await? + .into(); + ok_json(data) +} diff --git a/lib/api/src/git/commit_status.rs b/lib/api/src/git/commit_status.rs new file mode 100644 index 0000000..cdf3c55 --- /dev/null +++ b/lib/api/src/git/commit_status.rs @@ -0,0 +1,81 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use service::git::commit_status::{ + CombinedCommitStatus, CommitStatusResponse, CreateCommitStatus, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok_created(data: T) -> Result { + Ok(HttpResponse::Created().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct CommitShaPath { + pub wk: String, + pub repo: String, + pub sha: String, +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/commits/{sha}/statuses", + params(CommitShaPath), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_statuses( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json(service.git_commit_status_list_by_name( + &session, &path.wk, &path.repo, &path.sha, + ).await?) +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/commits/{sha}/status", + params(CommitShaPath), + responses((status = 200, body = CombinedCommitStatus)), + security(("session" = [])) +)] +pub async fn combined_status( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + ok_json(service.git_commit_status_combined_by_name( + &session, &path.wk, &path.repo, &path.sha, + ).await?) +} + +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/statuses/{sha}", + params(CommitShaPath), + request_body = CreateCommitStatus, + responses((status = 201, body = CommitStatusResponse)), + security(("session" = [])) +)] +pub async fn create_status( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_created(service.git_commit_status_create_by_name( + &session, user_id, &path.wk, &path.repo, &path.sha, body.into_inner(), + ).await?) +} diff --git a/lib/api/src/git/compare.rs b/lib/api/src/git/compare.rs new file mode 100644 index 0000000..9a3af8c --- /dev/null +++ b/lib/api/src/git/compare.rs @@ -0,0 +1,43 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use service::git::compare::CompareResponse; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/compare/{basehead}", + params( + ("wk" = String, Path), + ("repo" = String, Path), + ("basehead" = String, Path, description = "Comparison in format 'base...head'"), + ), + responses((status = 200, body = CompareResponse)), + security(("session" = [])) +)] +pub async fn compare( + session: Session, + service: web::Data, + path: web::Path<(String, String, String)>, +) -> Result { + let (wk, repo_name, basehead) = path.into_inner(); + + let (base, head) = basehead + .split_once("...") + .ok_or_else(|| ApiError(service::error::AppError::BadRequest( + "basehead must be in format 'base...head'".to_string(), + )))?; + + ok_json(service.git_compare(&session, &wk, &repo_name, base, head).await?) +} diff --git a/lib/api/src/git/contents.rs b/lib/api/src/git/contents.rs new file mode 100644 index 0000000..abb0145 --- /dev/null +++ b/lib/api/src/git/contents.rs @@ -0,0 +1,105 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use service::git::contents::{ContentResponse, CreateContent, UpdateContent}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ContentQuery { + pub r#ref: Option, +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/contents/{path}", + params(WkRepoPath, ("path" = String, Path), ContentQuery), + responses((status = 200, body = ContentResponse)), + security(("session" = [])) +)] +pub async fn get_contents( + session: Session, + service: web::Data, + info: web::Path<(String, String, String)>, + query: web::Query, +) -> Result { + let (wk, repo_name, file_path) = info.into_inner(); + ok_json(service.git_contents_get_by_name( + &session, &wk, &repo_name, &file_path, query.r#ref.as_deref(), + ).await?) +} + +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/contents/{path}", + params(WkRepoPath, ("path" = String, Path)), + request_body = CreateContent, + responses((status = 201, body = ContentResponse)), + security(("session" = [])) +)] +pub async fn create_contents( + session: Session, + service: web::Data, + info: web::Path<(String, String, String)>, + body: web::Json, +) -> Result { + let (wk, repo_name, file_path) = info.into_inner(); + let resp = service.git_contents_create_by_name( + &session, &wk, &repo_name, &file_path, body.into_inner(), + ).await?; + Ok(HttpResponse::Created().json(resp)) +} + +#[utoipa::path( + put, path = "/api/v1/workspace/{wk}/repos/{repo}/contents/{path}", + params(WkRepoPath, ("path" = String, Path)), + request_body = UpdateContent, + responses((status = 200, body = ContentResponse)), + security(("session" = [])) +)] +pub async fn update_contents( + session: Session, + service: web::Data, + info: web::Path<(String, String, String)>, + body: web::Json, +) -> Result { + let (wk, repo_name, file_path) = info.into_inner(); + ok_json(service.git_contents_update_by_name( + &session, &wk, &repo_name, &file_path, body.into_inner(), + ).await?) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct DeleteContentQuery { + pub message: String, + pub sha: String, + pub branch: Option, +} + +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/contents/{path}", + params(WkRepoPath, ("path" = String, Path), DeleteContentQuery), + responses((status = 204)), + security(("session" = [])) +)] +pub async fn delete_contents( + session: Session, + service: web::Data, + info: web::Path<(String, String, String)>, + query: web::Query, +) -> Result { + let (wk, repo_name, file_path) = info.into_inner(); + service.git_contents_delete_by_name( + &session, &wk, &repo_name, &file_path, &query.message, &query.sha, query.branch.as_deref(), + ).await?; + Ok(HttpResponse::NoContent().finish()) +} diff --git a/lib/api/src/git/contributor.rs b/lib/api/src/git/contributor.rs new file mode 100644 index 0000000..6e75901 --- /dev/null +++ b/lib/api/src/git/contributor.rs @@ -0,0 +1,34 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{AppService, Pagination}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/contributors", + params(WkRepoPath, Pagination), + responses((status = 200, description = "Contributor list", body = Vec)), + security(("session" = [])) +)] +pub async fn list_contributors( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .git_repo_contributors(&session, &wk, &repo, pagination.into_inner()) + .await?; + ok_json(data) +} diff --git a/lib/api/src/git/diff.rs b/lib/api/src/git/diff.rs new file mode 100644 index 0000000..0a8d60f --- /dev/null +++ b/lib/api/src/git/diff.rs @@ -0,0 +1,87 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct DiffQuery { + pub old_oid: Option, + pub new_oid: Option, + pub old_tree: Option, + pub new_tree: Option, + pub tree_oid: Option, + #[serde(default = "default_mode")] + pub mode: String, + pub path: Option, +} + +fn default_mode() -> String { + "patch".to_string() +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/diff", + params(WkRepoPath, DiffQuery), + responses((status = 200, description = "Diff result")), + security(("session" = [])) +)] +pub async fn diff( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + if let (Some(old_tree), Some(new_tree)) = (&query.old_tree, &query.new_tree) { + let proto_resp = service + .git_diff_tree_to_tree(&session, &wk, &repo, old_tree.clone(), new_tree.clone(), None) + .await?; + let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + return ok_json(data); + } + if let Some(tree_oid) = &query.tree_oid { + let proto_resp = service + .git_diff_index_to_tree(&session, &wk, &repo, tree_oid.clone(), None) + .await?; + let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + return ok_json(data); + } + let old_oid = query.old_oid.clone().unwrap_or_default(); + let new_oid = query.new_oid.clone().unwrap_or_default(); + + match query.mode.as_str() { + "stats" => { + let proto_resp = service + .git_diff_stats(&session, &wk, &repo, old_oid, new_oid, None) + .await?; + let data: dto::DiffStatsDto = proto_resp.result.and_then(|r| r.stats).unwrap_or_default().into(); + ok_json(data) + } + "side-by-side" => { + let proto_resp = service + .git_diff_patch_side_by_side(&session, &wk, &repo, old_oid, new_oid, None) + .await?; + let data: dto::SideBySideDiffResultDto = proto_resp.result.unwrap_or_default().into(); + ok_json(data) + } + _ => { + let proto_resp = service + .git_diff_patch(&session, &wk, &repo, old_oid, new_oid, None) + .await?; + let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + ok_json(data) + } + } +} diff --git a/lib/api/src/git/dto.rs b/lib/api/src/git/dto.rs new file mode 100644 index 0000000..f12aa73 --- /dev/null +++ b/lib/api/src/git/dto.rs @@ -0,0 +1,888 @@ +use base64::Engine; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use git::rpc::proto as p; + +fn oid_val(oid: Option) -> String { + oid.map(|o| o.value).unwrap_or_default() +} + +fn oid_opt(oid: Option) -> Option { + oid.map(|o| o.value) +} + +fn oid_vec(oids: Vec) -> Vec { + oids.into_iter().map(|o| o.value).collect() +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct ObjectIdDto { + pub value: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchListItemDto { + pub name: String, + pub oid: String, + pub is_head: bool, + pub is_remote: bool, + pub is_current: bool, + pub upstream: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchListResponseDto { + pub branches: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchInfoResponseDto { + pub branch: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchSummaryDto { + pub local_count: u64, + pub remote_count: u64, + pub all_count: u64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchSummaryResponseDto { + pub summary: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchHeadResponseDto { + pub head_name: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchAheadBehindResponseDto { + pub ahead: u64, + pub behind: u64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BranchUpstreamResponseDto { + pub upstream_name: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitSignatureDto { + pub name: String, + pub email: String, + pub time_secs: i64, + pub offset_minutes: i32, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitMetaDto { + pub oid: String, + pub message: String, + pub summary: String, + pub author: Option, + pub committer: Option, + pub tree_id: Option, + pub parent_ids: Vec, + pub encoding: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitInfoResponseDto { + pub commit: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitHistoryResponseDto { + pub commits: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitSummaryDto { + pub head: Option, + pub count: u64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitSummaryResponseDto { + pub summary: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitRefInfoDto { + pub name: String, + pub target: String, + pub is_remote: bool, + pub is_tag: bool, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitRefsResponseDto { + pub refs: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitPrefixResponseDto { + pub oid: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitExistsResponseDto { + pub exists: bool, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CherryPickResponseDto { + pub oid: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum TreeKindDto { + Blob, + Tree, + LfsPointer, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TreeEntryDto { + pub name: String, + pub oid: String, + pub kind: TreeKindDto, + pub filemode: u32, + pub is_binary: bool, + pub is_lfs: bool, + pub last_commit_message: Option, + pub last_commit_time: Option, + pub last_commit_author_name: Option, + pub last_commit_author_email: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TreeEntriesResponseDto { + pub entries: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TreeEntryByPathResponseDto { + pub entry: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TreeInfoDto { + pub oid: String, + pub entry_count: u64, + pub is_empty: bool, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct ResolveTreeResponseDto { + pub info: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlobLoadResponseDto { + pub blob: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlobSizeResponseDto { + pub size: u64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlobExistsResponseDto { + pub exists: bool, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlobIsBinaryResponseDto { + pub is_binary: bool, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlobUploadResponseDto { + pub id: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagItemDto { + pub name: String, + pub oid: String, + pub target: String, + pub is_annotated: bool, + pub message: Option, + pub tagger: Option, + pub tagger_email: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagListResponseDto { + pub tags: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagInfoResponseDto { + pub tag: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagSummaryDto { + pub total_count: u64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagSummaryResponseDto { + pub summary: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagInitResponseDto { + pub oid: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct TagUpdateMessageResponseDto { + pub oid: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitBlameHunkDto { + pub commit_oid: Option, + pub final_start_line: u32, + pub final_lines: u32, + pub orig_start_line: u32, + pub orig_lines: u32, + pub boundary: bool, + pub orig_path: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CommitBlameLineDto { + pub commit_oid: Option, + pub line_no: u32, + pub content: String, + pub orig_path: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlameFileResponseDto { + pub hunks: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct BlameLinesResponseDto { + pub lines: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum DiffDeltaStatusDto { + Unmodified, + Added, + Deleted, + Modified, + Renamed, + Copied, + Typechange, + Conflicted, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct DiffFileDto { + pub oid: Option, + pub path: Option, + pub size: u64, + pub is_binary: bool, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct DiffHunkDto { + pub old_start: u32, + pub old_lines: u32, + pub new_start: u32, + pub new_lines: u32, + pub header: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct DiffLineDto { + pub content: String, + pub origin: String, + pub old_lineno: Option, + pub new_lineno: Option, + pub num_lines: u32, + pub content_offset: i64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct DiffDeltaDto { + pub status: i32, + pub old_file: Option, + pub new_file: Option, + pub nfiles: u32, + pub hunks: Vec, + pub lines: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct DiffStatsDto { + pub files_changed: u64, + pub insertions: u64, + pub deletions: u64, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct DiffResultDto { + pub stats: Option, + pub deltas: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum SideBySideChangeTypeDto { + Unchanged, + Added, + Removed, + Modified, + Empty, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct SideBySideLineDto { + pub left_line_no: Option, + pub right_line_no: Option, + pub left_content: String, + pub right_content: String, + pub change_type: i32, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct SideBySideFileDto { + pub path: String, + pub additions: u64, + pub deletions: u64, + pub is_binary: bool, + pub is_rename: bool, + pub lines: Vec, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub struct SideBySideDiffResultDto { + pub files: Vec, + pub total_additions: u64, + pub total_deletions: u64, +} + +impl From for ObjectIdDto { + fn from(o: p::ObjectId) -> Self { + ObjectIdDto { value: o.value } + } +} + +impl From for BranchListItemDto { + fn from(b: p::BranchListItem) -> Self { + BranchListItemDto { + name: b.name, + oid: oid_val(b.oid), + is_head: b.is_head, + is_remote: b.is_remote, + is_current: b.is_current, + upstream: b.upstream, + } + } +} + +impl From for BranchListResponseDto { + fn from(r: p::BranchListResponse) -> Self { + BranchListResponseDto { + branches: r.branches.into_iter().map(Into::into).collect(), + } + } +} + +impl From for BranchInfoResponseDto { + fn from(r: p::BranchInfoResponse) -> Self { + BranchInfoResponseDto { + branch: r.branch.map(Into::into), + } + } +} + +impl From for BranchSummaryDto { + fn from(s: p::BranchSummary) -> Self { + BranchSummaryDto { + local_count: s.local_count, + remote_count: s.remote_count, + all_count: s.all_count, + } + } +} + +impl From for BranchSummaryResponseDto { + fn from(r: p::BranchSummaryResponse) -> Self { + BranchSummaryResponseDto { + summary: r.summary.map(Into::into), + } + } +} + +impl From for BranchHeadResponseDto { + fn from(r: p::BranchHeadResponse) -> Self { + BranchHeadResponseDto { head_name: r.head_name } + } +} + +impl From for BranchAheadBehindResponseDto { + fn from(r: p::BranchAheadBehindResponse) -> Self { + BranchAheadBehindResponseDto { + ahead: r.ahead, + behind: r.behind, + } + } +} + +impl From for BranchUpstreamResponseDto { + fn from(r: p::BranchUpstreamResponse) -> Self { + BranchUpstreamResponseDto { + upstream_name: r.upstream_name, + } + } +} + +impl From for CommitSignatureDto { + fn from(s: p::CommitSignature) -> Self { + CommitSignatureDto { + name: s.name, + email: s.email, + time_secs: s.time_secs, + offset_minutes: s.offset_minutes, + } + } +} + +impl From for CommitMetaDto { + fn from(c: p::CommitMeta) -> Self { + CommitMetaDto { + oid: oid_val(c.oid), + message: c.message, + summary: c.summary, + author: c.author.map(Into::into), + committer: c.committer.map(Into::into), + tree_id: oid_opt(c.tree_id), + parent_ids: oid_vec(c.parent_ids), + encoding: c.encoding, + } + } +} + +impl From for CommitInfoResponseDto { + fn from(r: p::CommitInfoResponse) -> Self { + CommitInfoResponseDto { + commit: r.commit.map(Into::into), + } + } +} + +impl From for CommitHistoryResponseDto { + fn from(r: p::CommitHistoryResponse) -> Self { + CommitHistoryResponseDto { + commits: r.commits.into_iter().map(Into::into).collect(), + } + } +} + +impl From for CommitSummaryDto { + fn from(s: p::CommitSummary) -> Self { + CommitSummaryDto { + head: s.head.map(Into::into), + count: s.count, + } + } +} + +impl From for CommitSummaryResponseDto { + fn from(r: p::CommitSummaryResponse) -> Self { + CommitSummaryResponseDto { + summary: r.summary.map(Into::into), + } + } +} + +impl From for CommitRefInfoDto { + fn from(r: p::CommitRefInfo) -> Self { + CommitRefInfoDto { + name: r.name, + target: oid_val(r.target), + is_remote: r.is_remote, + is_tag: r.is_tag, + } + } +} + +impl From for CommitRefsResponseDto { + fn from(r: p::CommitRefsResponse) -> Self { + CommitRefsResponseDto { + refs: r.refs.into_iter().map(Into::into).collect(), + } + } +} + +impl From for CommitPrefixResponseDto { + fn from(r: p::CommitPrefixResponse) -> Self { + CommitPrefixResponseDto { oid: oid_opt(r.oid) } + } +} + +impl From for CommitExistsResponseDto { + fn from(r: p::CommitExistsResponse) -> Self { + CommitExistsResponseDto { exists: r.exists } + } +} + +impl From for CherryPickResponseDto { + fn from(r: p::CherryPickResponse) -> Self { + CherryPickResponseDto { oid: oid_opt(r.oid) } + } +} + +impl From for CherryPickResponseDto { + fn from(r: p::CherryPickSequenceResponse) -> Self { + CherryPickResponseDto { oid: oid_opt(r.oid) } + } +} + +fn tree_kind_from_proto(kind: i32) -> TreeKindDto { + match kind { + 0 => TreeKindDto::Blob, + 1 => TreeKindDto::Tree, + 2 => TreeKindDto::LfsPointer, + _ => TreeKindDto::Blob, + } +} + +impl From for TreeEntryDto { + fn from(e: p::TreeEntry) -> Self { + fn opt(s: String) -> Option { + if s.is_empty() { None } else { Some(s) } + } + TreeEntryDto { + name: e.name, + oid: oid_val(e.oid), + kind: tree_kind_from_proto(e.kind), + filemode: e.filemode, + is_binary: e.is_binary, + is_lfs: e.is_lfs, + last_commit_message: opt(e.last_commit_message), + last_commit_time: opt(e.last_commit_time), + last_commit_author_name: opt(e.last_commit_author_name), + last_commit_author_email: opt(e.last_commit_author_email), + } + } +} + +impl From for TreeEntriesResponseDto { + fn from(r: p::TreeEntriesResponse) -> Self { + TreeEntriesResponseDto { + entries: r.entries.into_iter().map(Into::into).collect(), + } + } +} + +impl From for TreeEntryByPathResponseDto { + fn from(r: p::TreeEntryByPathResponse) -> Self { + TreeEntryByPathResponseDto { + entry: r.entry.map(Into::into), + } + } +} + +impl From for TreeEntryByPathResponseDto { + fn from(r: p::TreeEntryByPathFromCommitResponse) -> Self { + TreeEntryByPathResponseDto { + entry: r.entry.map(Into::into), + } + } +} + +impl From for TreeInfoDto { + fn from(i: p::TreeInfo) -> Self { + TreeInfoDto { + oid: oid_val(i.oid), + entry_count: i.entry_count, + is_empty: i.is_empty, + } + } +} + +impl From for ResolveTreeResponseDto { + fn from(r: p::ResolveTreeResponse) -> Self { + ResolveTreeResponseDto { + info: r.info.map(Into::into), + } + } +} + +impl From for BlobLoadResponseDto { + fn from(r: p::BlobLoadResponse) -> Self { + BlobLoadResponseDto { + blob: base64::engine::general_purpose::STANDARD.encode(&r.blob), + } + } +} + +impl From for BlobSizeResponseDto { + fn from(r: p::BlobSizeResponse) -> Self { + BlobSizeResponseDto { size: r.size } + } +} + +impl From for BlobExistsResponseDto { + fn from(r: p::BlobExistsResponse) -> Self { + BlobExistsResponseDto { exists: r.exists } + } +} + +impl From for BlobIsBinaryResponseDto { + fn from(r: p::BlobIsBinaryResponse) -> Self { + BlobIsBinaryResponseDto { is_binary: r.is_binary } + } +} + +impl From for BlobUploadResponseDto { + fn from(r: p::BlobUploadResponse) -> Self { + BlobUploadResponseDto { id: oid_opt(r.id) } + } +} + +impl From for TagItemDto { + fn from(t: p::TagItem) -> Self { + TagItemDto { + name: t.name, + oid: oid_val(t.oid), + target: oid_val(t.target), + is_annotated: t.is_annotated, + message: t.message, + tagger: t.tagger, + tagger_email: t.tagger_email, + } + } +} + +impl From for TagListResponseDto { + fn from(r: p::TagListResponse) -> Self { + TagListResponseDto { + tags: r.tags.into_iter().map(Into::into).collect(), + } + } +} + +impl From for TagInfoResponseDto { + fn from(r: p::TagInfoResponse) -> Self { + TagInfoResponseDto { tag: r.tag.map(Into::into) } + } +} + +impl From for TagSummaryDto { + fn from(s: p::TagSummary) -> Self { + TagSummaryDto { total_count: s.total_count } + } +} + +impl From for TagSummaryResponseDto { + fn from(r: p::TagSummaryResponse) -> Self { + TagSummaryResponseDto { summary: r.summary.map(Into::into) } + } +} + +impl From for TagInitResponseDto { + fn from(r: p::TagInitResponse) -> Self { + TagInitResponseDto { oid: oid_opt(r.oid) } + } +} + +impl From for TagUpdateMessageResponseDto { + fn from(r: p::TagUpdateMessageResponse) -> Self { + TagUpdateMessageResponseDto { oid: oid_opt(r.oid) } + } +} + +impl From for CommitBlameHunkDto { + fn from(h: p::CommitBlameHunk) -> Self { + CommitBlameHunkDto { + commit_oid: oid_opt(h.commit_oid), + final_start_line: h.final_start_line, + final_lines: h.final_lines, + orig_start_line: h.orig_start_line, + orig_lines: h.orig_lines, + boundary: h.boundary, + orig_path: h.orig_path, + } + } +} + +impl From for CommitBlameLineDto { + fn from(l: p::CommitBlameLine) -> Self { + CommitBlameLineDto { + commit_oid: oid_opt(l.commit_oid), + line_no: l.line_no, + content: l.content, + orig_path: l.orig_path, + } + } +} + +impl From for BlameFileResponseDto { + fn from(r: p::BlameFileResponse) -> Self { + BlameFileResponseDto { + hunks: r.hunks.into_iter().map(Into::into).collect(), + } + } +} + +impl From for BlameFileResponseDto { + fn from(r: p::BlameHunkResponse) -> Self { + BlameFileResponseDto { + hunks: r.hunks.into_iter().map(Into::into).collect(), + } + } +} + +impl From for BlameLinesResponseDto { + fn from(r: p::BlameLinesResponse) -> Self { + BlameLinesResponseDto { + lines: r.lines.into_iter().map(Into::into).collect(), + } + } +} + +impl From for DiffDeltaStatusDto { + fn from(s: p::DiffDeltaStatus) -> Self { + match s { + p::DiffDeltaStatus::Unmodified => DiffDeltaStatusDto::Unmodified, + p::DiffDeltaStatus::Added => DiffDeltaStatusDto::Added, + p::DiffDeltaStatus::Deleted => DiffDeltaStatusDto::Deleted, + p::DiffDeltaStatus::Modified => DiffDeltaStatusDto::Modified, + p::DiffDeltaStatus::Renamed => DiffDeltaStatusDto::Renamed, + p::DiffDeltaStatus::Copied => DiffDeltaStatusDto::Copied, + p::DiffDeltaStatus::Typechange => DiffDeltaStatusDto::Typechange, + p::DiffDeltaStatus::Conflicted => DiffDeltaStatusDto::Conflicted, + } + } +} + +impl From for DiffFileDto { + fn from(f: p::DiffFile) -> Self { + DiffFileDto { + oid: oid_opt(f.oid), + path: f.path, + size: f.size, + is_binary: f.is_binary, + } + } +} + +impl From for DiffHunkDto { + fn from(h: p::DiffHunk) -> Self { + DiffHunkDto { + old_start: h.old_start, + old_lines: h.old_lines, + new_start: h.new_start, + new_lines: h.new_lines, + header: h.header, + } + } +} + +impl From for DiffLineDto { + fn from(l: p::DiffLine) -> Self { + DiffLineDto { + content: l.content, + origin: l.origin, + old_lineno: l.old_lineno, + new_lineno: l.new_lineno, + num_lines: l.num_lines, + content_offset: l.content_offset, + } + } +} + +impl From for DiffDeltaDto { + fn from(d: p::DiffDelta) -> Self { + DiffDeltaDto { + status: d.status, + old_file: d.old_file.map(Into::into), + new_file: d.new_file.map(Into::into), + nfiles: d.nfiles, + hunks: d.hunks.into_iter().map(Into::into).collect(), + lines: d.lines.into_iter().map(Into::into).collect(), + } + } +} + +impl From for DiffStatsDto { + fn from(s: p::DiffStats) -> Self { + DiffStatsDto { + files_changed: s.files_changed, + insertions: s.insertions, + deletions: s.deletions, + } + } +} + +impl From for DiffResultDto { + fn from(r: p::DiffResult) -> Self { + DiffResultDto { + stats: r.stats.map(Into::into), + deltas: r.deltas.into_iter().map(Into::into).collect(), + } + } +} + +impl From for SideBySideChangeTypeDto { + fn from(t: p::SideBySideChangeType) -> Self { + match t { + p::SideBySideChangeType::Unchanged => SideBySideChangeTypeDto::Unchanged, + p::SideBySideChangeType::Added => SideBySideChangeTypeDto::Added, + p::SideBySideChangeType::Removed => SideBySideChangeTypeDto::Removed, + p::SideBySideChangeType::Modified => SideBySideChangeTypeDto::Modified, + p::SideBySideChangeType::Empty => SideBySideChangeTypeDto::Empty, + } + } +} + +impl From for SideBySideLineDto { + fn from(l: p::SideBySideLine) -> Self { + SideBySideLineDto { + left_line_no: l.left_line_no, + right_line_no: l.right_line_no, + left_content: l.left_content, + right_content: l.right_content, + change_type: l.change_type, + } + } +} + +impl From for SideBySideFileDto { + fn from(f: p::SideBySideFile) -> Self { + SideBySideFileDto { + path: f.path, + additions: f.additions, + deletions: f.deletions, + is_binary: f.is_binary, + is_rename: f.is_rename, + lines: f.lines.into_iter().map(Into::into).collect(), + } + } +} + +impl From for SideBySideDiffResultDto { + fn from(r: p::SideBySideDiffResult) -> Self { + SideBySideDiffResultDto { + files: r.files.into_iter().map(Into::into).collect(), + total_additions: r.total_additions, + total_deletions: r.total_deletions, + } + } +} diff --git a/lib/api/src/git/fork.rs b/lib/api/src/git/fork.rs new file mode 100644 index 0000000..79e39c8 --- /dev/null +++ b/lib/api/src/git/fork.rs @@ -0,0 +1,62 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, Pagination, + git::fork::{CreateFork, ForkResponse}, +}; +use session::Session; + +use crate::error::ApiError; + +use super::repo::WkRepoPath; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/forks", + params(WkRepoPath), + request_body = CreateFork, + responses( + (status = 200, body = ForkResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Cannot fork private repo"), + (status = 409, description = "Fork already exists"), + ), + security(("session" = [])) +)] +pub async fn create_fork( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_fork_create(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/forks", + params(WkRepoPath, Pagination), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_forks( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_fork_list(&session, &wk, &repo, pagination.into_inner()) + .await?; + ok_json(data) +} diff --git a/lib/api/src/git/init.rs b/lib/api/src/git/init.rs new file mode 100644 index 0000000..eb13bfe --- /dev/null +++ b/lib/api/src/git/init.rs @@ -0,0 +1,56 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{AppService, git::init::{CloneRepo, CreateRepo}}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkPath { + pub wk: String, +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos", + params(("wk" = String, Path, description = "Workspace name")), + request_body = CreateRepo, + responses((status = 200, description = "Repo created"), (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), (status = 409, description = "Repo name exists")), + security(("session" = [])) +)] +pub async fn create_repo( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner().wk; + let data = service + .git_init_bare(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} + +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/clone", + params(("wk" = String, Path, description = "Workspace name")), + request_body = CloneRepo, + responses((status = 200, description = "Repo cloned"), (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), (status = 409, description = "Repo name exists")), + security(("session" = [])) +)] +pub async fn clone_repo( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner().wk; + let data = service + .git_clone_bare(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} diff --git a/lib/api/src/git/language.rs b/lib/api/src/git/language.rs new file mode 100644 index 0000000..36a98b2 --- /dev/null +++ b/lib/api/src/git/language.rs @@ -0,0 +1,31 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/languages", + params(WkRepoPath), + responses((status = 200, description = "Language stats", body = Vec)), + security(("session" = [])) +)] +pub async fn get_languages( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_languages(&session, &wk, &repo).await?; + ok_json(data) +} diff --git a/lib/api/src/git/mod.rs b/lib/api/src/git/mod.rs new file mode 100644 index 0000000..c6703ff --- /dev/null +++ b/lib/api/src/git/mod.rs @@ -0,0 +1,241 @@ +pub mod archive; +pub mod blame; +pub mod blob; +pub mod branch; +pub mod commit; +pub mod commit_status; +pub mod compare; +pub mod contents; +pub mod contributor; +pub mod diff; +pub mod dto; +pub mod fork; +pub mod init; +pub mod language; +pub mod protect; +pub mod readme; +pub mod refs; +pub mod release; +pub mod repo; +pub mod star; +pub mod tag; +pub mod tree; +pub mod watch; +pub mod webhook; + +use actix_web::{web, web::ServiceConfig}; +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("") + .route(web::post().to(init::create_repo)) + .route(web::get().to(repo::list_repos)), + ); + cfg.service( + web::resource("/clone") + .route(web::post().to(init::clone_repo)), + ); + cfg.service( + web::resource("/{repo}") + .route(web::get().to(repo::get_repo)) + .route(web::put().to(repo::update_repo)) + .route(web::delete().to(repo::delete_repo)), + ); + cfg.service( + web::resource("/{repo}/archive") + .route(web::post().to(repo::archive_repo)), + ); + cfg.service( + web::resource("/{repo}/transfer") + .route(web::post().to(repo::transfer_repo)), + ); + cfg.service( + web::resource("/{repo}/topics") + .route(web::get().to(repo::get_topics)) + .route(web::put().to(repo::update_topics)), + ); + cfg.service( + web::resource("/{repo}/forks") + .route(web::get().to(fork::list_forks)) + .route(web::post().to(fork::create_fork)), + ); + cfg.service( + web::resource("/{repo}/protect") + .route(web::get().to(protect::list_protects)) + .route(web::post().to(protect::create_protect)), + ); + cfg.service( + web::resource("/{repo}/protect/{protect_id}") + .route(web::put().to(protect::update_protect)) + .route(web::delete().to(protect::delete_protect)), + ); + cfg.service( + web::resource("/{repo}/webhooks") + .route(web::get().to(webhook::list_webhooks)) + .route(web::post().to(webhook::create_webhook)), + ); + cfg.service( + web::resource("/{repo}/webhooks/{webhook_id}") + .route(web::put().to(webhook::update_webhook)) + .route(web::delete().to(webhook::delete_webhook)), + ); + cfg.service( + web::resource("/{repo}/webhooks/{webhook_id}/deliveries") + .route(web::get().to(webhook::list_deliveries)), + ); + cfg.service( + web::scope("/{repo}/git") + .service( + web::resource("/branches") + .route(web::get().to(branch::list_branches)) + .route(web::post().to(branch::fork_branch)), + ) + .service( + web::resource("/branches/{name}") + .route(web::get().to(branch::branch_info)) + .route(web::patch().to(branch::rename_branch)) + .route(web::delete().to(branch::delete_branch)), + ) + .service( + web::resource("/branches/{name}/ahead-behind") + .route(web::get().to(branch::ahead_behind)), + ) + .service( + web::resource("/branches/{name}/upstream") + .route(web::get().to(branch::branch_upstream)), + ) + .service( + web::resource("/commits") + .route(web::get().to(commit::list_commits)), + ) + .service( + web::resource("/commits/history") + .route(web::get().to(commit::commit_history)), + ) + .service( + web::resource("/commits/{oid}") + .route(web::get().to(commit::commit_info)), + ) + .service( + web::resource("/commits/walk") + .route(web::post().to(commit::commit_walk)), + ) + .service( + web::resource("/commits/cherry-pick") + .route(web::post().to(commit::cherry_pick)), + ) + .service( + web::resource("/blobs") + .route(web::post().to(blob::blob_upload)), + ) + .service( + web::resource("/blobs/{oid}") + .route(web::get().to(blob::blob_info)), + ) + .service( + web::resource("/blame") + .route(web::get().to(blame::blame_file)), + ) + .service( + web::resource("/trees/{oid}") + .route(web::get().to(tree::tree_entries)), + ) + .service( + web::resource("/trees/{tree_oid}/entries") + .route(web::get().to(tree::tree_entry_by_path)), + ) + .service( + web::resource("/commits/{oid}/tree") + .route(web::get().to(tree::tree_entry_by_path_from_commit)), + ) + .service( + web::resource("/diff") + .route(web::get().to(diff::diff)), + ) + .service( + web::resource("/diff/branches") + .route(web::get().to(readme::diff_branches)), + ) + .service( + web::resource("/tags") + .route(web::get().to(tag::list_tags)) + .route(web::post().to(tag::init_tag)), + ) + .service( + web::resource("/tags/{name}") + .route(web::get().to(tag::tag_info)) + .route(web::patch().to(tag::update_tag)) + .route(web::delete().to(tag::delete_tag)), + ) + .service( + web::resource("/archive") + .route(web::get().to(archive::archive)), + ) + .service( + web::resource("/star") + .route(web::get().to(star::star_status)) + .route(web::post().to(star::star_repo)) + .route(web::delete().to(star::unstar_repo)), + ) + .service( + web::resource("/watch") + .route(web::get().to(watch::watch_status)) + .route(web::post().to(watch::watch_repo)) + .route(web::delete().to(watch::unwatch_repo)), + ) + .service( + web::resource("/contributors") + .route(web::get().to(contributor::list_contributors)), + ) + .service( + web::resource("/languages") + .route(web::get().to(language::get_languages)), + ) + .service( + web::resource("/readme") + .route(web::get().to(readme::get_readme)), + ) + .service( + web::resource("/refs") + .route(web::get().to(refs::list_refs)), + ), + ); + cfg.service( + web::resource("/{repo}/releases") + .route(web::get().to(release::list_releases)) + .route(web::post().to(release::create_release)), + ) + .service( + web::resource("/{repo}/releases/{id}") + .route(web::get().to(release::get_release)) + .route(web::patch().to(release::update_release)) + .route(web::delete().to(release::delete_release)), + ) + .service( + web::resource("/{repo}/releases/tags/{tag}") + .route(web::get().to(release::get_release_by_tag)) + .route(web::delete().to(release::delete_release_by_tag)), + ) + .service( + web::resource("/{repo}/statuses/{sha}") + .route(web::post().to(commit_status::create_status)), + ) + .service( + web::resource("/{repo}/commits/{sha}/status") + .route(web::get().to(commit_status::combined_status)), + ) + .service( + web::resource("/{repo}/commits/{sha}/statuses") + .route(web::get().to(commit_status::list_statuses)), + ) + .service( + web::resource("/{repo}/compare/{basehead}") + .route(web::get().to(compare::compare)), + ) + .service( + web::resource("/{repo}/contents/{path:.*}") + .route(web::get().to(contents::get_contents)) + .route(web::post().to(contents::create_contents)) + .route(web::put().to(contents::update_contents)) + .route(web::delete().to(contents::delete_contents)), + ); +} diff --git a/lib/api/src/git/protect.rs b/lib/api/src/git/protect.rs new file mode 100644 index 0000000..81e60a1 --- /dev/null +++ b/lib/api/src/git/protect.rs @@ -0,0 +1,134 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, Pagination, + git::protect::{CreateProtect, ProtectResponse, UpdateProtect}, +}; +use session::Session; + +use crate::error::ApiError; + +use super::repo::WkRepoPath; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ProtectPath { + pub wk: String, + pub repo: String, + pub protect_id: uuid::Uuid, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/protect", + params(WkRepoPath, Pagination), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_protects( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_protect_list(&session, &wk, &repo, pagination.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/protect", + params(WkRepoPath), + request_body = CreateProtect, + responses( + (status = 200, body = ProtectResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security(("session" = [])) +)] +pub async fn create_protect( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_protect_create(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/repos/{repo}/protect/{protect_id}", + params(ProtectPath), + request_body = UpdateProtect, + responses( + (status = 200, body = ProtectResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Rule not found"), + ), + security(("session" = [])) +)] +pub async fn update_protect( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let ProtectPath { + wk, + repo, + protect_id, + } = path.into_inner(); + let data = service + .repo_protect_update( + &session, + &wk, + &repo, + protect_id, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/protect/{protect_id}", + params(ProtectPath), + responses( + (status = 200, description = "Rule deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Rule not found"), + ), + security(("session" = [])) +)] +pub async fn delete_protect( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let ProtectPath { + wk, + repo, + protect_id, + } = path.into_inner(); + service + .repo_protect_delete(&session, &wk, &repo, protect_id) + .await?; + ok() +} diff --git a/lib/api/src/git/readme.rs b/lib/api/src/git/readme.rs new file mode 100644 index 0000000..36e98c5 --- /dev/null +++ b/lib/api/src/git/readme.rs @@ -0,0 +1,87 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/readme", + params(WkRepoPath), + responses((status = 200, description = "README content", body = Option)), + security(("session" = [])) +)] +pub async fn get_readme( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_readme(&session, &wk, &repo).await?; + ok_json(data) +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/diff/branches", + params(("wk" = String, Path), ("repo" = String, Path), ("old_branch" = String, Query), ("new_branch" = String, Query)), + responses((status = 200, description = "Branch diff", body = dto::DiffResultDto)), + security(("session" = [])) +)] +pub async fn diff_branches( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let q = query.into_inner(); + + let old_info = service + .git_branch_info(&session, &wk, &repo, q.old_branch.clone()) + .await?; + let new_info = service + .git_branch_info(&session, &wk, &repo, q.new_branch.clone()) + .await?; + + let old_oid = old_info + .branch + .and_then(|b| b.oid) + .map(|o| o.value) + .unwrap_or_default(); + let new_oid = new_info + .branch + .and_then(|b| b.oid) + .map(|o| o.value) + .unwrap_or_default(); + + if old_oid.is_empty() || new_oid.is_empty() { + return Err(service::error::AppError::NotFound( + "could not resolve one or both branches".to_string(), + ) + .into()); + } + + let proto_resp = service + .git_diff_patch(&session, &wk, &repo, old_oid, new_oid, None) + .await?; + + let data: dto::DiffResultDto = proto_resp.result.unwrap_or_default().into(); + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct BranchDiffQuery { + pub old_branch: String, + pub new_branch: String, +} diff --git a/lib/api/src/git/refs.rs b/lib/api/src/git/refs.rs new file mode 100644 index 0000000..96f95d8 --- /dev/null +++ b/lib/api/src/git/refs.rs @@ -0,0 +1,41 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use service::git::refs::GitRefResponse; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct RefQuery { + pub r#ref: Option, +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/refs", + params(WkRepoPath, RefQuery), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_refs( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + if let Some(ref_name) = &query.r#ref { + let r = service.git_ref_get_by_name(&session, &path.wk, &path.repo, ref_name).await?; + return ok_json(vec![r]); + } + ok_json(service.git_ref_list_by_name(&session, &path.wk, &path.repo).await?) +} diff --git a/lib/api/src/git/release.rs b/lib/api/src/git/release.rs new file mode 100644 index 0000000..060788a --- /dev/null +++ b/lib/api/src/git/release.rs @@ -0,0 +1,146 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use service::git::release::{CreateRelease, ReleaseResponse, UpdateRelease}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok_created(data: T) -> Result { + Ok(HttpResponse::Created().json(data)) +} + +fn ok_empty() -> Result { + Ok(HttpResponse::NoContent().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ReleaseIdPath { + pub wk: String, + pub repo: String, + pub id: Uuid, +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/releases", + params(WkRepoPath), + responses((status = 200, body = Vec)), + security(("session" = [])) +)] +pub async fn list_releases( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.git_release_list_by_name(&session, user_id, &path.wk, &path.repo).await?) +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/releases/{id}", + params(ReleaseIdPath), + responses((status = 200, body = ReleaseResponse)), + security(("session" = [])) +)] +pub async fn get_release( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.git_release_get_by_name(&session, user_id, &path.wk, &path.repo, path.id).await?) +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/releases/tags/{tag}", + params(WkRepoPath, ("tag" = String, Path)), + responses((status = 200, body = ReleaseResponse)), + security(("session" = [])) +)] +pub async fn get_release_by_tag( + session: Session, + service: web::Data, + path: web::Path<(String, String, String)>, +) -> Result { + let (wk, repo_name, tag) = path.into_inner(); + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.git_release_get_by_tag_name(&session, user_id, &wk, &repo_name, &tag).await?) +} + +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/releases", + params(WkRepoPath), + request_body = CreateRelease, + responses((status = 201, body = ReleaseResponse)), + security(("session" = [])) +)] +pub async fn create_release( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_created(service.git_release_create_by_name(&session, user_id, &path.wk, &path.repo, body.into_inner()).await?) +} + +#[utoipa::path( + patch, path = "/api/v1/workspace/{wk}/repos/{repo}/releases/{id}", + params(ReleaseIdPath), + request_body = UpdateRelease, + responses((status = 200, body = ReleaseResponse)), + security(("session" = [])) +)] +pub async fn update_release( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + ok_json(service.git_release_update_by_name(&session, user_id, &path.wk, &path.repo, path.id, body.into_inner()).await?) +} + +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/releases/{id}", + params(ReleaseIdPath), + responses((status = 204)), + security(("session" = [])) +)] +pub async fn delete_release( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + service.git_release_delete_by_name(&session, user_id, &path.wk, &path.repo, path.id).await?; + ok_empty() +} + +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/releases/tags/{tag}", + params(WkRepoPath, ("tag" = String, Path)), + responses((status = 204)), + security(("session" = [])) +)] +pub async fn delete_release_by_tag( + session: Session, + service: web::Data, + path: web::Path<(String, String, String)>, +) -> Result { + let (wk, repo_name, tag) = path.into_inner(); + let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?; + service.git_release_delete_by_tag_name(&session, user_id, &wk, &repo_name, &tag).await?; + ok_empty() +} diff --git a/lib/api/src/git/repo.rs b/lib/api/src/git/repo.rs new file mode 100644 index 0000000..c177df4 --- /dev/null +++ b/lib/api/src/git/repo.rs @@ -0,0 +1,203 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, Pagination, + git::repo::{RepoFilter, RepoResponse, TransferRepo, UpdateRepo}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos", + params( + ("wk" = String, Path, description = "Workspace name"), + RepoFilter, + Pagination, + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_repos( + session: Session, + service: web::Data, + path: web::Path, + filter: web::Query, + pagination: web::Query, +) -> Result { + let wk = path.into_inner(); + let data = service + .repo_list(&session, &wk, filter.into_inner(), pagination.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}", + params(WkRepoPath), + responses( + (status = 200, body = RepoResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Repo not found"), + ), + security(("session" = [])) +)] +pub async fn get_repo( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.repo_get(&session, &wk, &repo).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/repos/{repo}", + params(WkRepoPath), + request_body = UpdateRepo, + responses( + (status = 200, body = RepoResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Repo not found"), + ), + security(("session" = [])) +)] +pub async fn update_repo( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_update(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/archive", + params(WkRepoPath), + responses( + (status = 200, body = RepoResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security(("session" = [])) +)] +pub async fn archive_repo( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.repo_archive(&session, &wk, &repo).await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}", + params(WkRepoPath), + responses( + (status = 200, description = "Repo deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied - owner only"), + ), + security(("session" = [])) +)] +pub async fn delete_repo( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + service.repo_delete(&session, &wk, &repo).await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/transfer", + params(WkRepoPath), + request_body = TransferRepo, + responses( + (status = 200, body = RepoResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied - owner only"), + (status = 409, description = "Name conflict in target workspace"), + ), + security(("session" = [])) +)] +pub async fn transfer_repo( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_transfer(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/topics", + params(WkRepoPath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn get_topics( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.repo_topics(&session, &wk, &repo).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/repos/{repo}/topics", + params(WkRepoPath), + request_body = Vec, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security(("session" = [])) +)] +pub async fn update_topics( + session: Session, + service: web::Data, + path: web::Path, + topics: web::Json>, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_update_topics(&session, &wk, &repo, topics.into_inner()) + .await?; + ok_json(data) +} diff --git a/lib/api/src/git/star.rs b/lib/api/src/git/star.rs new file mode 100644 index 0000000..bb0375c --- /dev/null +++ b/lib/api/src/git/star.rs @@ -0,0 +1,63 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/star", + params(WkRepoPath), + responses((status = 200, description = "Star/unstar result")), + security(("session" = [])) +)] +pub async fn star_repo( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_star(&session, &wk, &repo).await?; + ok_json(data) +} + +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/git/star", + params(WkRepoPath), + responses((status = 200, description = "Star/unstar result")), + security(("session" = [])) +)] +pub async fn unstar_repo( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_unstar(&session, &wk, &repo).await?; + ok_json(data) +} + +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/star", + params(WkRepoPath), + responses((status = 200, description = "Star status")), + security(("session" = [])) +)] +pub async fn star_status( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_star_status(&session, &wk, &repo).await?; + ok_json(data) +} diff --git a/lib/api/src/git/tag.rs b/lib/api/src/git/tag.rs new file mode 100644 index 0000000..66e96b5 --- /dev/null +++ b/lib/api/src/git/tag.rs @@ -0,0 +1,170 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{AppService, Pagination}; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoTagPath { + pub wk: String, + pub repo: String, + pub name: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct TagListQuery { + #[serde(default)] + pub summary: bool, +} +#[derive(Deserialize, utoipa::ToSchema)] +pub struct RenameTagBody { + pub new_name: String, +} +#[derive(Deserialize, utoipa::ToSchema)] +pub struct UpdateTagMessageBody { + pub message: String, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/tags", + params(WkRepoPath, Pagination, TagListQuery), + responses( + (status = 200, description = "Tag list", body = dto::TagListResponseDto), + (status = 200, description = "Tag summary", body = dto::TagSummaryResponseDto), + ), + security(("session" = [])) +)] +pub async fn list_tags( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, + query: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + if query.summary { + let data: dto::TagSummaryResponseDto = service + .git_tag_summary(&session, &wk, &repo) + .await? + .into(); + return ok_json(data); + } + let data: dto::TagListResponseDto = service + .git_tag_list(&session, &wk, &repo, pagination.into_inner()) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/tags/{name}", + params(WkRepoTagPath), + responses((status = 200, description = "Tag info", body = dto::TagInfoResponseDto)), + security(("session" = [])) +)] +pub async fn tag_info( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoTagPath { wk, repo, name } = path.into_inner(); + let data: dto::TagInfoResponseDto = service + .git_tag_info(&session, &wk, &repo, name) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/tags", + params(WkRepoPath), + request_body = Object, description = "TagInitParams { name, oid, message, tagger, force }", + responses((status = 200, description = "Tag created", body = dto::TagInitResponseDto)), + security(("session" = [])) +)] +pub async fn init_tag( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data: dto::TagInitResponseDto = service + .git_tag_init(&session, &wk, &repo, params.into_inner()) + .await? + .into(); + ok_json(data) +} +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/git/tags/{name}", + params(WkRepoTagPath), + responses((status = 200, description = "Tag deleted")), + security(("session" = [])) +)] +pub async fn delete_tag( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoTagPath { wk, repo, name } = path.into_inner(); + let params = git::rpc::proto::TagDeleteParams { name }; + let _ = service + .git_tag_delete(&session, &wk, &repo, params) + .await?; + ok_json(serde_json::json!({})) +} +#[utoipa::path( + patch, path = "/api/v1/workspace/{wk}/repos/{repo}/git/tags/{name}", + params(WkRepoTagPath), + request_body = Object, + responses( + (status = 200, description = "Tag renamed"), + (status = 200, description = "Tag message updated"), + ), + security(("session" = [])) +)] +pub async fn update_tag( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let WkRepoTagPath { wk, repo, name } = path.into_inner(); + let body = body.into_inner(); + if let Some(new_name) = body.get("new_name").and_then(|v| v.as_str()) { + let params = git::rpc::proto::TagRenameParams { + old_name: name, + new_name: new_name.to_string(), + force: false, + }; + let _ = service + .git_tag_rename(&session, &wk, &repo, params) + .await?; + return ok_json(serde_json::json!({})); + } + let message = body + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let params = git::rpc::proto::TagUpdateMessageParams { + name, + message, + tagger: None, + force: false, + }; + let data: dto::TagUpdateMessageResponseDto = service + .git_tag_update_message(&session, &wk, &repo, params) + .await? + .into(); + ok_json(data) +} diff --git a/lib/api/src/git/tree.rs b/lib/api/src/git/tree.rs new file mode 100644 index 0000000..ec25ffd --- /dev/null +++ b/lib/api/src/git/tree.rs @@ -0,0 +1,119 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; +use crate::git::dto; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoTreePath { + pub wk: String, + pub repo: String, + pub oid: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct TreeQuery { + pub path: Option, + #[serde(default)] + pub last: bool, + #[serde(default)] + pub resolve: bool, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/trees/{oid}", + params(WkRepoTreePath, TreeQuery), + responses((status = 200, description = "Tree entries", body = dto::TreeEntriesResponseDto)), + security(("session" = [])) +)] +pub async fn tree_entries( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoTreePath { wk, repo, oid } = path.into_inner(); + + if query.resolve { + let data: dto::ResolveTreeResponseDto = service + .git_resolve_tree(&session, &wk, &repo, oid) + .await? + .into(); + return ok_json(data); + } + + let base_path = query.path.clone().unwrap_or_default(); + let data: dto::TreeEntriesResponseDto = service + .git_tree_entries(&session, &wk, &repo, oid, base_path, query.last) + .await? + .into(); + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoTreeSubPath { + pub wk: String, + pub repo: String, + pub tree_oid: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct TreeEntryQuery { + pub path: String, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/trees/{tree_oid}/entries", + params(WkRepoTreeSubPath, TreeEntryQuery), + responses((status = 200, description = "Tree entry", body = dto::TreeEntryByPathResponseDto)), + security(("session" = [])) +)] +pub async fn tree_entry_by_path( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoTreeSubPath { wk, repo, tree_oid } = path.into_inner(); + let data: dto::TreeEntryByPathResponseDto = service + .git_tree_entry_by_path(&session, &wk, &repo, tree_oid, query.path.clone()) + .await? + .into(); + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoCommitPath { + pub wk: String, + pub repo: String, + pub oid: String, +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/{oid}/tree", + params(WkRepoCommitPath, TreeEntryQuery), + responses((status = 200, description = "Tree entry", body = dto::TreeEntryByPathResponseDto)), + security(("session" = [])) +)] +pub async fn tree_entry_by_path_from_commit( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let WkRepoCommitPath { wk, repo, oid } = path.into_inner(); + let data: dto::TreeEntryByPathResponseDto = service + .git_tree_entry_by_path_from_commit(&session, &wk, &repo, oid, query.path.clone()) + .await? + .into(); + ok_json(data) +} diff --git a/lib/api/src/git/watch.rs b/lib/api/src/git/watch.rs new file mode 100644 index 0000000..c01cd96 --- /dev/null +++ b/lib/api/src/git/watch.rs @@ -0,0 +1,63 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::AppService; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WkRepoPath { + pub wk: String, + pub repo: String, +} +#[utoipa::path( + post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/watch", + params(WkRepoPath), request_body = Object, description = "WatchLevel {level: String}", + responses((status = 200, description = "Watch result")), + security(("session" = [])) +)] +pub async fn watch_repo( + session: Session, + service: web::Data, + path: web::Path, + body: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let level = body.get("level").and_then(|v| v.as_str()).map(String::from); + let data = service.git_repo_watch(&session, &wk, &repo, level).await?; + ok_json(data) +} +#[utoipa::path( + delete, path = "/api/v1/workspace/{wk}/repos/{repo}/git/watch", + params(WkRepoPath), + responses((status = 200, description = "Watch result")), + security(("session" = [])) +)] +pub async fn unwatch_repo( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_unwatch(&session, &wk, &repo).await?; + ok_json(data) +} +#[utoipa::path( + get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/watch", + params(WkRepoPath), + responses((status = 200, description = "Watch status")), + security(("session" = [])) +)] +pub async fn watch_status( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service.git_repo_watch_status(&session, &wk, &repo).await?; + ok_json(data) +} diff --git a/lib/api/src/git/webhook.rs b/lib/api/src/git/webhook.rs new file mode 100644 index 0000000..40b0995 --- /dev/null +++ b/lib/api/src/git/webhook.rs @@ -0,0 +1,168 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, Pagination, + git::webhook::{ + CreateWebhook, UpdateWebhook, WebhookDeliveryResponse, WebhookResponse, + }, +}; +use session::Session; + +use crate::error::ApiError; + +use super::repo::WkRepoPath; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct WebhookPath { + pub wk: String, + pub repo: String, + pub webhook_id: uuid::Uuid, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/webhooks", + params(WkRepoPath, Pagination), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_webhooks( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_webhook_list(&session, &wk, &repo, pagination.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/webhooks", + params(WkRepoPath), + request_body = CreateWebhook, + responses( + (status = 200, body = WebhookResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security(("session" = [])) +)] +pub async fn create_webhook( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WkRepoPath { wk, repo } = path.into_inner(); + let data = service + .repo_webhook_create(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/repos/{repo}/webhooks/{webhook_id}", + params(WebhookPath), + request_body = UpdateWebhook, + responses( + (status = 200, body = WebhookResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Webhook not found"), + ), + security(("session" = [])) +)] +pub async fn update_webhook( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let WebhookPath { + wk, + repo, + webhook_id, + } = path.into_inner(); + let data = service + .repo_webhook_update( + &session, + &wk, + &repo, + webhook_id, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/webhooks/{webhook_id}", + params(WebhookPath), + responses( + (status = 200, description = "Webhook deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Webhook not found"), + ), + security(("session" = [])) +)] +pub async fn delete_webhook( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let WebhookPath { + wk, + repo, + webhook_id, + } = path.into_inner(); + service + .repo_webhook_delete(&session, &wk, &repo, webhook_id) + .await?; + ok() +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/webhooks/{webhook_id}/deliveries", + params(WebhookPath, Pagination), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_deliveries( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + let WebhookPath { + wk, + repo, + webhook_id, + } = path.into_inner(); + let data = service + .repo_webhook_deliveries( + &session, + &wk, + &repo, + webhook_id, + pagination.into_inner(), + ) + .await?; + ok_json(data) +} diff --git a/lib/api/src/issues/assignee.rs b/lib/api/src/issues/assignee.rs new file mode 100644 index 0000000..2822ff5 --- /dev/null +++ b/lib/api/src/issues/assignee.rs @@ -0,0 +1,75 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + issues::{assignee::AssignIssueUser, types::IssueAuthor}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct AssigneePath { + pub wk: String, + pub number: i64, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/assignees", + params(AssigneePath), + request_body = AssignIssueUser, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn assign_user( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let AssigneePath { wk, number } = path.into_inner(); + let data = service + .issue_assign(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/assignees", + params(AssigneePath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn unassign_user( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let AssigneePath { wk, number } = path.into_inner(); + let data = service + .issue_unassign(&session, &wk, number, &query.username) + .await?; + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct UnassignQuery { + pub username: String, +} diff --git a/lib/api/src/issues/binding.rs b/lib/api/src/issues/binding.rs new file mode 100644 index 0000000..7a2cdb9 --- /dev/null +++ b/lib/api/src/issues/binding.rs @@ -0,0 +1,141 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + issues::{ + binding::{BindIssuePullRequest, BindIssueRepo}, + types::{IssuePullRequestResponse, IssueRepoResponse}, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/repos", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + request_body = BindIssueRepo, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or repo not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn bind_repo( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + params: web::Json, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_bind_repo(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/repos", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or repo not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn unbind_repo( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + query: web::Query, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_unbind_repo(&session, &wk, number, query.repo_id) + .await?; + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct UnbindRepoQuery { + pub repo_id: Uuid, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/pull-requests", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + request_body = BindIssuePullRequest, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or PR not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn bind_pull_request( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + params: web::Json, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_bind_pull_request(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/pull-requests", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or PR not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn unbind_pull_request( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + query: web::Query, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_unbind_pull_request(&session, &wk, number, query.pull_request_id) + .await?; + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct UnbindPrQuery { + pub pull_request_id: Uuid, +} diff --git a/lib/api/src/issues/comment.rs b/lib/api/src/issues/comment.rs new file mode 100644 index 0000000..05f06c0 --- /dev/null +++ b/lib/api/src/issues/comment.rs @@ -0,0 +1,145 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + issues::{ + comment::{CreateComment, UpdateComment}, + types::IssueCommentResponse, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct CommentPath { + pub wk: String, + pub number: i64, + pub comment_id: Uuid, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/comments", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + request_body = CreateComment, + responses( + (status = 200, body = IssueCommentResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn create_comment( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + params: web::Json, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_comment_create(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/issues/{number}/comments", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_comments( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service.issue_comment_list(&session, &wk, number).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/issues/{number}/comments/{comment_id}", + params(CommentPath), + request_body = UpdateComment, + responses( + (status = 200, body = IssueCommentResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Comment not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_comment( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let CommentPath { + wk, + number, + comment_id, + } = path.into_inner(); + let data = service + .issue_comment_update( + &session, + &wk, + number, + comment_id, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/comments/{comment_id}", + params(CommentPath), + responses( + (status = 200, description = "Comment deleted"), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Comment not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn delete_comment( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let CommentPath { + wk, + number, + comment_id, + } = path.into_inner(); + service + .issue_comment_delete(&session, &wk, number, comment_id) + .await?; + ok() +} diff --git a/lib/api/src/issues/event.rs b/lib/api/src/issues/event.rs new file mode 100644 index 0000000..d62af73 --- /dev/null +++ b/lib/api/src/issues/event.rs @@ -0,0 +1,35 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, issues::types::IssueEventResponse}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/issues/{number}/events", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_events( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service.issue_events(&session, &wk, number).await?; + ok_json(data) +} diff --git a/lib/api/src/issues/issue.rs b/lib/api/src/issues/issue.rs new file mode 100644 index 0000000..67b27f6 --- /dev/null +++ b/lib/api/src/issues/issue.rs @@ -0,0 +1,198 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, Pagination, + issues::{ + issue::{CreateIssue, UpdateIssue}, + types::{IssueFilter, IssueResponse}, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct IssuePath { + pub wk: String, + pub number: i64, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = CreateIssue, + responses( + (status = 200, body = IssueResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Workspace not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn create_issue( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .issue_create(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/issues", + params( + ("wk" = String, Path, description = "Workspace name"), + IssueFilter, + Pagination, + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_issues( + session: Session, + service: web::Data, + path: web::Path, + filter: web::Query, + pagination: web::Query, +) -> Result { + let wk = path.into_inner(); + let data = service + .issue_list(&session, &wk, filter.into_inner(), pagination.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/issues/{number}", + params(IssuePath), + responses( + (status = 200, body = IssueResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn get_issue( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let IssuePath { wk, number } = path.into_inner(); + let data = service.issue_get(&session, &wk, number).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/issues/{number}", + params(IssuePath), + request_body = UpdateIssue, + responses( + (status = 200, body = IssueResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_issue( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let IssuePath { wk, number } = path.into_inner(); + let data = service + .issue_update(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}", + params(IssuePath), + responses( + (status = 200, description = "Issue deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn delete_issue( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let IssuePath { wk, number } = path.into_inner(); + service.issue_delete(&session, &wk, number).await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/close", + params(IssuePath), + responses( + (status = 200, body = IssueResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn close_issue( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let IssuePath { wk, number } = path.into_inner(); + let data = service.issue_close(&session, &wk, number).await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/reopen", + params(IssuePath), + responses( + (status = 200, body = IssueResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn reopen_issue( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let IssuePath { wk, number } = path.into_inner(); + let data = service.issue_reopen(&session, &wk, number).await?; + ok_json(data) +} diff --git a/lib/api/src/issues/label.rs b/lib/api/src/issues/label.rs new file mode 100644 index 0000000..f6f034a --- /dev/null +++ b/lib/api/src/issues/label.rs @@ -0,0 +1,190 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + issues::{ + label::{AddIssueLabel, CreateLabel, UpdateLabel}, + types::LabelResponse, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct LabelIdPath { + pub wk: String, + pub label_id: Uuid, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/labels", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = CreateLabel, + responses( + (status = 200, body = LabelResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn create_label( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .label_create(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/labels", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_labels( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let data = service.label_list(&session, &wk).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/labels/{label_id}", + params(LabelIdPath), + request_body = UpdateLabel, + responses( + (status = 200, body = LabelResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Label not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_label( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let LabelIdPath { wk, label_id } = path.into_inner(); + let data = service + .label_update(&session, &wk, label_id, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/labels/{label_id}", + params(LabelIdPath), + responses( + (status = 200, description = "Label deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Label not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn delete_label( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let LabelIdPath { wk, label_id } = path.into_inner(); + service.label_delete(&session, &wk, label_id).await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/labels", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + request_body = AddIssueLabel, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or label not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn add_issue_label( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + params: web::Json, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_add_label(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/labels", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or label not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn remove_issue_label( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + query: web::Query, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_remove_label(&session, &wk, number, query.label_id) + .await?; + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct RemoveLabelQuery { + pub label_id: Uuid, +} diff --git a/lib/api/src/issues/milestone.rs b/lib/api/src/issues/milestone.rs new file mode 100644 index 0000000..8b4b790 --- /dev/null +++ b/lib/api/src/issues/milestone.rs @@ -0,0 +1,184 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + issues::{ + milestone::{CreateMilestone, SetIssueMilestone, UpdateMilestone}, + types::MilestoneResponse, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct MilestoneIdPath { + pub wk: String, + pub milestone_id: Uuid, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/milestones", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = CreateMilestone, + responses( + (status = 200, body = MilestoneResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn create_milestone( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .milestone_create(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/milestones", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_milestones( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let data = service.milestone_list(&session, &wk).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/milestones/{milestone_id}", + params(MilestoneIdPath), + request_body = UpdateMilestone, + responses( + (status = 200, body = MilestoneResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Milestone not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_milestone( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let MilestoneIdPath { wk, milestone_id } = path.into_inner(); + let data = service + .milestone_update(&session, &wk, milestone_id, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/milestones/{milestone_id}", + params(MilestoneIdPath), + responses( + (status = 200, description = "Milestone deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Milestone not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn delete_milestone( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let MilestoneIdPath { wk, milestone_id } = path.into_inner(); + service + .milestone_delete(&session, &wk, milestone_id) + .await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/milestone", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + request_body = SetIssueMilestone, + responses( + (status = 200, body = Option), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue or milestone not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn set_issue_milestone( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + params: web::Json, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_set_milestone(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/milestone", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, description = "Milestone cleared"), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn clear_issue_milestone( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, +) -> Result { + let (wk, number) = path.into_inner(); + service.issue_clear_milestone(&session, &wk, number).await?; + ok() +} diff --git a/lib/api/src/issues/mod.rs b/lib/api/src/issues/mod.rs new file mode 100644 index 0000000..78a14d3 --- /dev/null +++ b/lib/api/src/issues/mod.rs @@ -0,0 +1,104 @@ +pub mod assignee; +pub mod binding; +pub mod comment; +pub mod event; +pub mod issue; +pub mod label; +pub mod milestone; +pub mod reaction; + +use actix_web::{web, web::ServiceConfig}; +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("") + .route(web::post().to(issue::create_issue)) + .route(web::get().to(issue::list_issues)), + ); + cfg.service( + web::resource("/{number}") + .route(web::get().to(issue::get_issue)) + .route(web::put().to(issue::update_issue)) + .route(web::delete().to(issue::delete_issue)), + ); + cfg.service( + web::resource("/{number}/close") + .route(web::post().to(issue::close_issue)), + ); + cfg.service( + web::resource("/{number}/reopen") + .route(web::post().to(issue::reopen_issue)), + ); + cfg.service( + web::resource("/{number}/comments") + .route(web::get().to(comment::list_comments)) + .route(web::post().to(comment::create_comment)), + ); + cfg.service( + web::resource("/{number}/comments/{comment_id}") + .route(web::put().to(comment::update_comment)) + .route(web::delete().to(comment::delete_comment)), + ); + cfg.service( + web::resource("/{number}/assignees") + .route(web::post().to(assignee::assign_user)) + .route(web::delete().to(assignee::unassign_user)), + ); + cfg.service( + web::resource("/{number}/labels") + .route(web::post().to(label::add_issue_label)) + .route(web::delete().to(label::remove_issue_label)), + ); + cfg.service( + web::resource("/{number}/milestone") + .route(web::post().to(milestone::set_issue_milestone)) + .route(web::delete().to(milestone::clear_issue_milestone)), + ); + cfg.service( + web::resource("/{number}/repos") + .route(web::post().to(binding::bind_repo)) + .route(web::delete().to(binding::unbind_repo)), + ); + cfg.service( + web::resource("/{number}/pull-requests") + .route(web::post().to(binding::bind_pull_request)) + .route(web::delete().to(binding::unbind_pull_request)), + ); + cfg.service( + web::resource("/{number}/reactions") + .route(web::post().to(reaction::add_reaction)) + .route(web::delete().to(reaction::remove_reaction)), + ); + cfg.service( + web::resource("/{number}/comments/{comment_id}/reactions") + .route(web::post().to(reaction::add_comment_reaction)) + .route(web::delete().to(reaction::remove_comment_reaction)), + ); + cfg.service( + web::resource("/{number}/events") + .route(web::get().to(event::list_events)), + ); +} +pub fn configure_labels(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("") + .route(web::get().to(label::list_labels)) + .route(web::post().to(label::create_label)), + ); + cfg.service( + web::resource("/{label_id}") + .route(web::put().to(label::update_label)) + .route(web::delete().to(label::delete_label)), + ); +} +pub fn configure_milestones(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("") + .route(web::get().to(milestone::list_milestones)) + .route(web::post().to(milestone::create_milestone)), + ); + cfg.service( + web::resource("/{milestone_id}") + .route(web::put().to(milestone::update_milestone)) + .route(web::delete().to(milestone::delete_milestone)), + ); +} diff --git a/lib/api/src/issues/reaction.rs b/lib/api/src/issues/reaction.rs new file mode 100644 index 0000000..ab2859c --- /dev/null +++ b/lib/api/src/issues/reaction.rs @@ -0,0 +1,152 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + issues::{reaction::AddReaction, types::IssueReactionResponse}, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize)] +pub struct ReactionPath { + pub wk: String, + pub number: i64, + pub comment_id: Option, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/reactions", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + request_body = AddReaction, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Issue not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn add_reaction( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + params: web::Json, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_add_reaction(&session, &wk, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/reactions", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn remove_reaction( + session: Session, + service: web::Data, + path: web::Path<(String, i64)>, + query: web::Query, +) -> Result { + let (wk, number) = path.into_inner(); + let data = service + .issue_remove_reaction(&session, &wk, number, &query.reaction) + .await?; + ok_json(data) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct RemoveReactionQuery { + pub reaction: String, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/issues/{number}/comments/{comment_id}/reactions", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ("comment_id" = Uuid, Path, description = "Comment ID"), + ), + request_body = AddReaction, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Comment not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn add_comment_reaction( + session: Session, + service: web::Data, + path: web::Path<(String, i64, Uuid)>, + params: web::Json, +) -> Result { + let (wk, number, comment_id) = path.into_inner(); + let data = service + .issue_comment_add_reaction( + &session, + &wk, + number, + comment_id, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/issues/{number}/comments/{comment_id}/reactions", + params( + ("wk" = String, Path, description = "Workspace name"), + ("number" = i64, Path, description = "Issue number"), + ("comment_id" = Uuid, Path, description = "Comment ID"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn remove_comment_reaction( + session: Session, + service: web::Data, + path: web::Path<(String, i64, Uuid)>, + query: web::Query, +) -> Result { + let (wk, number, comment_id) = path.into_inner(); + let data = service + .issue_comment_remove_reaction( + &session, + &wk, + number, + comment_id, + &query.reaction, + ) + .await?; + ok_json(data) +} diff --git a/lib/api/src/lib.rs b/lib/api/src/lib.rs new file mode 100644 index 0000000..df87920 --- /dev/null +++ b/lib/api/src/lib.rs @@ -0,0 +1,59 @@ +pub mod agent; +pub mod ai; +pub mod auth; +pub mod channel; +pub mod error; +pub mod git; +pub mod issues; +pub mod openapi; +pub mod pull_request; +pub mod search; +pub mod user; +pub mod users; +pub mod workspace; + +use actix_web::web::{self, ServiceConfig}; + +pub fn configure(cfg: &mut ServiceConfig, channel_bus: channel::ChannelBus) { + cfg.service( + web::scope("/api/v1") + .configure(auth::configure) + .configure(user::configure) + .configure(users::configure) + .configure(ai::configure) + .configure(agent::configure) + .service( + web::scope("/workspace") + .configure(workspace::configure) + .service( + web::scope("/{wk}") + .configure(workspace::configure_wk) + .service( + web::scope("/repos") + .configure(git::configure) + .configure(pull_request::configure) + ) + .service( + web::scope("/issues") + .configure(issues::configure) + ) + .service( + web::scope("/labels") + .configure(issues::configure_labels) + ) + .service( + web::scope("/milestones") + .configure(issues::configure_milestones) + ) + ) + ) + .service( + web::scope("/ws") + .configure(|cfg| channel::configure(cfg, channel_bus)), + ) + .service( + web::resource("/search") + .route(web::get().to(search::search)), + ) + ); +} diff --git a/lib/api/src/openapi.rs b/lib/api/src/openapi.rs new file mode 100644 index 0000000..da4109c --- /dev/null +++ b/lib/api/src/openapi.rs @@ -0,0 +1,395 @@ +use utoipa::Modify; +use utoipa::OpenApi; +use utoipa::openapi::security::{ + ApiKey, ApiKeyValue, SecurityRequirement, SecurityScheme, +}; + +#[derive(utoipa::OpenApi)] +#[openapi( + info( + title = "GitDataAI API", + version = "1.0.0", + description = "GitDataAI platform REST API" + ), + tags( + (name = "auth", description = "Authentication & 2FA"), + (name = "user", description = "Current user settings & profile"), + (name = "users", description = "Public user info & relationships"), + (name = "workspace", description = "Workspace management"), + (name = "issues", description = "Issue tracking"), + (name = "pull_request", description = "Pull request operations"), + (name = "git", description = "Git repository operations"), + (name = "repos", description = "Repo management & configuration"), + (name = "ai", description = "AI model catalog"), + (name = "agent", description = "AI Agent sessions, conversations & streaming"), + (name = "channel", description = "Channel WebSocket & REST fallback"), + (name = "search", description = "Global search across workspaces, repos, channels, issues") + ), + paths( + crate::auth::captcha::captcha, + crate::auth::login::login, + crate::auth::logout::logout, + crate::auth::me::me, + crate::auth::register::register, + crate::auth::reset_pass::reset_password_request, + crate::auth::reset_pass::reset_password_verify, + crate::auth::rsa::rsa, + crate::auth::totp::enable_2fa, + crate::auth::totp::verify_2fa, + crate::auth::totp::disable_2fa, + crate::auth::totp::status_2fa, + crate::auth::totp::regenerate_backup_codes, + crate::auth::email::get_email, + crate::auth::email::email_change_request, + crate::auth::email::email_verify, + crate::user::config::user_config, + crate::user::accessibility::update_accessibility, + crate::user::appearance::update_appearance, + crate::user::notification::update_notification, + crate::user::privacy::update_privacy, + crate::user::profile::update_profile, + crate::user::profile::upload_avatar, + crate::user::access_token::list_access_tokens, + crate::user::access_token::create_access_token, + crate::user::access_token::update_access_token, + crate::user::access_token::revoke_access_token, + crate::user::sshkey::list_ssh_keys, + crate::user::sshkey::add_ssh_key, + crate::user::sshkey::update_ssh_key, + crate::user::sshkey::revoke_ssh_key, + crate::user::chpc::contribution_heatmap, + crate::user::chpc::invalidate_chpc_cache, + crate::users::summary::user_summary, + crate::users::public::user_public, + crate::users::chpc::user_chpc, + crate::users::relation::follow_user, + crate::users::relation::unfollow_user, + crate::users::relation::block_user, + crate::users::relation::unblock_user, + crate::users::relation::relation_status, + crate::users::relation::followers, + crate::users::relation::following, + crate::users::relation::blocked_list, + crate::users::relation::relation_counts, + crate::users::avatar::user_avatar, + crate::workspace::workspace::create_workspace, + crate::workspace::workspace::my_workspaces, + crate::workspace::workspace::get_workspace, + crate::workspace::workspace::update_workspace, + crate::workspace::workspace::get_avatar, + crate::workspace::workspace::upload_avatar, + crate::workspace::member::list_members, + crate::workspace::member::add_member, + crate::workspace::member::update_member, + crate::workspace::member::remove_member, + crate::workspace::group::list_groups, + crate::workspace::group::create_group, + crate::workspace::group::update_group, + crate::workspace::group::delete_group, + crate::workspace::group::add_group_member, + crate::workspace::group::remove_group_member, + crate::workspace::group::list_group_members, + crate::workspace::join::join_strategy, + crate::workspace::join::update_join_strategy, + crate::workspace::join::my_join_applies, + crate::workspace::join::apply_join, + crate::workspace::join::cancel_join, + crate::workspace::join::list_join_applies, + crate::workspace::join::approve_join, + crate::issues::issue::create_issue, + crate::issues::issue::list_issues, + crate::issues::issue::get_issue, + crate::issues::issue::update_issue, + crate::issues::issue::delete_issue, + crate::issues::issue::close_issue, + crate::issues::issue::reopen_issue, + crate::issues::comment::list_comments, + crate::issues::comment::create_comment, + crate::issues::comment::update_comment, + crate::issues::comment::delete_comment, + crate::issues::assignee::assign_user, + crate::issues::assignee::unassign_user, + crate::issues::label::add_issue_label, + crate::issues::label::remove_issue_label, + crate::issues::label::list_labels, + crate::issues::label::create_label, + crate::issues::label::update_label, + crate::issues::label::delete_label, + crate::issues::milestone::set_issue_milestone, + crate::issues::milestone::clear_issue_milestone, + crate::issues::milestone::list_milestones, + crate::issues::milestone::create_milestone, + crate::issues::milestone::update_milestone, + crate::issues::milestone::delete_milestone, + crate::issues::binding::bind_repo, + crate::issues::binding::unbind_repo, + crate::issues::binding::bind_pull_request, + crate::issues::binding::unbind_pull_request, + crate::issues::reaction::add_reaction, + crate::issues::reaction::remove_reaction, + crate::issues::reaction::add_comment_reaction, + crate::issues::reaction::remove_comment_reaction, + crate::issues::event::list_events, + crate::pull_request::pull_request::list_prs, + crate::pull_request::pull_request::create_pr, + crate::pull_request::pull_request::get_pr, + crate::pull_request::pull_request::update_pr, + crate::pull_request::pull_request::delete_pr, + crate::pull_request::merge::merge_pr, + crate::pull_request::merge::merge_analysis, + crate::pull_request::merge::update_branch, + crate::pull_request::assignee::assign_user, + crate::pull_request::assignee::unassign_user, + crate::pull_request::comment::list_comments, + crate::pull_request::comment::create_comment, + crate::pull_request::comment::update_comment, + crate::pull_request::comment::delete_comment, + crate::pull_request::label::add_label, + crate::pull_request::label::remove_label, + crate::pull_request::review::list_reviews, + crate::pull_request::review::create_review, + crate::pull_request::review::dismiss_review, + crate::pull_request::review::create_review_comment, + crate::pull_request::reaction::add_reaction, + crate::pull_request::reaction::remove_reaction, + crate::pull_request::reaction::add_comment_reaction, + crate::pull_request::reaction::remove_comment_reaction, + crate::git::init::create_repo, + crate::git::init::clone_repo, + crate::git::repo::list_repos, + crate::git::repo::get_repo, + crate::git::repo::update_repo, + crate::git::repo::delete_repo, + crate::git::repo::archive_repo, + crate::git::repo::transfer_repo, + crate::git::repo::get_topics, + crate::git::repo::update_topics, + crate::git::fork::create_fork, + crate::git::fork::list_forks, + crate::git::protect::list_protects, + crate::git::protect::create_protect, + crate::git::protect::update_protect, + crate::git::protect::delete_protect, + crate::git::webhook::list_webhooks, + crate::git::webhook::create_webhook, + crate::git::webhook::update_webhook, + crate::git::webhook::delete_webhook, + crate::git::webhook::list_deliveries, + crate::git::branch::list_branches, + crate::git::branch::fork_branch, + crate::git::branch::branch_info, + crate::git::branch::ahead_behind, + crate::git::branch::branch_upstream, + crate::git::branch::delete_branch, + crate::git::branch::rename_branch, + crate::git::commit::list_commits, + crate::git::commit::commit_info, + crate::git::commit::commit_history, + crate::git::commit::commit_walk, + crate::git::commit::cherry_pick, + crate::git::blob::blob_info, + crate::git::blob::blob_upload, + crate::git::blame::blame_file, + crate::git::tree::tree_entries, + crate::git::tree::tree_entry_by_path, + crate::git::tree::tree_entry_by_path_from_commit, + crate::git::diff::diff, + crate::git::tag::list_tags, + crate::git::tag::init_tag, + crate::git::tag::tag_info, + crate::git::tag::delete_tag, + crate::git::tag::update_tag, + crate::git::archive::archive, + crate::git::star::star_repo, + crate::git::star::unstar_repo, + crate::git::star::star_status, + crate::git::watch::watch_repo, + crate::git::watch::unwatch_repo, + crate::git::watch::watch_status, + crate::git::contributor::list_contributors, + crate::git::language::get_languages, + crate::git::readme::get_readme, + crate::git::readme::diff_branches, + crate::git::refs::list_refs, + crate::git::contents::get_contents, + crate::git::contents::create_contents, + crate::git::contents::update_contents, + crate::git::contents::delete_contents, + crate::git::compare::compare, + crate::git::release::list_releases, + crate::git::release::get_release, + crate::git::release::get_release_by_tag, + crate::git::release::create_release, + crate::git::release::update_release, + crate::git::release::delete_release, + crate::git::release::delete_release_by_tag, + crate::git::commit_status::list_statuses, + crate::git::commit_status::combined_status, + crate::git::commit_status::create_status, + crate::ai::model::list_models, + crate::ai::model::get_model, + crate::ai::model::list_versions, + crate::ai::model::get_card, + crate::ai::model::list_tags, + crate::ai::model::list_discussions, + crate::ai::model::list_likes, + crate::ai::provider::list_providers, + crate::ai::provider::get_provider, + crate::agent::session::create_session, + crate::agent::session::list_sessions, + crate::agent::session::get_session, + crate::agent::session::update_session, + crate::agent::session::delete_session, + crate::agent::conversation::list_all_conversations, + crate::agent::conversation::create_conversation, + crate::agent::conversation::list_conversations, + crate::agent::conversation::get_conversation, + crate::agent::conversation::update_conversation, + crate::agent::conversation::delete_conversation, + crate::agent::conversation::list_messages, + crate::agent::conversation::send_message, + crate::agent::conversation::stream_agent, + crate::channel::token::generate_token, + crate::channel::rest::ping, + crate::channel::rest::csrf_token, + crate::channel::rest_message::create_message, + crate::channel::rest_message::update_message, + crate::channel::rest_message::revoke_message, + crate::channel::rest_message::list_messages, + crate::channel::rest_message::messages_around, + crate::channel::rest_message::missed_messages, + crate::channel::rest_message::search, + crate::channel::rest_room::subscribe, + crate::channel::rest_room::unsubscribe, + crate::channel::rest_room::room_get, + crate::channel::rest_room::room_create, + crate::channel::rest_room::room_update, + crate::channel::rest_room::room_delete, + crate::channel::rest_room::access_grant, + crate::channel::rest_room::access_revoke, + crate::channel::rest_room::category_create, + crate::channel::rest_room::category_update, + crate::channel::rest_room::category_delete, + crate::channel::rest_interact::reaction_add, + crate::channel::rest_interact::reaction_remove, + crate::channel::rest_interact::thread_create, + crate::channel::rest_interact::thread_resolve, + crate::channel::rest_interact::thread_archive, + crate::channel::rest_interact::pin_add, + crate::channel::rest_interact::pin_remove, + crate::channel::rest_interact::draft_save, + crate::channel::rest_interact::draft_clear, + crate::channel::rest_interact::typing, + crate::channel::rest_member::read_receipt, + crate::channel::rest_member::dnd_update, + crate::channel::rest_member::notification_mark_read, + crate::channel::rest_member::notification_mark_all_read, + crate::channel::rest_member::notification_archive, + crate::channel::rest_member::presence_update, + crate::channel::rest_member::custom_status_update, + crate::channel::rest_member::invite_create, + crate::channel::rest_member::invite_accept, + crate::channel::rest_member::invite_revoke, + crate::channel::rest_member::ban_create, + crate::channel::rest_member::ban_remove, + crate::channel::rest_voice::voice_join, + crate::channel::rest_voice::voice_leave, + crate::channel::rest_voice::voice_mute, + crate::channel::rest_voice::voice_deaf, + crate::channel::rest_voice::screen_share, + crate::channel::rest_ai::ai_list, + crate::channel::rest_ai::ai_add, + crate::channel::rest_ai::ai_remove, + crate::channel::rest_ai::ai_stop, + crate::channel::rest_ai::user_summary, + crate::search::search, + ), + modifiers(&SecurityAddon) +)] +pub struct ApiDoc; + +struct SecurityAddon; + +impl Modify for SecurityAddon { + fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { + if let Some(components) = openapi.components.as_mut() { + components.add_security_scheme( + "session", + SecurityScheme::ApiKey(ApiKey::Cookie(ApiKeyValue::new("id"))), + ); + } + openapi.security = Some(vec![SecurityRequirement::new( + "session", + Vec::::new(), + )]); + openapi.info.license = None; + let remap = |t: &str| -> String { + if t.starts_with("crate::") { + let parts: Vec<&str> = t.split("::").collect(); + parts.get(1).unwrap_or(&t).to_string() + } else { + t.to_string() + } + }; + fn clean_tags( + op: &mut utoipa::openapi::path::Operation, + remap: impl Fn(&str) -> String, + ) { + if let Some(tags) = op.tags.as_mut() { + for t in tags.iter_mut() { + *t = remap(t); + } + tags.dedup(); + } + } + fn prefix_operation_id(op: &mut utoipa::openapi::path::Operation) { + if let Some(tags) = &op.tags { + if let Some(tag) = tags.first() { + if let Some(id) = op.operation_id.as_mut() { + if !id.starts_with(tag) { + *id = format!("{}_{}", tag, id); + } + } + } + } + } + for (_path, item) in openapi.paths.paths.iter_mut() { + if let Some(op) = item.get.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.put.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.post.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.delete.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.patch.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.head.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.options.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + if let Some(op) = item.trace.as_mut() { + clean_tags(op, remap); + prefix_operation_id(op); + } + } + } +} + +pub fn openapi_json() -> String { + ApiDoc::openapi().to_pretty_json().unwrap() +} diff --git a/lib/api/src/pull_request/assignee.rs b/lib/api/src/pull_request/assignee.rs new file mode 100644 index 0000000..f848bc7 --- /dev/null +++ b/lib/api/src/pull_request/assignee.rs @@ -0,0 +1,72 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, issues::types::IssueAuthor, + pull_request::assignee::AssignPrUser, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct UnassignQuery { + pub username: String, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/assignees", + params(PrPath), + request_body = AssignPrUser, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn assign_user( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_assign(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/assignees", + params(PrPath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn unassign_user( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_unassign(&session, &wk, &repo, number, &query.username) + .await?; + ok_json(data) +} diff --git a/lib/api/src/pull_request/comment.rs b/lib/api/src/pull_request/comment.rs new file mode 100644 index 0000000..367a6fb --- /dev/null +++ b/lib/api/src/pull_request/comment.rs @@ -0,0 +1,145 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + pull_request::{ + comment::{CreatePrComment, UpdatePrComment}, + types::PullRequestCommentResponse, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct CommentPath { + pub wk: String, + pub repo: String, + pub number: i64, + pub comment_id: Uuid, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/comments", + params(PrPath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_comments( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_comment_list(&session, &wk, &repo, number) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/comments", + params(PrPath), + request_body = CreatePrComment, + responses( + (status = 200, body = PullRequestCommentResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn create_comment( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_comment_create(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/comments/{comment_id}", + params(CommentPath), + request_body = UpdatePrComment, + responses( + (status = 200, body = PullRequestCommentResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Only author can update"), + (status = 404, description = "Comment not found"), + ), + security(("session" = [])) +)] +pub async fn update_comment( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let CommentPath { + wk, + repo, + number, + comment_id, + } = path.into_inner(); + let data = service + .pr_comment_update( + &session, + &wk, + &repo, + number, + comment_id, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/comments/{comment_id}", + params(CommentPath), + responses( + (status = 200, description = "Comment deleted"), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Comment not found"), + ), + security(("session" = [])) +)] +pub async fn delete_comment( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let CommentPath { + wk, + repo, + number: _, + comment_id, + } = path.into_inner(); + service + .pr_comment_delete(&session, &wk, &repo, comment_id) + .await?; + ok() +} diff --git a/lib/api/src/pull_request/label.rs b/lib/api/src/pull_request/label.rs new file mode 100644 index 0000000..0d758bd --- /dev/null +++ b/lib/api/src/pull_request/label.rs @@ -0,0 +1,72 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, issues::types::LabelResponse, pull_request::label::AddPrLabel, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct RemoveLabelQuery { + pub label_id: Uuid, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/labels", + params(PrPath), + request_body = AddPrLabel, + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR or label not found"), + ), + security(("session" = [])) +)] +pub async fn add_label( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_add_label(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/labels", + params(PrPath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR or label not found"), + ), + security(("session" = [])) +)] +pub async fn remove_label( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_remove_label(&session, &wk, &repo, number, query.label_id) + .await?; + ok_json(data) +} diff --git a/lib/api/src/pull_request/merge.rs b/lib/api/src/pull_request/merge.rs new file mode 100644 index 0000000..eb1933a --- /dev/null +++ b/lib/api/src/pull_request/merge.rs @@ -0,0 +1,89 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + pull_request::{merge::MergePullRequest, types::PullRequestResponse}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/merge", + params(PrPath), + responses( + (status = 200, description = "Merge analysis result"), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn merge_analysis( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_merge_analysis(&session, &wk, &repo, number) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/merge", + params(PrPath), + request_body = MergePullRequest, + responses( + (status = 200, body = PullRequestResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn merge_pr( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_merge(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/update-branch", + params(PrPath), + responses( + (status = 200, body = PullRequestResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn update_branch( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_update_branch(&session, &wk, &repo, number) + .await?; + ok_json(data) +} diff --git a/lib/api/src/pull_request/mod.rs b/lib/api/src/pull_request/mod.rs new file mode 100644 index 0000000..5d84f91 --- /dev/null +++ b/lib/api/src/pull_request/mod.rs @@ -0,0 +1,78 @@ +pub mod assignee; +pub mod comment; +pub mod label; +pub mod merge; +pub mod pull_request; +pub mod reaction; +pub mod review; + +use actix_web::{web, web::ServiceConfig}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::scope("/{repo}/pull-requests") + .service( + web::resource("") + .route(web::get().to(pull_request::list_prs)) + .route(web::post().to(pull_request::create_pr)), + ) + .service( + web::resource("/{number}") + .route(web::get().to(pull_request::get_pr)) + .route(web::patch().to(pull_request::update_pr)) + .route(web::delete().to(pull_request::delete_pr)), + ) + .service( + web::resource("/{number}/merge") + .route(web::get().to(merge::merge_analysis)) + .route(web::post().to(merge::merge_pr)), + ) + .service( + web::resource("/{number}/update-branch") + .route(web::post().to(merge::update_branch)), + ) + .service( + web::resource("/{number}/assignees") + .route(web::post().to(assignee::assign_user)) + .route(web::delete().to(assignee::unassign_user)), + ) + .service( + web::resource("/{number}/comments") + .route(web::get().to(comment::list_comments)) + .route(web::post().to(comment::create_comment)), + ) + .service( + web::resource("/{number}/comments/{comment_id}") + .route(web::put().to(comment::update_comment)) + .route(web::delete().to(comment::delete_comment)), + ) + .service( + web::resource("/{number}/labels") + .route(web::post().to(label::add_label)) + .route(web::delete().to(label::remove_label)), + ) + .service( + web::resource("/{number}/reviews") + .route(web::get().to(review::list_reviews)) + .route(web::post().to(review::create_review)), + ) + .service( + web::resource("/{number}/reviews/{review_id}/dismiss") + .route(web::post().to(review::dismiss_review)), + ) + .service( + web::resource("/{number}/review-comments") + .route(web::post().to(review::create_review_comment)), + ) + .service( + web::resource("/{number}/reactions") + .route(web::post().to(reaction::add_reaction)) + .route(web::delete().to(reaction::remove_reaction)), + ) + .service( + web::resource("/{number}/comments/{comment_id}/reactions") + .route(web::post().to(reaction::add_comment_reaction)) + .route(web::delete().to(reaction::remove_comment_reaction)), + ), + ); +} diff --git a/lib/api/src/pull_request/pull_request.rs b/lib/api/src/pull_request/pull_request.rs new file mode 100644 index 0000000..16b444d --- /dev/null +++ b/lib/api/src/pull_request/pull_request.rs @@ -0,0 +1,160 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, Pagination, + pull_request::{ + pull_request::{CreatePullRequest, UpdatePullRequest}, + types::{PullRequestFilter, PullRequestResponse}, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrRepoPath { + pub wk: String, + pub repo: String, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests", + params( + ("wk" = String, Path, description = "Workspace name"), + ("repo" = String, Path, description = "Repo name"), + ), + request_body = CreatePullRequest, + responses( + (status = 200, body = PullRequestResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Repo not found"), + ), + security(("session" = [])) +)] +pub async fn create_pr( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrRepoPath { wk, repo } = path.into_inner(); + let data = service + .pr_create(&session, &wk, &repo, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests", + params( + ("wk" = String, Path, description = "Workspace name"), + ("repo" = String, Path, description = "Repo name"), + PullRequestFilter, + Pagination, + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_prs( + session: Session, + service: web::Data, + path: web::Path, + filter: web::Query, + pagination: web::Query, +) -> Result { + let PrRepoPath { wk, repo } = path.into_inner(); + let data = service + .pr_list( + &session, + &wk, + &repo, + filter.into_inner(), + pagination.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}", + params(PrPath), + responses( + (status = 200, body = PullRequestResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn get_pr( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service.pr_get(&session, &wk, &repo, number).await?; + ok_json(data) +} +#[utoipa::path( + patch, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}", + params(PrPath), + request_body = UpdatePullRequest, + responses( + (status = 200, body = PullRequestResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Only author can update"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn update_pr( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_update(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}", + params(PrPath), + responses( + (status = 200, description = "PR deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn delete_pr( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + service.pr_delete(&session, &wk, &repo, number).await?; + ok() +} diff --git a/lib/api/src/pull_request/reaction.rs b/lib/api/src/pull_request/reaction.rs new file mode 100644 index 0000000..6a5ca4c --- /dev/null +++ b/lib/api/src/pull_request/reaction.rs @@ -0,0 +1,155 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + pull_request::{ + reaction::AddPrReaction, types::PullRequestReactionResponse, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct RemoveReactionQuery { + pub reaction_id: Uuid, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct CommentReactionPath { + pub wk: String, + pub repo: String, + pub number: i64, + pub comment_id: Uuid, +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/reactions", + params(PrPath), + request_body = AddPrReaction, + responses( + (status = 200, body = PullRequestReactionResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn add_reaction( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_add_reaction(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/reactions", + params(PrPath), + responses( + (status = 200, description = "Reaction removed"), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn remove_reaction( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + service + .pr_remove_reaction(&session, &wk, &repo, number, query.reaction_id) + .await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/comments/{comment_id}/reactions", + params(CommentReactionPath), + request_body = AddPrReaction, + responses( + (status = 200, body = PullRequestReactionResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Comment not found"), + ), + security(("session" = [])) +)] +pub async fn add_comment_reaction( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let CommentReactionPath { + wk, + repo, + number, + comment_id, + } = path.into_inner(); + let data = service + .pr_comment_add_reaction( + &session, + &wk, + &repo, + number, + comment_id, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/comments/{comment_id}/reactions", + params(CommentReactionPath), + responses( + (status = 200, description = "Reaction removed"), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn remove_comment_reaction( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let CommentReactionPath { + wk, + repo, + number, + comment_id: _, + } = path.into_inner(); + service + .pr_comment_remove_reaction( + &session, + &wk, + &repo, + number, + query.reaction_id, + ) + .await?; + ok() +} diff --git a/lib/api/src/pull_request/review.rs b/lib/api/src/pull_request/review.rs new file mode 100644 index 0000000..fa22156 --- /dev/null +++ b/lib/api/src/pull_request/review.rs @@ -0,0 +1,146 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + pull_request::{ + review::{CreatePrReview, CreatePrReviewComment, DismissPrReview}, + types::{PullRequestReviewCommentResponse, PullRequestReviewResponse}, + }, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct PrPath { + pub wk: String, + pub repo: String, + pub number: i64, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct DismissPath { + pub wk: String, + pub repo: String, + pub number: i64, + pub review_id: Uuid, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/reviews", + params(PrPath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security(("session" = [])) +)] +pub async fn list_reviews( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service.pr_review_list(&session, &wk, &repo, number).await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/reviews", + params(PrPath), + request_body = CreatePrReview, + responses( + (status = 200, body = PullRequestReviewResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn create_review( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_review_create(&session, &wk, &repo, number, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/reviews/{review_id}/dismiss", + params(DismissPath), + request_body = DismissPrReview, + responses( + (status = 200, description = "Review dismissed"), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Review not found"), + ), + security(("session" = [])) +)] +pub async fn dismiss_review( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let DismissPath { + wk, + repo, + number, + review_id, + } = path.into_inner(); + service + .pr_review_dismiss( + &session, + &wk, + &repo, + number, + review_id, + params.into_inner(), + ) + .await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/repos/{repo}/pull-requests/{number}/review-comments", + params(PrPath), + request_body = CreatePrReviewComment, + responses( + (status = 200, body = PullRequestReviewCommentResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "PR not found"), + ), + security(("session" = [])) +)] +pub async fn create_review_comment( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let PrPath { wk, repo, number } = path.into_inner(); + let data = service + .pr_review_comment_create( + &session, + &wk, + &repo, + number, + None, + params.into_inner(), + ) + .await?; + ok_json(data) +} diff --git a/lib/api/src/search.rs b/lib/api/src/search.rs new file mode 100644 index 0000000..2dca0cc --- /dev/null +++ b/lib/api/src/search.rs @@ -0,0 +1,259 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use utoipa::ToSchema; + +use crate::error::ApiError; +use crate::channel::ChannelBus; +use service::AppService; +use session::Session; + +#[derive(Debug, serde::Deserialize, utoipa::IntoParams)] +pub struct SearchQuery { + pub q: Option, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct SearchResponse { + pub workspaces: SearchGroup, + pub repos: SearchGroup, + pub rooms: SearchGroup, + pub issues: SearchGroup, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct SearchGroup { + pub items: Vec, + pub total: usize, + pub has_more: bool, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct WorkspaceHit { + pub name: String, + pub description: Option, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct RepoHit { + pub name: String, + pub workspace: String, + pub description: Option, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct RoomHit { + pub id: String, + pub name: String, + pub workspace: String, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct IssueHit { + pub number: i32, + pub title: String, + pub state: String, + pub workspace: String, +} + +const MAX_PER_GROUP: usize = 10; + +fn session_user(session: &Session) -> Result { + session + .user() + .ok_or(ApiError(service::error::AppError::Unauthorized)) +} + +#[utoipa::path( + get, + path = "/api/v1/search", + params(SearchQuery), + responses((status = 200, body = SearchResponse)), + security(("session" = [])), + tag = "search", +)] +pub async fn search( + session: Session, + service: web::Data, + bus: web::Data, + query: web::Query, +) -> Result { + let user_id = session_user(&session)?; + let q = query.q.as_deref().unwrap_or("").trim().to_lowercase(); + + let workspaces = search_workspaces(&service, user_id, &q).await?; + let repos = search_repos(&service, user_id, &q).await?; + let rooms = search_rooms(&bus, user_id, &q).await?; + let issues = search_issues(&service, user_id, &q).await?; + + Ok(HttpResponse::Ok().json(SearchResponse { + workspaces, + repos, + rooms, + issues, + })) +} + +async fn search_workspaces( + service: &AppService, + user_id: uuid::Uuid, + q: &str, +) -> Result, ApiError> { + let all: Vec = service + .workspace_my_inner(user_id) + .await? + .into_iter() + .map(|(name, desc)| WorkspaceHit { + name, + description: desc, + }) + .collect(); + + let filtered: Vec = if q.is_empty() { + all + } else { + all.into_iter() + .filter(|w| { + w.name.to_lowercase().contains(q) + || w.description + .as_deref() + .unwrap_or("") + .to_lowercase() + .contains(q) + }) + .collect() + }; + + Ok(group_result(filtered)) +} + +async fn search_repos( + service: &AppService, + user_id: uuid::Uuid, + q: &str, +) -> Result, ApiError> { + let workspaces: Vec = service + .workspace_my_inner(user_id) + .await? + .into_iter() + .map(|(name, _)| name) + .collect(); + + let mut all: Vec = Vec::new(); + for wk in &workspaces { + if let Ok(repos) = service.repo_list_inner(wk).await { + for (name, desc) in repos { + all.push(RepoHit { + name, + workspace: wk.clone(), + description: desc, + }); + } + } + } + + let filtered: Vec = if q.is_empty() { + all + } else { + let q = q; + all.into_iter() + .filter(|r| { + format!( + "{} {} {}", + r.workspace, + r.name, + r.description.as_deref().unwrap_or("") + ) + .to_lowercase() + .contains(q) + }) + .collect() + }; + + Ok(group_result(filtered)) +} + +async fn search_rooms( + bus: &ChannelBus, + user_id: uuid::Uuid, + q: &str, +) -> Result, ApiError> { + let rooms = bus + .list_user_rooms(user_id) + .await + .map_err(|e| { + ApiError(service::error::AppError::InternalServerError( + e.to_string(), + )) + })?; + + let all: Vec = rooms + .into_iter() + .map(|r| RoomHit { + id: r.id.to_string(), + name: r.name, + workspace: r.workspace_id.to_string(), + }) + .collect(); + + let filtered: Vec = if q.is_empty() { + all + } else { + all.into_iter() + .filter(|r| r.name.to_lowercase().contains(q)) + .collect() + }; + + Ok(group_result(filtered)) +} + +async fn search_issues( + service: &AppService, + user_id: uuid::Uuid, + q: &str, +) -> Result, ApiError> { + let workspaces: Vec = service + .workspace_my_inner(user_id) + .await? + .into_iter() + .map(|(name, _)| name) + .collect(); + + let mut all: Vec = Vec::new(); + for wk in &workspaces { + if let Ok(issues) = service.issue_list_inner(wk).await { + for (number, title, state) in issues { + all.push(IssueHit { + number, + title, + state, + workspace: wk.clone(), + }); + } + } + } + + let filtered: Vec = if q.is_empty() { + all + } else { + all.into_iter() + .filter(|i| { + format!("#{} {}", i.number, i.title) + .to_lowercase() + .contains(q) + }) + .collect() + }; + + Ok(group_result(filtered)) +} + +fn group_result(mut items: Vec) -> SearchGroup { + let total = items.len(); + let has_more = total > MAX_PER_GROUP; + items.truncate(MAX_PER_GROUP); + SearchGroup { + items, + total, + has_more, + } +} diff --git a/lib/api/src/user/access_token.rs b/lib/api/src/user/access_token.rs new file mode 100644 index 0000000..4a1c75e --- /dev/null +++ b/lib/api/src/user/access_token.rs @@ -0,0 +1,106 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::access_token::{ + CreateUserAccessToken, CreatedUserAccessToken, UpdateUserAccessToken, + UserAccessToken, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + get, + path = "/api/v1/user/access-tokens", + responses( + (status = 200, body = Vec) + ), + tag = "user" +)] +pub async fn list_access_tokens( + service: web::Data, + session: Session, +) -> Result { + let tokens = service.user_access_tokens(&session).await?; + ok_json(tokens) +} + +#[utoipa::path( + post, + path = "/api/v1/user/access-tokens", + request_body = CreateUserAccessToken, + responses( + (status = 200, body = CreatedUserAccessToken) + ), + tag = "user" +)] +pub async fn create_access_token( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let token = service + .user_create_access_token(&session, body.into_inner()) + .await?; + ok_json(token) +} + +#[utoipa::path( + put, + path = "/api/v1/user/access-tokens/{id}", + request_body = UpdateUserAccessToken, + responses( + (status = 200, body = UserAccessToken) + ), + params( + ("id" = i64, Path, description = "Access token ID") + ), + tag = "user" +)] +pub async fn update_access_token( + service: web::Data, + session: Session, + token_id: web::Path, + body: web::Json, +) -> Result { + let token = service + .user_update_access_token( + &session, + token_id.into_inner(), + body.into_inner(), + ) + .await?; + ok_json(token) +} + +#[utoipa::path( + delete, + path = "/api/v1/user/access-tokens/{id}", + responses( + (status = 200) + ), + params( + ("id" = i64, Path, description = "Access token ID") + ), + tag = "user" +)] +pub async fn revoke_access_token( + service: web::Data, + session: Session, + token_id: web::Path, +) -> Result { + service + .user_revoke_access_token(&session, token_id.into_inner()) + .await?; + ok() +} diff --git a/lib/api/src/user/accessibility.rs b/lib/api/src/user/accessibility.rs new file mode 100644 index 0000000..bf80ec1 --- /dev/null +++ b/lib/api/src/user/accessibility.rs @@ -0,0 +1,35 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::accessibility::{ + UpdateUserAccessibilityConfig, UserAccessibilityConfig, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + put, + path = "/api/v1/user/config/accessibility", + request_body = UpdateUserAccessibilityConfig, + responses( + (status = 200, body = UserAccessibilityConfig) + ), + tag = "user" +)] +pub async fn update_accessibility( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let config = service + .user_update_accessibility_config(&session, body.into_inner()) + .await?; + ok_json(config) +} diff --git a/lib/api/src/user/appearance.rs b/lib/api/src/user/appearance.rs new file mode 100644 index 0000000..cf809f0 --- /dev/null +++ b/lib/api/src/user/appearance.rs @@ -0,0 +1,33 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::appearance::{UpdateUserAppearanceConfig, UserAppearanceConfig}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + put, + path = "/api/v1/user/config/appearance", + request_body = UpdateUserAppearanceConfig, + responses( + (status = 200, body = UserAppearanceConfig) + ), + tag = "user" +)] +pub async fn update_appearance( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let config = service + .user_update_appearance_config(&session, body.into_inner()) + .await?; + ok_json(config) +} diff --git a/lib/api/src/user/chpc.rs b/lib/api/src/user/chpc.rs new file mode 100644 index 0000000..e495384 --- /dev/null +++ b/lib/api/src/user/chpc.rs @@ -0,0 +1,53 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::chpc::{ContributionHeatmapQuery, ContributionHeatmapResponse}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + get, + path = "/api/v1/user/contribution-heatmap", + params( + ContributionHeatmapQuery + ), + responses( + (status = 200, body = ContributionHeatmapResponse) + ), + tag = "user" +)] +pub async fn contribution_heatmap( + service: web::Data, + session: Session, + query: web::Query, +) -> Result { + let response = service.user_chpc(&session, query.into_inner()).await?; + ok_json(response) +} + +#[utoipa::path( + delete, + path = "/api/v1/user/contribution-heatmap/cache", + responses( + (status = 200) + ), + tag = "user" +)] +pub async fn invalidate_chpc_cache( + service: web::Data, + session: Session, +) -> Result { + service.user_invalidate_chpc_cache(&session).await?; + ok() +} diff --git a/lib/api/src/user/config.rs b/lib/api/src/user/config.rs new file mode 100644 index 0000000..a484e56 --- /dev/null +++ b/lib/api/src/user/config.rs @@ -0,0 +1,26 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, user::config::UserConfigResponse}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/user/config", + responses( + (status = 200, body = UserConfigResponse) + ), + tag = "user" +)] +pub async fn user_config( + service: web::Data, + session: Session, +) -> Result { + let config = service.user_config(&session).await?; + ok_json(config) +} diff --git a/lib/api/src/user/mod.rs b/lib/api/src/user/mod.rs new file mode 100644 index 0000000..4165893 --- /dev/null +++ b/lib/api/src/user/mod.rs @@ -0,0 +1,77 @@ +pub mod access_token; +pub mod accessibility; +pub mod appearance; +pub mod chpc; +pub mod config; +pub mod notification; +pub mod privacy; +pub mod profile; +pub mod sshkey; + +use actix_web::{web, web::ServiceConfig}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::scope("/user") + .service( + web::resource("/notifications") + .route(web::get().to(notification::list_notifications)), + ) + .service( + web::resource("/config") + .route(web::get().to(config::user_config)), + ) + .service( + web::resource("/config/accessibility") + .route(web::put().to(accessibility::update_accessibility)), + ) + .service( + web::resource("/config/appearance") + .route(web::put().to(appearance::update_appearance)), + ) + .service( + web::resource("/config/notification") + .route(web::put().to(notification::update_notification)), + ) + .service( + web::resource("/config/privacy") + .route(web::put().to(privacy::update_privacy)), + ) + .service( + web::resource("/config/profile") + .route(web::put().to(profile::update_profile)), + ) + .service( + web::resource("/avatar") + .route(web::post().to(profile::upload_avatar)), + ) + .service( + web::resource("/access-tokens") + .route(web::get().to(access_token::list_access_tokens)) + .route(web::post().to(access_token::create_access_token)), + ) + .service( + web::resource("/access-tokens/{id}") + .route(web::put().to(access_token::update_access_token)) + .route(web::delete().to(access_token::revoke_access_token)), + ) + .service( + web::resource("/ssh-keys") + .route(web::get().to(sshkey::list_ssh_keys)) + .route(web::post().to(sshkey::add_ssh_key)), + ) + .service( + web::resource("/ssh-keys/{id}") + .route(web::put().to(sshkey::update_ssh_key)) + .route(web::delete().to(sshkey::revoke_ssh_key)), + ) + .service( + web::resource("/contribution-heatmap") + .route(web::get().to(chpc::contribution_heatmap)), + ) + .service( + web::resource("/contribution-heatmap/cache") + .route(web::delete().to(chpc::invalidate_chpc_cache)), + ), + ); +} diff --git a/lib/api/src/user/notification.rs b/lib/api/src/user/notification.rs new file mode 100644 index 0000000..813fdef --- /dev/null +++ b/lib/api/src/user/notification.rs @@ -0,0 +1,56 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, Pagination, + user::notification::{ + AppNotificationItem, UpdateUserNotificationConfig, + UserNotificationConfig, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/user/notifications", + params(Pagination), + responses( + (status = 200, body = Vec) + ), + tag = "user" +)] +pub async fn list_notifications( + service: web::Data, + session: Session, + query: web::Query, +) -> Result { + let notifications = service + .list_notifications(&session, query.into_inner()) + .await?; + ok_json(notifications) +} + +#[utoipa::path( + put, + path = "/api/v1/user/config/notification", + request_body = UpdateUserNotificationConfig, + responses( + (status = 200, body = UserNotificationConfig) + ), + tag = "user" +)] +pub async fn update_notification( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let config = service + .user_update_notification_config(&session, body.into_inner()) + .await?; + ok_json(config) +} diff --git a/lib/api/src/user/privacy.rs b/lib/api/src/user/privacy.rs new file mode 100644 index 0000000..e65d5aa --- /dev/null +++ b/lib/api/src/user/privacy.rs @@ -0,0 +1,33 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::privacy::{UpdateUserPrivacyConfig, UserPrivacyConfig}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + put, + path = "/api/v1/user/config/privacy", + request_body = UpdateUserPrivacyConfig, + responses( + (status = 200, body = UserPrivacyConfig) + ), + tag = "user" +)] +pub async fn update_privacy( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let config = service + .user_update_privacy_config(&session, body.into_inner()) + .await?; + ok_json(config) +} diff --git a/lib/api/src/user/profile.rs b/lib/api/src/user/profile.rs new file mode 100644 index 0000000..87e287e --- /dev/null +++ b/lib/api/src/user/profile.rs @@ -0,0 +1,61 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::profile::{AvatarUploadResponse, UpdateUserProfileConfig, UserProfileConfig}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + put, + path = "/api/v1/user/config/profile", + request_body = UpdateUserProfileConfig, + responses( + (status = 200, body = UserProfileConfig) + ), + tag = "user" +)] +pub async fn update_profile( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let config = service + .user_update_profile_config(&session, body.into_inner()) + .await?; + ok_json(config) +} + +#[utoipa::path( + post, + path = "/api/v1/user/avatar", + request_body(content = Vec, content_type = "image/*"), + responses( + (status = 200, body = AvatarUploadResponse), + (status = 400, description = "Invalid file type or size") + ), + tag = "user" +)] +pub async fn upload_avatar( + service: web::Data, + session: Session, + body: web::Bytes, + req: HttpRequest, +) -> Result { + let content_type = req + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("application/octet-stream"); + + let response = service + .user_upload_avatar(&session, body.to_vec(), content_type) + .await?; + ok_json(response) +} diff --git a/lib/api/src/user/sshkey.rs b/lib/api/src/user/sshkey.rs new file mode 100644 index 0000000..cd2e6d7 --- /dev/null +++ b/lib/api/src/user/sshkey.rs @@ -0,0 +1,99 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::sshkey::{CreateUserSshKey, UpdateUserSshKey, UserSshKey}, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[utoipa::path( + get, + path = "/api/v1/user/ssh-keys", + responses( + (status = 200, body = Vec) + ), + tag = "user" +)] +pub async fn list_ssh_keys( + service: web::Data, + session: Session, +) -> Result { + let keys = service.user_ssh_keys(&session).await?; + ok_json(keys) +} + +#[utoipa::path( + post, + path = "/api/v1/user/ssh-keys", + request_body = CreateUserSshKey, + responses( + (status = 200, body = UserSshKey) + ), + tag = "user" +)] +pub async fn add_ssh_key( + service: web::Data, + session: Session, + body: web::Json, +) -> Result { + let key = service + .user_add_ssh_key(&session, body.into_inner()) + .await?; + ok_json(key) +} + +#[utoipa::path( + put, + path = "/api/v1/user/ssh-keys/{id}", + request_body = UpdateUserSshKey, + responses( + (status = 200, body = UserSshKey) + ), + params( + ("id" = i64, Path, description = "SSH key ID") + ), + tag = "user" +)] +pub async fn update_ssh_key( + service: web::Data, + session: Session, + key_id: web::Path, + body: web::Json, +) -> Result { + let key = service + .user_update_ssh_key(&session, key_id.into_inner(), body.into_inner()) + .await?; + ok_json(key) +} + +#[utoipa::path( + delete, + path = "/api/v1/user/ssh-keys/{id}", + responses( + (status = 200) + ), + params( + ("id" = i64, Path, description = "SSH key ID") + ), + tag = "user" +)] +pub async fn revoke_ssh_key( + service: web::Data, + session: Session, + key_id: web::Path, +) -> Result { + service + .user_revoke_ssh_key(&session, key_id.into_inner()) + .await?; + ok() +} diff --git a/lib/api/src/users/avatar.rs b/lib/api/src/users/avatar.rs new file mode 100644 index 0000000..8643862 --- /dev/null +++ b/lib/api/src/users/avatar.rs @@ -0,0 +1,27 @@ +use actix_web::{HttpResponse, web}; +use service::AppService; + +use crate::error::ApiError; + +#[utoipa::path( + get, + path = "/api/v1/users/avatar/{username}", + params( + ("username" = String, Path, description = "Username"), + ), + responses( + (status = 302, description = "Redirect to avatar image URL"), + (status = 404, description = "User or avatar not found"), + ), + tag = "users" +)] +pub async fn user_avatar( + service: web::Data, + path: web::Path, +) -> Result { + let username = path.into_inner(); + let url = service.users_get_avatar_url(&username).await?; + Ok(HttpResponse::Found() + .insert_header(("Location", url)) + .finish()) +} diff --git a/lib/api/src/users/chpc.rs b/lib/api/src/users/chpc.rs new file mode 100644 index 0000000..36f7a8e --- /dev/null +++ b/lib/api/src/users/chpc.rs @@ -0,0 +1,30 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + user::chpc::{ContributionHeatmapQuery, ContributionHeatmapResponse}, +}; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/contribution-heatmap", + params(("username" = String, Path), ContributionHeatmapQuery), + responses((status = 200, body = ContributionHeatmapResponse)), + tag = "users" +)] +pub async fn user_chpc( + service: web::Data, + username: web::Path, + query: web::Query, +) -> Result { + let result = service + .users_chpc_by_username(&username, query.into_inner()) + .await?; + ok_json(result) +} diff --git a/lib/api/src/users/mod.rs b/lib/api/src/users/mod.rs new file mode 100644 index 0000000..2b1e3ac --- /dev/null +++ b/lib/api/src/users/mod.rs @@ -0,0 +1,65 @@ +pub mod avatar; +pub mod chpc; +pub mod public; +pub mod relation; +pub mod summary; + +use actix_web::{web, web::ServiceConfig}; + +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::scope("/users") + .service( + web::resource("/avatar/{username}") + .route(web::get().to(avatar::user_avatar)), + ) + .service( + web::resource("/{username}/summary") + .route(web::get().to(summary::user_summary)), + ) + .service( + web::resource("/{username}/public") + .route(web::get().to(public::user_public)), + ) + .service( + web::resource("/{username}/contribution-heatmap") + .route(web::get().to(chpc::user_chpc)), + ) + .service( + web::resource("/{username}/follow") + .route(web::post().to(relation::follow_user)), + ) + .service( + web::resource("/{username}/unfollow") + .route(web::post().to(relation::unfollow_user)), + ) + .service( + web::resource("/{username}/block") + .route(web::post().to(relation::block_user)), + ) + .service( + web::resource("/{username}/unblock") + .route(web::post().to(relation::unblock_user)), + ) + .service( + web::resource("/{username}/relation") + .route(web::get().to(relation::relation_status)), + ) + .service( + web::resource("/{username}/followers") + .route(web::get().to(relation::followers)), + ) + .service( + web::resource("/{username}/following") + .route(web::get().to(relation::following)), + ) + .service( + web::resource("/blocked") + .route(web::get().to(relation::blocked_list)), + ) + .service( + web::resource("/{username}/relation-counts") + .route(web::get().to(relation::relation_counts)), + ), + ); +} diff --git a/lib/api/src/users/public.rs b/lib/api/src/users/public.rs new file mode 100644 index 0000000..bf66086 --- /dev/null +++ b/lib/api/src/users/public.rs @@ -0,0 +1,24 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, users::public::PublicUserResponse}; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/public", + params(("username" = String, Path)), + responses((status = 200, body = PublicUserResponse)), + tag = "users" +)] +pub async fn user_public( + service: web::Data, + username: web::Path, +) -> Result { + let result = service.users_public_by_username(&username).await?; + ok_json(result) +} diff --git a/lib/api/src/users/relation.rs b/lib/api/src/users/relation.rs new file mode 100644 index 0000000..82f8050 --- /dev/null +++ b/lib/api/src/users/relation.rs @@ -0,0 +1,178 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, Pagination, + users::relation::{ + UserRelationCard, UserRelationCounts, UserRelationStatus, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + post, + path = "/api/v1/users/{username}/follow", + params(("username" = String, Path)), + responses((status = 200, body = UserRelationStatus)), + tag = "users" +)] +pub async fn follow_user( + service: web::Data, + session: Session, + username: web::Path, +) -> Result { + let result = service + .users_follow_by_username(&session, &username) + .await?; + ok_json(result) +} + +#[utoipa::path( + post, + path = "/api/v1/users/{username}/unfollow", + params(("username" = String, Path)), + responses((status = 200, body = UserRelationStatus)), + tag = "users" +)] +pub async fn unfollow_user( + service: web::Data, + session: Session, + username: web::Path, +) -> Result { + let result = service + .users_unfollow_by_username(&session, &username) + .await?; + ok_json(result) +} + +#[utoipa::path( + post, + path = "/api/v1/users/{username}/block", + params(("username" = String, Path)), + responses((status = 200, body = UserRelationStatus)), + tag = "users" +)] +pub async fn block_user( + service: web::Data, + session: Session, + username: web::Path, +) -> Result { + let result = service.users_block_by_username(&session, &username).await?; + ok_json(result) +} + +#[utoipa::path( + post, + path = "/api/v1/users/{username}/unblock", + params(("username" = String, Path)), + responses((status = 200, body = UserRelationStatus)), + tag = "users" +)] +pub async fn unblock_user( + service: web::Data, + session: Session, + username: web::Path, +) -> Result { + let result = service + .users_unblock_by_username(&session, &username) + .await?; + ok_json(result) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/relation", + params(("username" = String, Path)), + responses((status = 200, body = UserRelationStatus)), + tag = "users" +)] +pub async fn relation_status( + service: web::Data, + session: Session, + username: web::Path, +) -> Result { + let result = service + .users_relation_status_by_username(&session, &username) + .await?; + ok_json(result) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/followers", + params(("username" = String, Path), Pagination), + responses((status = 200, body = Vec)), + tag = "users" +)] +pub async fn followers( + service: web::Data, + session: Session, + username: web::Path, + pagination: web::Query, +) -> Result { + let result = service + .users_followers_by_username( + Some(&session), + &username, + pagination.into_inner(), + ) + .await?; + ok_json(result) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/following", + params(("username" = String, Path), Pagination), + responses((status = 200, body = Vec)), + tag = "users" +)] +pub async fn following( + service: web::Data, + session: Session, + username: web::Path, + pagination: web::Query, +) -> Result { + let result = service + .users_following_by_username( + Some(&session), + &username, + pagination.into_inner(), + ) + .await?; + ok_json(result) +} + +#[utoipa::path( + get, + path = "/api/v1/users/blocked", + responses((status = 200, body = Vec)), + tag = "users" +)] +pub async fn blocked_list( + service: web::Data, + session: Session, +) -> Result { + let result = service.users_blocked(&session).await?; + ok_json(result) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/relation-counts", + params(("username" = String, Path)), + responses((status = 200, body = UserRelationCounts)), + tag = "users" +)] +pub async fn relation_counts( + service: web::Data, + username: web::Path, +) -> Result { + let result = service.users_relation_counts_by_username(&username).await?; + ok_json(result) +} diff --git a/lib/api/src/users/summary.rs b/lib/api/src/users/summary.rs new file mode 100644 index 0000000..855b31f --- /dev/null +++ b/lib/api/src/users/summary.rs @@ -0,0 +1,24 @@ +use actix_web::{HttpResponse, web}; +use serde::Serialize; +use service::{AppService, users::summary::UserSummaryResponse}; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[utoipa::path( + get, + path = "/api/v1/users/{username}/summary", + params(("username" = String, Path)), + responses((status = 200, body = UserSummaryResponse)), + tag = "users" +)] +pub async fn user_summary( + service: web::Data, + username: web::Path, +) -> Result { + let result = service.users_summary_by_username(&username).await?; + ok_json(result) +} diff --git a/lib/api/src/workspace/group.rs b/lib/api/src/workspace/group.rs new file mode 100644 index 0000000..f7a0d54 --- /dev/null +++ b/lib/api/src/workspace/group.rs @@ -0,0 +1,222 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + workspace::{ + group::{CreateWorkspaceGroup, UpdateWorkspaceGroup}, + types::{WorkspaceGroupResponse, WorkspaceMemberResponse}, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct GroupPath { + pub wk: String, + pub group_name: String, +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct GroupMemberPath { + pub wk: String, + pub group_name: String, + pub username: String, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/groups", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Workspace not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_groups( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let data = service.workspace_groups(&session, &wk).await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/groups", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = CreateWorkspaceGroup, + responses( + (status = 200, body = WorkspaceGroupResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn create_group( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_create_group(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/groups/{group_name}", + params(GroupPath), + request_body = UpdateWorkspaceGroup, + responses( + (status = 200, body = WorkspaceGroupResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Group not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_group( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let GroupPath { wk, group_name } = path.into_inner(); + let data = service + .workspace_update_group( + &session, + &wk, + &group_name, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/groups/{group_name}", + params(GroupPath), + responses( + (status = 200, description = "Group deleted"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Group not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn delete_group( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let GroupPath { wk, group_name } = path.into_inner(); + service + .workspace_delete_group(&session, &wk, &group_name) + .await?; + ok() +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/groups/{group_name}/members/{username}", + params(GroupMemberPath), + responses( + (status = 200, description = "Member added to group"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn add_group_member( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let GroupMemberPath { + wk, + group_name, + username, + } = path.into_inner(); + service + .workspace_add_group_member(&session, &wk, &group_name, &username) + .await?; + ok() +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/groups/{group_name}/members/{username}", + params(GroupMemberPath), + responses( + (status = 200, description = "Member removed from group"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn remove_group_member( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let GroupMemberPath { + wk, + group_name, + username, + } = path.into_inner(); + service + .workspace_remove_group_member(&session, &wk, &group_name, &username) + .await?; + ok() +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/groups/{group_name}/members", + params(GroupPath), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Group not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_group_members( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let GroupPath { wk, group_name } = path.into_inner(); + let data = service + .workspace_group_members(&session, &wk, &group_name) + .await?; + ok_json(data) +} diff --git a/lib/api/src/workspace/join.rs b/lib/api/src/workspace/join.rs new file mode 100644 index 0000000..7d87635 --- /dev/null +++ b/lib/api/src/workspace/join.rs @@ -0,0 +1,202 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, + workspace::join::{ + ApproveWorkspaceJoinApply, CreateWorkspaceJoinApply, + ListWorkspaceJoinApply, UpdateWorkspaceJoinStrategy, + WorkspaceJoinApplyResponse, WorkspaceJoinApprovalResponse, + WorkspaceJoinStrategyResponse, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct ApplyApprovePath { + pub wk: String, + pub username: String, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/join-strategy", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 200, body = WorkspaceJoinStrategyResponse), + (status = 404, description = "Workspace not found"), + ) +)] +pub async fn join_strategy( + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let data = service.workspace_join_strategy(&wk).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/join-strategy", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = UpdateWorkspaceJoinStrategy, + responses( + (status = 200, body = WorkspaceJoinStrategyResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Workspace not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_join_strategy( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_update_join_strategy(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/join/apply", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = CreateWorkspaceJoinApply, + responses( + (status = 200, body = WorkspaceJoinApplyResponse), + (status = 401, description = "Unauthorized"), + (status = 409, description = "Already a member or pending application"), + ), + security( + ("session" = []) + ) +)] +pub async fn apply_join( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_apply_join(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/join/my-applies", + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn my_join_applies( + session: Session, + service: web::Data, +) -> Result { + let data = service.workspace_my_join_applies(&session).await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/join/cancel", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 200, body = WorkspaceJoinApplyResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Application not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn cancel_join( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let data = service.workspace_cancel_join_apply(&session, &wk).await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/join/applies", + params( + ("wk" = String, Path, description = "Workspace name"), + ListWorkspaceJoinApply, + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_join_applies( + session: Session, + service: web::Data, + path: web::Path, + query: web::Query, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_join_applies(&session, &wk, query.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/join/applies/{username}/approve", + params(ApplyApprovePath), + request_body = ApproveWorkspaceJoinApply, + responses( + (status = 200, body = WorkspaceJoinApprovalResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Application not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn approve_join( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let ApplyApprovePath { wk, username } = path.into_inner(); + let data = service + .workspace_approve_join_apply( + &session, + &wk, + &username, + params.into_inner(), + ) + .await?; + ok_json(data) +} diff --git a/lib/api/src/workspace/member.rs b/lib/api/src/workspace/member.rs new file mode 100644 index 0000000..2845607 --- /dev/null +++ b/lib/api/src/workspace/member.rs @@ -0,0 +1,137 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; +use service::{ + AppService, Pagination, + workspace::{ + member::{AddWorkspaceMember, UpdateWorkspaceMember}, + types::WorkspaceMemberResponse, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} + +fn ok() -> Result { + Ok(HttpResponse::Ok().finish()) +} + +#[derive(Deserialize, utoipa::IntoParams)] +pub struct MemberPath { + pub wk: String, + pub username: String, +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/members", + params( + ("wk" = String, Path, description = "Workspace name"), + Pagination, + ), + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Workspace not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn list_members( + session: Session, + service: web::Data, + path: web::Path, + pagination: web::Query, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_members(&session, &wk, pagination.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/members", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = AddWorkspaceMember, + responses( + (status = 200, body = WorkspaceMemberResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn add_member( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_add_member(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}/members/{username}", + params(MemberPath), + request_body = UpdateWorkspaceMember, + responses( + (status = 200, body = WorkspaceMemberResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_member( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let MemberPath { wk, username } = path.into_inner(); + let data = service + .workspace_update_member( + &session, + &wk, + &username, + params.into_inner(), + ) + .await?; + ok_json(data) +} +#[utoipa::path( + delete, + path = "/api/v1/workspace/{wk}/members/{username}", + params(MemberPath), + responses( + (status = 200, description = "Member removed"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn remove_member( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let MemberPath { wk, username } = path.into_inner(); + service + .workspace_remove_member(&session, &wk, &username) + .await?; + ok() +} diff --git a/lib/api/src/workspace/mod.rs b/lib/api/src/workspace/mod.rs new file mode 100644 index 0000000..0c442eb --- /dev/null +++ b/lib/api/src/workspace/mod.rs @@ -0,0 +1,82 @@ +pub mod group; +pub mod join; +pub mod member; +pub mod workspace; + +use actix_web::{web, web::ServiceConfig}; +pub fn configure(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("") + .route(web::post().to(workspace::create_workspace)), + ); + cfg.service( + web::resource("/my") + .route(web::get().to(workspace::my_workspaces)), + ); + cfg.service( + web::resource("/join/my-applies") + .route(web::get().to(join::my_join_applies)), + ); +} +pub fn configure_wk(cfg: &mut ServiceConfig) { + cfg.service( + web::resource("") + .route(web::get().to(workspace::get_workspace)) + .route(web::put().to(workspace::update_workspace)), + ); + cfg.service( + web::resource("/avatar") + .route(web::get().to(workspace::get_avatar)) + .route(web::post().to(workspace::upload_avatar)), + ); + cfg.service( + web::resource("/members") + .route(web::get().to(member::list_members)) + .route(web::post().to(member::add_member)), + ); + cfg.service( + web::resource("/members/{username}") + .route(web::put().to(member::update_member)) + .route(web::delete().to(member::remove_member)), + ); + cfg.service( + web::resource("/groups") + .route(web::get().to(group::list_groups)) + .route(web::post().to(group::create_group)), + ); + cfg.service( + web::resource("/groups/{group_name}") + .route(web::put().to(group::update_group)) + .route(web::delete().to(group::delete_group)), + ); + cfg.service( + web::resource("/groups/{group_name}/members/{username}") + .route(web::post().to(group::add_group_member)) + .route(web::delete().to(group::remove_group_member)), + ); + cfg.service( + web::resource("/groups/{group_name}/members") + .route(web::get().to(group::list_group_members)), + ); + cfg.service( + web::resource("/join-strategy") + .route(web::get().to(join::join_strategy)) + .route(web::put().to(join::update_join_strategy)), + ); + cfg.service( + web::resource("/join/apply") + .route(web::post().to(join::apply_join)), + ); + cfg.service( + web::resource("/join/cancel") + .route(web::post().to(join::cancel_join)), + ); + cfg.service( + web::resource("/join/applies") + .route(web::get().to(join::list_join_applies)), + ); + cfg.service( + web::resource("/join/applies/{username}/approve") + .route(web::post().to(join::approve_join)), + ); +} diff --git a/lib/api/src/workspace/workspace.rs b/lib/api/src/workspace/workspace.rs new file mode 100644 index 0000000..fcf443d --- /dev/null +++ b/lib/api/src/workspace/workspace.rs @@ -0,0 +1,169 @@ +use actix_web::{HttpRequest, HttpResponse, web}; +use serde::Serialize; +use service::{ + AppService, + workspace::{ + types::WorkspaceResponse, + workspace::{AvatarUploadResponse, CreateWorkspace, UpdateWorkspace}, + }, +}; +use session::Session; + +use crate::error::ApiError; + +fn ok_json(data: T) -> Result { + Ok(HttpResponse::Ok().json(data)) +} +#[utoipa::path( + post, + path = "/api/v1/workspace", + request_body = CreateWorkspace, + responses( + (status = 200, body = WorkspaceResponse), + (status = 401, description = "Unauthorized"), + (status = 409, description = "Workspace name already exists"), + ), + security( + ("session" = []) + ) +)] +pub async fn create_workspace( + session: Session, + service: web::Data, + params: web::Json, +) -> Result { + let data = service + .workspace_create(&session, params.into_inner()) + .await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/my", + responses( + (status = 200, body = Vec), + (status = 401, description = "Unauthorized"), + ), + security( + ("session" = []) + ) +)] +pub async fn my_workspaces( + session: Session, + service: web::Data, +) -> Result { + let data = service.workspace_my(&session).await?; + ok_json(data) +} +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 200, body = WorkspaceResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Workspace not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn get_workspace( + session: Session, + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let data = service.workspace_get(&session, &wk).await?; + ok_json(data) +} +#[utoipa::path( + put, + path = "/api/v1/workspace/{wk}", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body = UpdateWorkspace, + responses( + (status = 200, body = WorkspaceResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Permission denied"), + (status = 404, description = "Workspace not found"), + ), + security( + ("session" = []) + ) +)] +pub async fn update_workspace( + session: Session, + service: web::Data, + path: web::Path, + params: web::Json, +) -> Result { + let wk = path.into_inner(); + let data = service + .workspace_update(&session, &wk, params.into_inner()) + .await?; + ok_json(data) +} + +#[utoipa::path( + get, + path = "/api/v1/workspace/{wk}/avatar", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + responses( + (status = 302, description = "Redirect to avatar image URL"), + (status = 404, description = "Workspace or avatar not found"), + ), + tag = "workspace" +)] +pub async fn get_avatar( + service: web::Data, + path: web::Path, +) -> Result { + let wk = path.into_inner(); + let url = service.workspace_get_avatar_url(&wk).await?; + Ok(HttpResponse::Found() + .insert_header(("Location", url)) + .finish()) +} + +#[utoipa::path( + post, + path = "/api/v1/workspace/{wk}/avatar", + params( + ("wk" = String, Path, description = "Workspace name"), + ), + request_body(content = Vec, content_type = "image/*"), + responses( + (status = 200, body = AvatarUploadResponse), + (status = 400, description = "Invalid file type or size"), + (status = 403, description = "Permission denied"), + ), + security( + ("session" = []) + ) +)] +pub async fn upload_avatar( + session: Session, + service: web::Data, + path: web::Path, + body: web::Bytes, + req: HttpRequest, +) -> Result { + let wk = path.into_inner(); + let content_type = req + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("application/octet-stream"); + + let response = service + .workspace_upload_avatar(&session, &wk, body.to_vec(), content_type) + .await?; + ok_json(response) +} diff --git a/lib/cache/Cargo.toml b/lib/cache/Cargo.toml new file mode 100644 index 0000000..c94c283 --- /dev/null +++ b/lib/cache/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "cache" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "cache" +[dependencies] +redis = { workspace = true, features = ["cluster-async", "aio","cache-aio","tokio-comp","r2d2"] } +tokio = { workspace = true, features = ["full"] } +moka = { workspace = true, features = ["future"] } +serde_json = { workspace = true } +serde = { workspace = true } +config = { workspace = true } +[lints] +workspace = true diff --git a/lib/cache/cluster.rs b/lib/cache/cluster.rs new file mode 100644 index 0000000..4bde828 --- /dev/null +++ b/lib/cache/cluster.rs @@ -0,0 +1,205 @@ +use std::{sync::Arc, time::Duration}; + +use redis::{ + AsyncCommands, cluster::ClusterClient, cluster_async::ClusterConnection, +}; +use serde::{Serialize, de::DeserializeOwned}; +use tokio::time::timeout; + +use crate::{CacheError, CacheResult}; + +const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(3); + +#[derive(Clone, Debug)] +pub struct ClusterCacheConfig { + pub urls: Vec, + pub key_prefix: Option, + pub command_timeout: Duration, +} + +impl ClusterCacheConfig { + pub fn new(urls: Vec) -> Self { + Self { + urls, + key_prefix: None, + command_timeout: DEFAULT_COMMAND_TIMEOUT, + } + } +} + +#[derive(Clone)] +pub struct ClusterCache { + connection: ClusterConnection, + key_prefix: Option>, + command_timeout: Duration, +} + +impl ClusterCache { + pub async fn connect(config: ClusterCacheConfig) -> CacheResult { + if config.urls.is_empty() { + return Err(CacheError::Config( + "redis cluster urls are empty".to_string(), + )); + } + + let client = + ClusterClient::new(config.urls).map_err(CacheError::Redis)?; + let connection = + timeout(config.command_timeout, client.get_async_connection()) + .await + .map_err(|_| CacheError::Timeout("connect redis cluster"))? + .map_err(CacheError::Redis)?; + + Ok(Self { + connection, + key_prefix: config.key_prefix.map(Arc::from), + command_timeout: config.command_timeout, + }) + } + + pub async fn get(&self, key: &str) -> CacheResult> + where + T: DeserializeOwned, + { + let key = self.key(key); + let mut connection = self.connection.clone(); + let value: Option> = self + .run(redis::cmd("GET").arg(&key).query_async(&mut connection)) + .await?; + + match value { + Some(value) => serde_json::from_slice(&value) + .map(Some) + .map_err(CacheError::Serialize), + None => Ok(None), + } + } + + pub async fn get_json( + &self, + key: &str, + ) -> CacheResult> { + self.get(key).await + } + + pub async fn set( + &self, + key: &str, + value: &T, + ttl: Option, + ) -> CacheResult<()> + where + T: Serialize + ?Sized, + { + let key = self.key(key); + let value = serde_json::to_vec(value).map_err(CacheError::Serialize)?; + let mut connection = self.connection.clone(); + + if let Some(ttl) = ttl { + let seconds = ttl.as_secs().max(1); + self.run::<(), _>(connection.set_ex(key, value, seconds)) + .await + } else { + self.run::<(), _>(connection.set(key, value)).await + } + } + + pub async fn remove(&self, key: &str) -> CacheResult { + let key = self.key(key); + let mut connection = self.connection.clone(); + let removed: u64 = self.run(connection.del(key)).await?; + Ok(removed > 0) + } + + pub async fn exists(&self, key: &str) -> CacheResult { + let key = self.key(key); + let mut connection = self.connection.clone(); + self.run(connection.exists(key)).await + } + + pub async fn set_nx_with_ttl( + &self, + key: &str, + value: &T, + ttl: Duration, + ) -> CacheResult + where + T: Serialize, + { + let key = self.key(key); + let value = serde_json::to_vec(value).map_err(CacheError::Serialize)?; + let mut connection = self.connection.clone(); + let result: Option = self + .run( + redis::cmd("SET") + .arg(&key) + .arg(&value) + .arg("NX") + .arg("EX") + .arg(ttl.as_secs().max(1)) + .query_async(&mut connection), + ) + .await?; + Ok(result.is_some()) + } + + pub async fn expire(&self, key: &str, ttl: Duration) -> CacheResult { + let key = self.key(key); + let mut connection = self.connection.clone(); + self.run(connection.expire(key, ttl.as_secs() as i64)).await + } + + pub async fn delete_pattern(&self, pattern: &str) -> CacheResult { + let pattern = self.key(pattern); + let mut connection = self.connection.clone(); + let keys: Vec = self + .run( + redis::cmd("KEYS") + .arg(&pattern) + .query_async(&mut connection), + ) + .await?; + + if keys.is_empty() { + return Ok(0); + } + + let mut connection = self.connection.clone(); + self.run(connection.del(keys)).await + } + + pub async fn ping(&self) -> CacheResult<()> { + let mut connection = self.connection.clone(); + let pong: String = self + .run(redis::cmd("PING").query_async(&mut connection)) + .await?; + if pong == "PONG" { + Ok(()) + } else { + Err(CacheError::Protocol(format!( + "unexpected redis PING response: {pong}" + ))) + } + } + + pub fn conn(&self) -> ClusterConnection { + self.connection.clone() + } + + fn key(&self, key: &str) -> String { + match &self.key_prefix { + Some(prefix) => format!("{prefix}:{key}"), + None => key.to_string(), + } + } + + async fn run(&self, future: F) -> CacheResult + where + F: std::future::Future>, + { + timeout(self.command_timeout, future) + .await + .map_err(|_| CacheError::Timeout("redis command"))? + .map_err(CacheError::Redis) + } +} diff --git a/lib/cache/error.rs b/lib/cache/error.rs new file mode 100644 index 0000000..54afe92 --- /dev/null +++ b/lib/cache/error.rs @@ -0,0 +1,44 @@ +use std::fmt; + +#[derive(Debug)] +pub enum CacheError { + Config(String), + Protocol(String), + Redis(redis::RedisError), + Serialize(serde_json::Error), + Timeout(&'static str), +} + +impl fmt::Display for CacheError { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Config(message) => { + write!(formatter, "cache config error: {message}") + } + Self::Protocol(message) => { + write!(formatter, "cache protocol error: {message}") + } + Self::Redis(error) => { + write!(formatter, "redis cache error: {error}") + } + Self::Serialize(error) => { + write!(formatter, "cache serialization error: {error}") + } + Self::Timeout(operation) => { + write!(formatter, "cache operation timed out: {operation}") + } + } + } +} + +impl std::error::Error for CacheError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Redis(error) => Some(error), + Self::Serialize(error) => Some(error), + Self::Config(_) | Self::Protocol(_) | Self::Timeout(_) => None, + } + } +} + +pub type CacheResult = Result; diff --git a/lib/cache/lib.rs b/lib/cache/lib.rs new file mode 100644 index 0000000..7eb326d --- /dev/null +++ b/lib/cache/lib.rs @@ -0,0 +1,209 @@ +pub mod cluster; +pub mod error; +pub mod local; + +use std::time::Duration; + +pub use crate::{ + cluster::{ClusterCache, ClusterCacheConfig}, + error::{CacheError, CacheResult}, + local::{LocalCacheConfig, MokaCache}, +}; + +#[derive(Clone, Debug)] +pub struct AppCacheConfig { + pub local: LocalCacheConfig, + pub cluster: Option, + pub default_ttl: Option, + pub cluster_write_through: bool, +} + +impl Default for AppCacheConfig { + fn default() -> Self { + Self { + local: LocalCacheConfig::default(), + cluster: None, + default_ttl: Some(Duration::from_secs(300)), + cluster_write_through: true, + } + } +} + +#[derive(Clone)] +pub struct AppCache { + pub local: MokaCache, + pub cluster: Option, + default_ttl: Option, + cluster_write_through: bool, +} + +impl AppCache { + pub async fn init(config: AppCacheConfig) -> CacheResult { + let local = MokaCache::with_config(config.local); + let cluster = match config.cluster { + Some(cluster) => Some(match ClusterCache::connect(cluster).await { + Ok(cluster) => cluster, + Err(e) => { + println!("cache:init:error with: {}", e); + return Err(e); + } + }), + None => None, + }; + + Ok(Self { + local, + cluster, + default_ttl: config.default_ttl, + cluster_write_through: config.cluster_write_through, + }) + } + + pub fn local_only(local: MokaCache) -> Self { + Self { + local, + cluster: None, + default_ttl: None, + cluster_write_through: false, + } + } + + pub async fn get(&self, key: &str) -> CacheResult> + where + T: serde::Serialize + serde::de::DeserializeOwned, + { + if let Some(value) = self.local.get(key).await? { + return Ok(Some(value)); + } + + let Some(cluster) = &self.cluster else { + return Ok(None); + }; + + let value = cluster.get::(key).await?; + if let Some(value) = &value { + self.local.set(key, value).await?; + } + Ok(value) + } + + pub async fn set(&self, key: &str, value: &T) -> CacheResult<()> + where + T: serde::Serialize + ?Sized, + { + self.local.set(key, value).await?; + if self.cluster_write_through + && let Some(cluster) = &self.cluster + { + cluster.set(key, value, self.default_ttl).await?; + } + Ok(()) + } + + pub async fn remove(&self, key: &str) -> CacheResult<()> { + self.local.remove(key).await; + if let Some(cluster) = &self.cluster { + cluster.remove(key).await?; + } + Ok(()) + } + pub async fn delete_pattern(&self, pattern: &str) -> CacheResult { + let pattern = pattern.to_string(); + let local_pattern = pattern.clone(); + self.local.invalidate_entries_if(move |key| { + simple_glob_match(&local_pattern, key) + }); + + let mut removed = 0u64; + if let Some(cluster) = &self.cluster { + removed = cluster.delete_pattern(&pattern).await?; + } + Ok(removed) + } + + pub async fn ping_cluster(&self) -> CacheResult<()> { + if let Some(cluster) = &self.cluster { + cluster.ping().await?; + } + Ok(()) + } + + pub fn conn(&self) -> Option { + self.cluster.as_ref().map(|c| c.conn()) + } +} + +impl TryFrom<&config::AppConfig> for AppCacheConfig { + type Error = CacheError; + + fn try_from(config: &config::AppConfig) -> Result { + let local = LocalCacheConfig { + max_capacity: config + .cache_local_max_capacity() + .map_err(|error| CacheError::Config(error.to_string()))?, + time_to_live: config + .cache_local_ttl() + .map_err(|error| CacheError::Config(error.to_string()))?, + time_to_idle: config + .cache_local_tti() + .map_err(|error| CacheError::Config(error.to_string()))?, + }; + + let cluster = if config + .cache_cluster_enabled() + .map_err(|error| CacheError::Config(error.to_string()))? + { + Some(ClusterCacheConfig { + urls: config + .redis_urls() + .map_err(|error| CacheError::Config(error.to_string()))?, + key_prefix: config.cache_cluster_key_prefix(), + command_timeout: config + .cache_cluster_command_timeout() + .map_err(|error| CacheError::Config(error.to_string()))?, + }) + } else { + None + }; + + Ok(Self { + local, + cluster, + default_ttl: config + .cache_default_ttl() + .map_err(|error| CacheError::Config(error.to_string()))?, + cluster_write_through: config + .cache_cluster_write_through() + .map_err(|error| CacheError::Config(error.to_string()))?, + }) + } +} +fn simple_glob_match(pattern: &str, key: &str) -> bool { + let p = pattern.as_bytes(); + let k = key.as_bytes(); + let (mut pi, mut ki) = (0usize, 0usize); + + let mut backtrack_p: Option = None; + let mut backtrack_k: usize = 0; + + loop { + if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) { + pi += 1; + ki += 1; + } else if pi < p.len() && p[pi] == b'*' { + backtrack_p = Some(pi); + backtrack_k = ki; + pi += 1; + } else if let Some(saved_pi) = backtrack_p { + backtrack_k += 1; + ki = backtrack_k; + pi = saved_pi + 1; + } else { + return pi == p.len() && ki == k.len(); + } + + if pi == p.len() && ki == k.len() { + return true; + } + } +} diff --git a/lib/cache/local.rs b/lib/cache/local.rs new file mode 100644 index 0000000..dd3d6ad --- /dev/null +++ b/lib/cache/local.rs @@ -0,0 +1,143 @@ +use std::{sync::Arc, time::Duration}; + +use moka::future::{Cache, CacheBuilder}; +use serde::{Serialize, de::DeserializeOwned}; + +use crate::{CacheError, CacheResult}; + +const DEFAULT_LOCAL_MAX_CAPACITY: u64 = 10_000; + +#[derive(Clone, Debug)] +pub struct LocalCacheConfig { + pub max_capacity: u64, + pub time_to_live: Option, + pub time_to_idle: Option, +} + +impl Default for LocalCacheConfig { + fn default() -> Self { + Self { + max_capacity: DEFAULT_LOCAL_MAX_CAPACITY, + time_to_live: Some(Duration::from_secs(300)), + time_to_idle: None, + } + } +} + +#[derive(Clone)] +pub struct MokaCache { + pub(crate) inner: Cache, Arc<[u8]>>, +} + +impl MokaCache { + pub fn init() -> Self { + Self::with_config(LocalCacheConfig::default()) + } + + pub fn with_config(config: LocalCacheConfig) -> Self { + let mut builder = CacheBuilder::new(config.max_capacity); + if let Some(time_to_live) = config.time_to_live { + builder = builder.time_to_live(time_to_live); + } + if let Some(time_to_idle) = config.time_to_idle { + builder = builder.time_to_idle(time_to_idle); + } + + Self { + inner: builder.build(), + } + } + + pub async fn get(&self, key: &str) -> CacheResult> + where + T: DeserializeOwned, + { + match self.inner.get(key).await { + Some(value) => serde_json::from_slice(value.as_ref()) + .map(Some) + .map_err(CacheError::Serialize), + None => Ok(None), + } + } + + pub async fn get_json( + &self, + key: &str, + ) -> CacheResult> { + self.get(key).await + } + + pub async fn set(&self, key: &str, value: &T) -> CacheResult<()> + where + T: Serialize + ?Sized, + { + let value = serde_json::to_vec(value).map_err(CacheError::Serialize)?; + self.inner + .insert(Arc::::from(key), Arc::<[u8]>::from(value)) + .await; + Ok(()) + } + + pub async fn remove(&self, key: &str) { + self.inner.remove(key).await; + } + + pub async fn contains_key(&self, key: &str) -> bool { + self.inner.contains_key(key) + } + + pub async fn invalidate_all(&self) { + self.inner.invalidate_all(); + } + + pub fn invalidate_entries_if(&self, predicate: F) + where + F: Fn(&str) -> bool + Send + Sync + 'static, + { + let _ = self + .inner + .invalidate_entries_if(move |key, _| predicate(key)); + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use serde::{Deserialize, Serialize}; + + use super::{LocalCacheConfig, MokaCache}; + + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] + struct User { + id: u64, + name: String, + } + + #[tokio::test] + async fn stores_and_reads_typed_values() { + let cache = MokaCache::init(); + let user = User { + id: 7, + name: "alice".to_string(), + }; + + cache.set("user:7", &user).await.unwrap(); + + assert_eq!(cache.get::("user:7").await.unwrap(), Some(user)); + } + + #[tokio::test] + async fn expires_values_by_ttl() { + let cache = MokaCache::with_config(LocalCacheConfig { + max_capacity: 16, + time_to_live: Some(Duration::from_millis(25)), + time_to_idle: None, + }); + + cache.set("short", &"value").await.unwrap(); + tokio::time::sleep(Duration::from_millis(60)).await; + + assert_eq!(cache.get::("short").await.unwrap(), None); + } +} diff --git a/lib/channel/Cargo.toml b/lib/channel/Cargo.toml new file mode 100644 index 0000000..f52ccec --- /dev/null +++ b/lib/channel/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "channel" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true +[lib] +path = "lib.rs" +name = "channel" +[dependencies] +cache = { workspace = true } +chrono = { workspace = true, features = ["serde"] } +db = { workspace = true } +dashmap = { workspace = true } +futures = { workspace = true } +hmac = "0.13" +sha2 = { workspace = true } +base64 = { workspace = true } +model = { workspace = true } +redis = { workspace = true, features = ["cluster"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +socketio = { workspace = true } +sqlx = { workspace = true } +storage = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["sync", "time"] } +tokio-util = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true, features = ["serde", "v7"] } +[lints] +workspace = true diff --git a/lib/channel/ack.rs b/lib/channel/ack.rs new file mode 100644 index 0000000..f60e6c1 --- /dev/null +++ b/lib/channel/ack.rs @@ -0,0 +1,162 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::ChannelResult; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageAck { + pub message_id: Uuid, + pub room_id: Uuid, + pub seq: i64, + pub status: AckStatus, + pub timestamp: chrono::DateTime, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum AckStatus { + Pending, + Received, + Persisted, + Delivered, + Failed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AckRequest { + pub message_id: Uuid, + pub room_id: Uuid, + pub client_timestamp: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AckResponse { + pub message_id: Uuid, + pub status: AckStatus, + pub seq: Option, + pub server_timestamp: chrono::DateTime, + pub error: Option, +} + +#[derive(Clone)] +pub struct AckTracker { + cache: cache::AppCache, + timeout: std::time::Duration, +} + +impl AckTracker { + pub fn new(cache: cache::AppCache) -> Self { + Self::with_config(cache, std::time::Duration::from_secs(30)) + } + + pub fn with_config( + cache: cache::AppCache, + timeout: std::time::Duration, + ) -> Self { + Self { cache, timeout } + } + async fn set_with_ttl(&self, key: &str, value: &T) -> ChannelResult<()> + where + T: serde::Serialize, + { + self.cache.set(key, value).await?; + if let Some(cluster) = &self.cache.cluster { + if let Err(e) = cluster.expire(key, self.timeout).await { + tracing::warn!(error = %e, "ack TTL override failed, using default cache TTL"); + } + } + Ok(()) + } + + pub async fn track_pending( + &self, + message_id: Uuid, + room_id: Uuid, + user_id: Uuid, + ) -> ChannelResult<()> { + let key = format!("ack:pending:{}:{}", room_id, message_id); + let ack = MessageAck { + message_id, + room_id, + seq: 0, + status: AckStatus::Pending, + timestamp: chrono::Utc::now(), + }; + let owner_key = format!("ack:owner:{}:{}", room_id, message_id); + self.set_with_ttl(&key, &ack).await?; + self.set_with_ttl(&owner_key, &user_id.to_string()).await?; + Ok(()) + } + + pub async fn mark_received( + &self, + message_id: Uuid, + room_id: Uuid, + ) -> ChannelResult<()> { + self.update_status(message_id, room_id, AckStatus::Received) + .await + } + + pub async fn mark_persisted( + &self, + message_id: Uuid, + room_id: Uuid, + seq: i64, + ) -> ChannelResult<()> { + let key = format!("ack:pending:{}:{}", room_id, message_id); + let ack = MessageAck { + message_id, + room_id, + seq, + status: AckStatus::Persisted, + timestamp: chrono::Utc::now(), + }; + self.set_with_ttl(&key, &ack).await?; + Ok(()) + } + + pub async fn mark_delivered( + &self, + message_id: Uuid, + room_id: Uuid, + ) -> ChannelResult<()> { + self.update_status(message_id, room_id, AckStatus::Delivered) + .await?; + let key = format!("ack:pending:{}:{}", room_id, message_id); + self.cache.remove(&key).await?; + Ok(()) + } + + pub async fn mark_failed( + &self, + message_id: Uuid, + room_id: Uuid, + ) -> ChannelResult<()> { + self.update_status(message_id, room_id, AckStatus::Failed) + .await + } + + pub async fn get_status( + &self, + message_id: Uuid, + room_id: Uuid, + ) -> ChannelResult> { + let key = format!("ack:pending:{}:{}", room_id, message_id); + Ok(self.cache.get::(&key).await?) + } + + async fn update_status( + &self, + message_id: Uuid, + room_id: Uuid, + status: AckStatus, + ) -> ChannelResult<()> { + if let Some(mut ack) = self.get_status(message_id, room_id).await? { + ack.status = status; + ack.timestamp = chrono::Utc::now(); + let key = format!("ack:pending:{}:{}", room_id, message_id); + self.set_with_ttl(&key, &ack).await?; + } + Ok(()) + } +} diff --git a/lib/channel/bus.rs b/lib/channel/bus.rs new file mode 100644 index 0000000..c412da9 --- /dev/null +++ b/lib/channel/bus.rs @@ -0,0 +1,608 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use cache::AppCache; +use dashmap::DashMap; +use db::AppDatabase; +use model::room::RoomMessageModel; +use serde::Deserialize; +use serde::Serialize; +use socketio::{Socket, SocketIo}; +use tokio::sync::{Mutex, RwLock}; +use tracing::warn; +use uuid::Uuid; + +use crate::{ + ChannelBusConfig, ChannelError, ChannelResult, + circuit_breaker::CircuitBreaker, + dedup::DeduplicationManager, + event::ChannelEvent, + metrics::ChannelMetrics, + reconnect::ReconnectManager, + rooms::{ + active_workspace_users, catchup_messages, refresh_user_rooms_cache, + room_socket_name, room_workspace, user_rooms, + }, + security::{CsrfProtection, RateLimiter}, + seq::SeqAllocator, +}; + +const ROOM_MESSAGE_EVENT: &str = "room.message"; + +#[derive(Clone)] +pub struct ChannelBus { + pub(crate) inner: Arc, +} + +pub(crate) struct Inner { + pub(crate) db: AppDatabase, + pub(crate) cache: AppCache, + pub(crate) io: SocketIo, + pub(crate) config: ChannelBusConfig, + pub(crate) online: RwLock>>, + pub(crate) user_sync_locks: DashMap>>, + pub(crate) typing_states: DashMap<(Uuid, Uuid), (crate::event::UserInfo, crate::event::RoomInfo, tokio_util::sync::CancellationToken)>, + pub(crate) seq: SeqAllocator, + pub(crate) dedup: DeduplicationManager, + pub(crate) metrics: ChannelMetrics, + pub(crate) reconnect: ReconnectManager, + pub(crate) rate_limiter: RateLimiter, + pub(crate) csrf: CsrfProtection, + pub(crate) circuit_breaker: CircuitBreaker, +} + +#[derive(Debug, Deserialize)] +struct ConnectAuth { + #[serde(default)] + last_seq: HashMap, +} + +impl ChannelBus { + pub fn io(&self) -> &SocketIo { + &self.inner.io + } + pub async fn first_workspace_id( + &self, + user: Uuid, + ) -> ChannelResult> { + let row = db::sqlx::query_as::<_, (Uuid,)>( + "SELECT wk FROM wk_member WHERE \"user\" = $1 AND leave_at IS NULL LIMIT 1", + ) + .bind(user) + .fetch_optional(self.inner.db.reader()) + .await?; + Ok(row.map(|r| r.0)) + } + pub async fn lookup_room( + &self, + room: Uuid, + ) -> ChannelResult { + let row = db::sqlx::query_as::<_, (String,)>( + "SELECT name FROM room WHERE id = $1", + ) + .bind(room) + .fetch_optional(self.inner.db.reader()) + .await? + .map(|(name,)| name) + .unwrap_or_default(); + Ok(crate::event::RoomInfo { + id: room, + name: row, + }) + } + pub async fn list_workspace_members( + &self, + workspace: Uuid, + ) -> ChannelResult> { + let rows = db::sqlx::query_as::<_, (Uuid, String, String, String)>( + r#"SELECT u.id, u.username, u.display_name, u.avatar_url + FROM wk_member wm + JOIN "user" u ON u.id = wm."user" + WHERE wm.wk = $1 AND wm.leave_at IS NULL + ORDER BY u.username"#, + ) + .bind(workspace) + .fetch_all(self.inner.db.reader()) + .await?; + Ok(rows) + } + pub async fn lookup_workspace( + &self, + wk: Uuid, + ) -> ChannelResult { + use db::sqlx::Row; + let row = db::sqlx::query( + "SELECT name, avatar_url FROM workspace WHERE id = $1", + ) + .bind(wk) + .fetch_optional(self.inner.db.reader()) + .await?; + let (name, avatar_url) = match row { + Some(r) => (r.get::(0), r.get::(1)), + None => (String::new(), String::new()), + }; + Ok(crate::event::WorkspaceInfo { + id: wk, + name, + avatar_url, + }) + } + pub async fn lookup_users( + &self, + users: &[Uuid], + ) -> ChannelResult> { + if users.is_empty() { + return Ok(std::collections::HashMap::new()); + } + let rows = db::sqlx::query_as::<_, model::users::UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, \ + allow_use, can_search, last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE id = ANY($1)", + ) + .bind(users) + .fetch_all(self.inner.db.reader()) + .await?; + Ok(rows + .into_iter() + .map(|m| (m.id, crate::event::UserInfo::from_model(&m))) + .collect()) + } + pub async fn lookup_user( + &self, + user: Uuid, + ) -> ChannelResult { + let row = db::sqlx::query_as::<_, model::users::UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, \ + allow_use, can_search, last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE id = $1", + ) + .bind(user) + .fetch_optional(self.inner.db.reader()) + .await? + .map(|m| crate::event::UserInfo::from_model(&m)) + .unwrap_or_else(|| crate::event::UserInfo::unknown(user)); + Ok(row) + } + pub async fn list_user_rooms( + &self, + user: Uuid, + ) -> ChannelResult> { + crate::rooms::user_rooms_for_api( + &self.inner.db, + &self.inner.cache, + &self.inner.config, + user, + ) + .await + } + pub async fn list_user_categories( + &self, + user: Uuid, + ) -> ChannelResult> { + crate::rooms::user_categories_for_api( + &self.inner.db, + &self.inner.cache, + &self.inner.config, + user, + ) + .await + } + + pub fn new( + db: AppDatabase, + cache: AppCache, + io: SocketIo, + config: ChannelBusConfig, + ) -> Self { + let seq = match config.seq_segment_size { + Some(size) => { + SeqAllocator::with_segment_size(cache.clone(), db.clone(), size) + } + None => SeqAllocator::new(cache.clone(), db.clone()), + }; + let dedup = DeduplicationManager::with_config( + cache.clone(), + std::time::Duration::from_secs( + config.dedup_window_secs.unwrap_or(300), + ), + ); + let reconnect = ReconnectManager::new(cache.clone(), db.clone()); + let rate_limiter = match ( + config.rate_limit_max_requests, + config.rate_limit_window_secs, + ) { + (Some(max), Some(secs)) => RateLimiter::with_config( + cache.clone(), + max, + std::time::Duration::from_secs(secs), + ), + _ => RateLimiter::new(cache.clone()), + }; + let csrf = CsrfProtection::new(cache.clone()); + let circuit_breaker = match ( + config.circuit_breaker_failure_threshold, + config.circuit_breaker_success_threshold, + config.circuit_breaker_timeout_secs, + config.circuit_breaker_half_open_max_calls, + ) { + (Some(failure), Some(success), Some(secs), Some(half_open)) => { + CircuitBreaker::with_config( + failure, + success, + std::time::Duration::from_secs(secs), + half_open, + ) + } + _ => CircuitBreaker::new(), + }; + Self { + inner: Arc::new(Inner { + db, + cache, + io, + config, + online: RwLock::new(HashMap::new()), + user_sync_locks: DashMap::new(), + typing_states: DashMap::new(), + seq, + dedup, + metrics: ChannelMetrics::new(), + reconnect, + rate_limiter, + csrf, + circuit_breaker, + }), + } + } + + pub async fn attach(&self) -> ChannelResult<()> { + let namespace = + self.inner.io.namespace(&self.inner.config.namespace).await; + + let auth_bus = self.clone(); + namespace + .use_middleware(move |socket, auth| { + let bus = auth_bus.clone(); + async move { + if socket.session_user().is_some() { + return Ok(()); + } + let token = auth + .as_ref() + .and_then(|v| v.get("access_token")) + .and_then(|v| v.as_str()); + if let Some(token) = token { + let ctx = bus + .check_access_token(token.to_owned()) + .await + .map_err(|_| { + socketio::SocketIoError::Adapter( + "token invalid or expired".to_owned(), + ) + })?; + socket.set_user(ctx.user_id); + return Ok(()); + } + Err(socketio::SocketIoError::Adapter( + "unauthorized".to_owned(), + )) + } + }) + .await; + + let on_connect_bus = self.clone(); + namespace + .on_connect(move |socket| { + let bus = on_connect_bus.clone(); + async move { + bus.inner.metrics.increment_connections(); + if let Err(error) = bus.handle_connect(socket.clone()).await { + warn!(%error, "channel socket connect failed, disconnecting"); + let _ = socket.disconnect().await; + } + } + }) + .await; + + let on_disconnect_bus = self.clone(); + namespace + .on_disconnect(move |socket, _reason| { + let bus = on_disconnect_bus.clone(); + async move { + bus.inner.metrics.decrement_connections(); + bus.handle_disconnect(&socket).await; + } + }) + .await; + crate::http::ws::register_message_handler(self).await?; + + Ok(()) + } + + pub async fn publish_room_message( + &self, + message: RoomMessageModel, + sender: Option, + ) -> ChannelResult<()> { + let is_new = self + .inner + .dedup + .check_and_mark(message.id, message.room) + .await?; + if !is_new { + return Ok(()); + } + let event = match sender { + Some(s) => ChannelEvent::message_created_with_sender(message, s), + None => ChannelEvent::message_created(message), + }; + self.publish_event(event).await + } + + pub async fn publish_room_event( + &self, + room: Uuid, + event_type: impl Into, + payload: T, + ) -> ChannelResult<()> + where + T: Serialize, + { + let payload = serde_json::to_value(payload)?; + self.publish_event(ChannelEvent::custom(room, event_type, payload)) + .await + } + pub async fn emit_to_user( + &self, + user: Uuid, + event: &str, + data: &T, + ) -> ChannelResult<()> { + let sockets = self + .inner + .online + .read() + .await + .get(&user) + .map(|sockets| sockets.values().cloned().collect::>()) + .unwrap_or_default(); + + for socket in sockets { + socket.emit(event, data).await?; + } + Ok(()) + } + + pub async fn refresh_user(&self, user: Uuid) -> ChannelResult<()> { + let rooms = refresh_user_rooms_cache( + &self.inner.db, + &self.inner.cache, + &self.inner.config, + user, + ) + .await?; + self.sync_online_user_rooms(user, &rooms).await + } + + pub async fn workspace_changed(&self, wk: Uuid) -> ChannelResult<()> { + let users = active_workspace_users(&self.inner.db, wk).await?; + let bus = self.clone(); + let results = + futures::future::join_all(users.into_iter().map(|user| { + let bus = bus.clone(); + async move { bus.refresh_user(user).await } + })) + .await; + let mut first_error = None; + for result in results { + if let Err(e) = result { + tracing::warn!(error = %e, "workspace refresh failed for user"); + if first_error.is_none() { + first_error = Some(e); + } + } + } + if let Some(e) = first_error { + Err(e) + } else { + Ok(()) + } + } + + pub async fn room_changed(&self, room: Uuid) -> ChannelResult<()> { + if let Some(wk) = room_workspace(&self.inner.db, room).await? { + self.workspace_changed(wk).await?; + } + Ok(()) + } + + async fn publish_event(&self, event: ChannelEvent) -> ChannelResult<()> { + self.inner.metrics.increment_sent(); + + let result = self + .inner + .circuit_breaker + .call(async { + self.inner + .io + .namespace(&self.inner.config.namespace) + .await + .to(room_socket_name(event.room)) + .emit(ROOM_MESSAGE_EVENT, event) + .await + .map_err(ChannelError::SocketIo) + }) + .await; + + match result { + Ok(()) => { + self.inner.metrics.increment_received(); + Ok(()) + } + Err(e) => { + self.inner.metrics.increment_failed(); + match e { + crate::circuit_breaker::CircuitBreakerError::Open => { + Err(ChannelError::Internal( + "circuit breaker open".to_string(), + )) + } + crate::circuit_breaker::CircuitBreakerError::Inner(e) => { + Err(e) + } + } + } + } + } + + async fn handle_connect(&self, socket: Socket) -> ChannelResult<()> { + let user = socket.session_user().ok_or(ChannelError::Unauthorized)?; + + if !self + .inner + .rate_limiter + .check_rate_limit(user, "connect") + .await? + { + return Err(ChannelError::RateLimitExceeded); + } + + let auth = socket + .auth() + .await + .and_then(|value| serde_json::from_value::(value).ok()) + .unwrap_or_else(|| ConnectAuth { + last_seq: HashMap::new(), + }); + let rooms = user_rooms( + &self.inner.db, + &self.inner.cache, + &self.inner.config, + user, + ) + .await?; + + for room in &rooms { + socket.join(room_socket_name(*room)).await?; + } + self.register_socket(user, socket.clone()).await; + self.catchup(&socket, &rooms, &auth.last_seq).await?; + Ok(()) + } + + async fn handle_disconnect(&self, socket: &Socket) { + let Some(user) = socket.session_user() else { + return; + }; + let rooms = socket.rooms().await; + for room_name in &rooms { + if let Some(room_str) = room_name.strip_prefix("room:") { + if let Ok(room_id) = Uuid::parse_str(room_str) { + let _ = self.publish_room_event( + room_id, + "voice.channel_left", + serde_json::json!({"user_id": user, "disconnected": true}), + ) + .await; + } + } + } + let mut online = self.inner.online.write().await; + if let Some(sockets) = online.get_mut(&user) { + sockets.remove(socket.id()); + if sockets.is_empty() { + online.remove(&user); + self.inner.user_sync_locks.remove(&user); + } + } + } + + async fn register_socket(&self, user: Uuid, socket: Socket) { + self.inner + .online + .write() + .await + .entry(user) + .or_default() + .insert(socket.id().to_owned(), socket); + } + + async fn sync_online_user_rooms( + &self, + user: Uuid, + desired_rooms: &[Uuid], + ) -> ChannelResult<()> { + let lock = self + .inner + .user_sync_locks + .entry(user) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + let _guard = lock.lock().await; + + let sockets = self + .inner + .online + .read() + .await + .get(&user) + .map(|sockets| sockets.values().cloned().collect::>()) + .unwrap_or_default(); + + let desired = desired_rooms + .iter() + .map(|room| room_socket_name(*room)) + .collect::>(); + + for socket in sockets { + let current = socket + .rooms() + .await + .into_iter() + .filter(|room| room.starts_with("room:")) + .collect::>(); + + for room in desired.difference(¤t) { + socket.join(room.clone()).await?; + } + for room in current.difference(&desired) { + socket.leave(room).await?; + } + } + Ok(()) + } + + async fn catchup( + &self, + socket: &Socket, + rooms: &[Uuid], + last_seq: &HashMap, + ) -> ChannelResult<()> { + for room in rooms { + let Some(seq) = last_seq.get(room) else { + continue; + }; + let messages = catchup_messages( + &self.inner.db, + &self.inner.config, + *room, + *seq, + ) + .await?; + for message in messages { + let sender = match self.lookup_user(message.author).await { + Ok(s) => Some(s), + Err(_) => None, + }; + let event = match sender { + Some(s) => ChannelEvent::message_created_with_sender(message, s), + None => ChannelEvent::message_created(message), + }; + socket.emit(ROOM_MESSAGE_EVENT, event).await?; + } + } + Ok(()) + } +} diff --git a/lib/channel/cdn.rs b/lib/channel/cdn.rs new file mode 100644 index 0000000..08c26d1 --- /dev/null +++ b/lib/channel/cdn.rs @@ -0,0 +1,136 @@ +use std::time::Duration; +use uuid::Uuid; + +use storage::ObjectStorage; + +use crate::{ChannelError, ChannelResult}; + +const ATTACHMENT_KEY_PREFIX: &str = "attachments"; +const DEFAULT_PRESIGNED_TTL: Duration = Duration::from_secs(3600); +const DEFAULT_MAX_FILE_SIZE: u64 = 50 * 1024 * 1024; // 50 MB + +#[derive(Clone)] +pub struct CdnManager { + storage: storage::AppStorage, + presigned_ttl: Duration, + max_file_size: u64, +} + +impl CdnManager { + pub fn new(storage: storage::AppStorage) -> Self { + Self { + storage, + presigned_ttl: DEFAULT_PRESIGNED_TTL, + max_file_size: DEFAULT_MAX_FILE_SIZE, + } + } + + pub fn with_config( + storage: storage::AppStorage, + presigned_ttl: Duration, + max_file_size: u64, + ) -> Self { + Self { + storage, + presigned_ttl, + max_file_size, + } + } + + pub fn max_file_size(&self) -> u64 { + self.max_file_size + } + + pub async fn upload_file( + &self, + room_id: Uuid, + message_id: Uuid, + file_data: &[u8], + filename: &str, + content_type: Option, + ) -> ChannelResult { + if file_data.len() as u64 > self.max_file_size { + return Err(ChannelError::Internal( + "file exceeds max size".to_string(), + )); + } + + let filename = sanitize_filename(filename); + let key = format!( + "{}/{}/{}/{}", + ATTACHMENT_KEY_PREFIX, room_id, message_id, filename + ); + + let options = storage::PutObjectOptions { + content_type, + content_length: Some(file_data.len() as i64), + cache_control: Some("public, max-age=86400".to_string()), + }; + + let stored = self + .storage + .put_bytes(&key, file_data.to_vec(), options) + .await?; + + Ok(CdnStoredFile { + key: stored.key, + url: stored.url, + e_tag: stored.e_tag, + size: file_data.len() as i64, + }) + } + + pub async fn get_file(&self, key: &str) -> ChannelResult { + let object = self.storage.get_bytes(key).await?; + Ok(CdnFileContent { + bytes: object.bytes, + content_type: object.content_type, + content_length: object.content_length, + }) + } + + pub async fn delete_file(&self, key: &str) -> ChannelResult<()> { + self.storage.delete(key).await?; + Ok(()) + } + + pub fn public_url(&self, key: &str) -> ChannelResult> { + self.storage.public_url(key).map_err(Into::into) + } + + pub async fn presigned_url(&self, key: &str) -> ChannelResult { + Ok(self + .storage + .presigned_get_url(key, self.presigned_ttl) + .await?) + } +} +fn sanitize_filename(name: &str) -> String { + let cleaned: String = name + .chars() + .filter(|&c| c.is_ascii_graphic() && c != '/' && c != '\\' && c != '\0') + .take(255) + .collect::() + .trim() + .to_owned(); + if cleaned.is_empty() { + uuid::Uuid::new_v4().to_string() + } else { + cleaned + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CdnStoredFile { + pub key: String, + pub url: String, + pub e_tag: Option, + pub size: i64, +} + +#[derive(Debug)] +pub struct CdnFileContent { + pub bytes: Vec, + pub content_type: Option, + pub content_length: Option, +} diff --git a/lib/channel/circuit_breaker.rs b/lib/channel/circuit_breaker.rs new file mode 100644 index 0000000..e9dc288 --- /dev/null +++ b/lib/channel/circuit_breaker.rs @@ -0,0 +1,165 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +use crate::ChannelError; + +const STATUS_CLOSED: u8 = 0; +const STATUS_OPEN: u8 = 1; +const STATUS_HALF_OPEN: u8 = 2; + +#[derive(Clone)] +pub struct CircuitBreaker { + inner: Arc, +} + +struct CircuitState { + status: u8, + failure_count: u32, + success_count: u32, + half_open_calls: u32, + last_failure_time: Option, +} + +struct Inner { + state: Mutex, + config: CircuitConfig, +} + +#[derive(Clone)] +struct CircuitConfig { + failure_threshold: u32, + success_threshold: u32, + timeout: Duration, + half_open_max_calls: u32, +} + +impl CircuitBreaker { + pub fn new() -> Self { + Self::with_config(5, 2, Duration::from_secs(60), 3) + } + + pub fn with_config( + failure_threshold: u32, + success_threshold: u32, + timeout: Duration, + half_open_max_calls: u32, + ) -> Self { + Self { + inner: Arc::new(Inner { + state: Mutex::new(CircuitState { + status: STATUS_CLOSED, + failure_count: 0, + success_count: 0, + half_open_calls: 0, + last_failure_time: None, + }), + config: CircuitConfig { + failure_threshold, + success_threshold, + timeout, + half_open_max_calls, + }, + }), + } + } + + pub async fn call(&self, f: F) -> Result + where + F: std::future::Future>, + { + let slot_reserved = { + let mut state = self.inner.state.lock().await; + match state.status { + STATUS_OPEN => { + match state.last_failure_time { + Some(t) if t.elapsed() > self.inner.config.timeout => { + state.status = STATUS_HALF_OPEN; + state.half_open_calls = 1; + state.success_count = 0; + true + } + _ => false, + } + } + STATUS_HALF_OPEN => { + if state.half_open_calls + < self.inner.config.half_open_max_calls + { + state.half_open_calls += 1; + true + } else { + false + } + } + _ => true, // Closed → allow + } + }; // Lock released before executing the call. + + if !slot_reserved { + return Err(CircuitBreakerError::Open); + } + + match f.await { + Ok(result) => { + self.on_success().await; + Ok(result) + } + Err(e) => { + self.on_failure().await; + Err(CircuitBreakerError::Inner(e)) + } + } + } + + async fn on_success(&self) { + let mut state = self.inner.state.lock().await; + state.failure_count = 0; + + if state.status == STATUS_HALF_OPEN { + state.success_count += 1; + if state.success_count >= self.inner.config.success_threshold { + state.status = STATUS_CLOSED; + state.success_count = 0; + state.half_open_calls = 0; + } + } + } + async fn on_failure(&self) { + let mut state = self.inner.state.lock().await; + state.failure_count += 1; + state.last_failure_time = Some(Instant::now()); + + if state.status == STATUS_HALF_OPEN { + state.status = STATUS_OPEN; + state.success_count = 0; + state.half_open_calls = 0; + } else if state.status == STATUS_CLOSED + && state.failure_count >= self.inner.config.failure_threshold + { + state.status = STATUS_OPEN; + } + } + + pub async fn is_open(&self) -> bool { + let state = self.inner.state.lock().await; + state.status == STATUS_OPEN + } +} + +#[derive(Debug)] +pub enum CircuitBreakerError { + Open, + Inner(ChannelError), +} + +impl std::fmt::Display for CircuitBreakerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CircuitBreakerError::Open => write!(f, "Circuit breaker is open"), + CircuitBreakerError::Inner(e) => write!(f, "{}", e), + } + } +} + +impl std::error::Error for CircuitBreakerError {} diff --git a/lib/channel/config.rs b/lib/channel/config.rs new file mode 100644 index 0000000..1781791 --- /dev/null +++ b/lib/channel/config.rs @@ -0,0 +1,38 @@ +#[derive(Clone, Debug)] +pub struct ChannelBusConfig { + pub namespace: String, + pub room_cache_ttl_hint: Option, + pub catchup_limit: i64, + pub signing_secret: Option, + pub seq_segment_size: Option, + pub rate_limit_max_requests: Option, + pub rate_limit_window_secs: Option, + pub dedup_window_secs: Option, + pub ack_timeout_secs: Option, + pub revoke_window_secs: Option, + pub circuit_breaker_failure_threshold: Option, + pub circuit_breaker_success_threshold: Option, + pub circuit_breaker_timeout_secs: Option, + pub circuit_breaker_half_open_max_calls: Option, +} + +impl Default for ChannelBusConfig { + fn default() -> Self { + Self { + namespace: "/channel".to_owned(), + room_cache_ttl_hint: Some(std::time::Duration::from_secs(300)), + catchup_limit: 100, + signing_secret: None, + seq_segment_size: None, + rate_limit_max_requests: None, + rate_limit_window_secs: None, + dedup_window_secs: None, + ack_timeout_secs: None, + revoke_window_secs: None, + circuit_breaker_failure_threshold: None, + circuit_breaker_success_threshold: None, + circuit_breaker_timeout_secs: None, + circuit_breaker_half_open_max_calls: None, + } + } +} diff --git a/lib/channel/dedup.rs b/lib/channel/dedup.rs new file mode 100644 index 0000000..3eeab2e --- /dev/null +++ b/lib/channel/dedup.rs @@ -0,0 +1,55 @@ +use std::time::Duration; +use uuid::Uuid; + +use crate::{ChannelError, ChannelResult, security::require_cluster}; + +#[derive(Clone)] +pub struct DeduplicationManager { + cache: cache::AppCache, + window: Duration, +} + +impl DeduplicationManager { + pub fn new(cache: cache::AppCache) -> Self { + Self { + cache, + window: Duration::from_secs(300), + } + } + + pub fn with_config(cache: cache::AppCache, window: Duration) -> Self { + Self { cache, window } + } + + pub async fn check_and_mark( + &self, + message_id: Uuid, + room_id: Uuid, + ) -> ChannelResult { + let cluster = require_cluster(&self.cache)?; + let key = format!("dedup:{}:{}", room_id, message_id); + let mut conn = cluster.conn(); + + let result: Option = redis::cmd("SET") + .arg(&key) + .arg("1") + .arg("NX") + .arg("EX") + .arg(self.window.as_secs()) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + Ok(result.is_some()) + } + + pub async fn is_duplicate( + &self, + message_id: Uuid, + room_id: Uuid, + ) -> ChannelResult { + let cluster = require_cluster(&self.cache)?; + let key = format!("dedup:{}:{}", room_id, message_id); + Ok(cluster.exists(&key).await?) + } +} diff --git a/lib/channel/e2e.rs b/lib/channel/e2e.rs new file mode 100644 index 0000000..cd57193 --- /dev/null +++ b/lib/channel/e2e.rs @@ -0,0 +1,36 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EncryptedMessage { + pub ciphertext: Vec, + pub nonce: Vec, + pub recipient_key_id: String, +} + +pub struct E2EEncryption; + +impl E2EEncryption { + pub fn new() -> Self { + Self {} + } + + pub fn encrypt( + &self, + _plaintext: &[u8], + _recipient_public_key: &[u8], + ) -> crate::ChannelResult { + Err(crate::ChannelError::Internal( + "e2e not implemented".to_string(), + )) + } + + pub fn decrypt( + &self, + _encrypted: &EncryptedMessage, + _private_key: &[u8], + ) -> crate::ChannelResult> { + Err(crate::ChannelError::Internal( + "e2e not implemented".to_string(), + )) + } +} diff --git a/lib/channel/envelope.rs b/lib/channel/envelope.rs new file mode 100644 index 0000000..be4b919 --- /dev/null +++ b/lib/channel/envelope.rs @@ -0,0 +1,37 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChannelEnvelope { + #[serde(default = "uuid::Uuid::now_v7")] + pub message_id: Uuid, + pub user_id: Uuid, + pub payload: T, + #[serde(default = "default_timestamp")] + pub created_at: chrono::DateTime, + #[serde(default)] + pub attempt: u8, +} + +fn default_timestamp() -> chrono::DateTime { + chrono::Utc::now() +} + +impl ChannelEnvelope { + pub fn new(user_id: Uuid, payload: T) -> Self { + Self { + message_id: Uuid::now_v7(), + user_id, + payload, + created_at: chrono::Utc::now(), + attempt: 1, + } + } + + pub fn retry(self) -> Self { + Self { + attempt: self.attempt + 1, + ..self + } + } +} diff --git a/lib/channel/error.rs b/lib/channel/error.rs new file mode 100644 index 0000000..e6773ef --- /dev/null +++ b/lib/channel/error.rs @@ -0,0 +1,66 @@ +#[derive(Debug, thiserror::Error)] +pub enum ChannelError { + #[error("unauthorized")] + Unauthorized, + #[error("token invalid or expired")] + TokenInvalidOrExpired, + #[error("renewal limit exceeded")] + RenewalLimitExceeded, + #[error("rate limit exceeded")] + RateLimitExceeded, + #[error("room not found")] + RoomNotFound, + #[error("user not found")] + UserNotFound, + #[error("access denied")] + AccessDenied, + #[error("validation failed: {0}")] + Validation(String), + #[error("internal error: {0}")] + Internal(String), + #[error("database failed: {0}")] + Database(#[from] db::sqlx::Error), + #[error("cache failed: {0}")] + Cache(#[from] cache::CacheError), + #[error("socket.io failed: {0}")] + SocketIo(#[from] socketio::SocketIoError), + #[error("serialization failed: {0}")] + Serialization(#[from] serde_json::Error), + #[error("redis failed: {0}")] + Redis(#[from] redis::RedisError), + #[error("storage failed: {0}")] + Storage(#[from] storage::StorageError), +} + +impl ChannelError { + pub fn ws_error_code(&self) -> (u16, &'static str) { + match self { + ChannelError::Unauthorized => (401, "unauthorized"), + ChannelError::TokenInvalidOrExpired => (401, "token_invalid"), + ChannelError::AccessDenied => (403, "access_denied"), + ChannelError::Validation(_) => (422, "validation_failed"), + ChannelError::RateLimitExceeded => (429, "rate_limit_exceeded"), + ChannelError::RoomNotFound => (404, "not_found"), + ChannelError::UserNotFound => (404, "user_not_found"), + ChannelError::RenewalLimitExceeded => (429, "renewal_limit"), + ChannelError::Internal(_) => (500, "internal_error"), + ChannelError::Database(_) => (500, "internal_error"), + ChannelError::Cache(_) => (500, "internal_error"), + ChannelError::SocketIo(_) => (500, "internal_error"), + ChannelError::Serialization(_) => (500, "internal_error"), + ChannelError::Redis(_) => (500, "internal_error"), + ChannelError::Storage(_) => (500, "internal_error"), + } + } + + pub fn to_ws_error(&self) -> crate::http::out_event::WsError { + let (code, error_type) = self.ws_error_code(); + crate::http::out_event::WsError { + code: code as i32, + error: error_type.to_string(), + message: self.to_string(), + } + } +} + +pub type ChannelResult = Result; diff --git a/lib/channel/event/ai.rs b/lib/channel/event/ai.rs new file mode 100644 index 0000000..97fc4e2 --- /dev/null +++ b/lib/channel/event/ai.rs @@ -0,0 +1,44 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{AgentInfo, RoomInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiAgentJoinedService { + pub room: RoomInfo, + pub agent: AgentInfo, + pub joined_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiAgentLeftService { + pub room: RoomInfo, + pub agent: AgentInfo, + pub left_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomAiEntry { + pub agent_session: Uuid, + pub name: String, + pub agent_kind: String, + pub model_version: Option, + pub enabled: bool, + pub auto_reply: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomAiListService { + pub room: RoomInfo, + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiAgentStatusChangedService { + pub room: RoomInfo, + pub agent: AgentInfo, + pub old_status: String, + pub new_status: String, + pub changed_at: DateTime, +} diff --git a/lib/channel/event/attachment.rs b/lib/channel/event/attachment.rs new file mode 100644 index 0000000..57f863e --- /dev/null +++ b/lib/channel/event/attachment.rs @@ -0,0 +1,55 @@ +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use crate::event::{RoomInfo, UserInfo}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AttachmentEventType { + Uploaded, + ThumbnailGenerated, + Deleted, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum AttachmentEvent { + #[serde(rename = "attachment.uploaded")] + Uploaded(AttachmentUploadedService), + #[serde(rename = "attachment.thumbnail_generated")] + ThumbnailGenerated(AttachmentThumbnailService), + #[serde(rename = "attachment.deleted")] + Deleted(AttachmentDeletedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttachmentUploadedService { + pub id: Uuid, + pub room: RoomInfo, + pub message: Uuid, + pub filename: String, + pub content_type: Option, + pub size: i64, + pub url: Option, + pub uploaded_by: Uuid, + pub uploaded_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttachmentThumbnailService { + pub id: Uuid, + pub room: RoomInfo, + pub thumbnail_url: String, + pub width: i32, + pub height: i32, + pub generated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttachmentDeletedService { + pub id: Uuid, + pub room: RoomInfo, + pub deleted_by: UserInfo, + pub deleted_at: DateTime, +} diff --git a/lib/channel/event/ban.rs b/lib/channel/event/ban.rs new file mode 100644 index 0000000..cdf3bc7 --- /dev/null +++ b/lib/channel/event/ban.rs @@ -0,0 +1,52 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::event::{UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum BanEventType { + Banned, + Unbanned, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum BanEvent { + #[serde(rename = "ban.banned")] + Banned(BannedService), + #[serde(rename = "ban.unbanned")] + Unbanned(UnbannedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BannedService { + pub workspace: WorkspaceInfo, + pub user: UserInfo, + pub banned_by: UserInfo, + pub reason: Option, + pub expires_at: Option>, + pub banned_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UnbannedService { + pub workspace: WorkspaceInfo, + pub user: UserInfo, + pub unbanned_by: UserInfo, + pub unbanned_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BanCreateClient { + pub workspace: WorkspaceInfo, + pub user: UserInfo, + pub reason: Option, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BanRemoveClient { + pub workspace: WorkspaceInfo, + pub user: UserInfo, +} diff --git a/lib/channel/event/category.rs b/lib/channel/event/category.rs new file mode 100644 index 0000000..c9c4260 --- /dev/null +++ b/lib/channel/event/category.rs @@ -0,0 +1,68 @@ +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use crate::event::{UserInfo, WorkspaceInfo}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CategoryEventType { + Created, + Updated, + Deleted, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum CategoryEvent { + #[serde(rename = "category.created")] + Created(CategoryCreatedService), + #[serde(rename = "category.updated")] + Updated(CategoryUpdatedService), + #[serde(rename = "category.deleted")] + Deleted(CategoryDeletedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CategoryCreatedService { + pub id: Uuid, + #[serde(rename = "workspace")] + pub project: WorkspaceInfo, + pub name: String, + pub position: i32, + pub created_by: UserInfo, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CategoryUpdatedService { + pub id: Uuid, + #[serde(rename = "workspace")] + pub project: WorkspaceInfo, + pub name: Option, + pub position: Option, + pub updated_by: UserInfo, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CategoryDeletedService { + pub id: Uuid, + #[serde(rename = "workspace")] + pub project: WorkspaceInfo, + pub deleted_by: UserInfo, + pub deleted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CategoryCreateClient { + pub project: WorkspaceInfo, + pub name: String, + pub position: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CategoryUpdateClient { + pub name: Option, + pub position: Option, +} diff --git a/lib/channel/event/common.rs b/lib/channel/event/common.rs new file mode 100644 index 0000000..ab7a7ea --- /dev/null +++ b/lib/channel/event/common.rs @@ -0,0 +1,92 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserInfo { + pub id: Uuid, + pub username: String, + pub display_name: String, + pub avatar_url: String, +} + +impl UserInfo { + pub fn unknown(id: Uuid) -> Self { + Self { + id, + username: String::new(), + display_name: String::new(), + avatar_url: String::new(), + } + } + + pub fn from_model(m: &model::users::UserModel) -> Self { + Self { + id: m.id, + username: m.username.clone(), + display_name: m.display_name.clone(), + avatar_url: m.avatar_url.clone(), + } + } +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomInfo { + pub id: Uuid, + pub name: String, +} + +impl RoomInfo { + pub fn unknown(id: Uuid) -> Self { + Self { + id, + name: String::new(), + } + } + + pub fn from_model(m: &model::room::RoomModel) -> Self { + Self { + id: m.id, + name: m.name.clone(), + } + } +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceInfo { + pub id: Uuid, + pub name: String, + pub avatar_url: String, +} + +impl WorkspaceInfo { + pub fn unknown(id: Uuid) -> Self { + Self { + id, + name: String::new(), + avatar_url: String::new(), + } + } + + pub fn from_model(m: &model::workspace::WorkspaceModel) -> Self { + Self { + id: m.id, + name: m.name.clone(), + avatar_url: m.avatar_url.clone(), + } + } +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentInfo { + pub id: Uuid, + pub name: String, + pub agent_type: String, + pub model_name: Option, +} + +impl AgentInfo { + pub fn unknown(id: Uuid) -> Self { + Self { + id, + name: String::new(), + agent_type: String::new(), + model_name: None, + } + } +} diff --git a/lib/channel/event/conversation.rs b/lib/channel/event/conversation.rs new file mode 100644 index 0000000..436b2a8 --- /dev/null +++ b/lib/channel/event/conversation.rs @@ -0,0 +1,123 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ConversationEventType { + Pinned, + Unpinned, + Muted, + Unmuted, + UnreadUpdated, + NotifyLevelChanged, +} + +impl ConversationEventType { + pub fn as_str(&self) -> &str { + match self { + Self::Pinned => "conversation.pinned", + Self::Unpinned => "conversation.unpinned", + Self::Muted => "conversation.muted", + Self::Unmuted => "conversation.unmuted", + Self::UnreadUpdated => "conversation.unread_updated", + Self::NotifyLevelChanged => "conversation.notify_level_changed", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ConversationEvent { + #[serde(rename = "conversation.pinned")] + Pinned(ConversationPinnedService), + #[serde(rename = "conversation.unpinned")] + Unpinned(ConversationUnpinnedService), + #[serde(rename = "conversation.muted")] + Muted(ConversationMutedService), + #[serde(rename = "conversation.unmuted")] + Unmuted(ConversationUnmutedService), + #[serde(rename = "conversation.unread_updated")] + UnreadUpdated(ConversationUnreadUpdatedService), + #[serde(rename = "conversation.notify_level_changed")] + NotifyLevelChanged(ConversationNotifyLevelChangedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationPinnedService { + pub user: UserInfo, + pub room: RoomInfo, + pub pinned_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationUnpinnedService { + pub user: UserInfo, + pub room: RoomInfo, + pub unpinned_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationMutedService { + pub user: UserInfo, + pub room: RoomInfo, + pub muted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationUnmutedService { + pub user: UserInfo, + pub room: RoomInfo, + pub unmuted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationUnreadUpdatedService { + pub user: UserInfo, + pub room: RoomInfo, + pub last_read_seq: i64, + pub unread_count: i64, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationNotifyLevelChangedService { + pub user: UserInfo, + pub room: RoomInfo, + pub old_level: String, + pub new_level: String, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationPinClient { + pub room: Uuid, + pub pin: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationMuteClient { + pub room: Uuid, + pub mute: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationNotifyLevelClient { + pub room: Uuid, + pub notify_level: String, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationSummary { + pub room: Uuid, + pub room_name: String, + pub room_type: String, + pub is_pinned: bool, + pub is_muted: bool, + pub notify_level: String, + pub last_read_seq: i64, + pub max_seq: i64, + pub unread_count: i64, + pub last_read_at: Option>, +} diff --git a/lib/channel/event/dm.rs b/lib/channel/event/dm.rs new file mode 100644 index 0000000..cbad446 --- /dev/null +++ b/lib/channel/event/dm.rs @@ -0,0 +1,66 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DmEventType { + Created, + Closed, + Reopened, +} + +impl DmEventType { + pub fn as_str(&self) -> &str { + match self { + Self::Created => "dm.created", + Self::Closed => "dm.closed", + Self::Reopened => "dm.reopened", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum DmEvent { + #[serde(rename = "dm.created")] + Created(DmCreatedService), + #[serde(rename = "dm.closed")] + Closed(DmClosedService), + #[serde(rename = "dm.reopened")] + Reopened(DmReopenedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DmCreatedService { + pub room: RoomInfo, + pub initiator: UserInfo, + pub recipient: UserInfo, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DmClosedService { + pub room: RoomInfo, + pub closed_by: UserInfo, + pub closed_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DmReopenedService { + pub room: RoomInfo, + pub reopened_by: UserInfo, + pub reopened_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DmCreateClient { + pub recipient: Uuid, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DmCloseClient { + pub room: Uuid, +} diff --git a/lib/channel/event/draft.rs b/lib/channel/event/draft.rs new file mode 100644 index 0000000..f708088 --- /dev/null +++ b/lib/channel/event/draft.rs @@ -0,0 +1,35 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DraftSavedService { + pub user: UserInfo, + pub room: RoomInfo, + pub content: String, + pub saved_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DraftClearedService { + pub user: UserInfo, + pub room: RoomInfo, + pub cleared_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DraftSaveClient { + pub room: RoomInfo, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DraftLoadClient { + pub room: RoomInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DraftClearClient { + pub room: RoomInfo, +} diff --git a/lib/channel/event/forward.rs b/lib/channel/event/forward.rs new file mode 100644 index 0000000..cd3a0e6 --- /dev/null +++ b/lib/channel/event/forward.rs @@ -0,0 +1,24 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageForwardedService { + pub id: Uuid, + pub seq: i64, + pub room: RoomInfo, + pub sender: UserInfo, + pub content: String, + pub content_type: String, + pub source_room: RoomInfo, + pub source_message_id: Uuid, + pub forwarded_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageForwardClient { + pub source_message_id: Uuid, + pub target_room: Uuid, +} diff --git a/lib/channel/event/invite.rs b/lib/channel/event/invite.rs new file mode 100644 index 0000000..3bfb4c1 --- /dev/null +++ b/lib/channel/event/invite.rs @@ -0,0 +1,83 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum InviteEventType { + Created, + Accepted, + Rejected, + Revoked, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum InviteEvent { + #[serde(rename = "invite.created")] + Created(InviteCreatedService), + #[serde(rename = "invite.accepted")] + Accepted(InviteAcceptedService), + #[serde(rename = "invite.rejected")] + Rejected(InviteRejectedService), + #[serde(rename = "invite.revoked")] + Revoked(InviteRevokedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteCreatedService { + pub id: Uuid, + pub workspace: WorkspaceInfo, + pub room: Option, + pub inviter: UserInfo, + pub invitee: Option, + pub code: String, + pub max_uses: Option, + pub expires_at: Option>, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteAcceptedService { + pub id: Uuid, + pub workspace: WorkspaceInfo, + pub room: Option, + pub user: UserInfo, + pub accepted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteRejectedService { + pub id: Uuid, + pub workspace: WorkspaceInfo, + pub user: UserInfo, + pub rejected_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteRevokedService { + pub id: Uuid, + pub workspace: WorkspaceInfo, + pub revoked_by: UserInfo, + pub revoked_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteCreateClient { + pub workspace: WorkspaceInfo, + pub room: Option, + pub max_uses: Option, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteAcceptClient { + pub code: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InviteRevokeClient { + pub id: Uuid, +} diff --git a/lib/channel/event/member.rs b/lib/channel/event/member.rs new file mode 100644 index 0000000..ded684e --- /dev/null +++ b/lib/channel/event/member.rs @@ -0,0 +1,89 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemberJoinedService { + pub room: RoomInfo, + pub user: UserInfo, + pub project_role: Option, + pub joined_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemberRemovedService { + pub room: RoomInfo, + pub user: UserInfo, + pub removed_by: UserInfo, + pub removed_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadReceiptService { + pub room: RoomInfo, + pub user: UserInfo, + pub last_read_seq: i64, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypingStartService { + pub room: RoomInfo, + pub user: UserInfo, + pub sender_type: String, + pub started_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypingStopService { + pub room: RoomInfo, + pub user: UserInfo, + pub sender_type: String, + pub stopped_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DndUpdatedService { + pub room: RoomInfo, + pub user: UserInfo, + pub do_not_disturb: bool, + pub dnd_start_hour: Option, + pub dnd_end_hour: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessGrantClient { + pub room: RoomInfo, + pub user: UserInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessRevokeClient { + pub room: RoomInfo, + pub user: UserInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadReceiptClient { + pub room: RoomInfo, + pub last_read_seq: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypingStartClient { + pub room: RoomInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypingStopClient { + pub room: RoomInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DndUpdateClient { + pub room: RoomInfo, + pub do_not_disturb: bool, + pub dnd_start_hour: Option, + pub dnd_end_hour: Option, +} diff --git a/lib/channel/event/message.rs b/lib/channel/event/message.rs new file mode 100644 index 0000000..ef70577 --- /dev/null +++ b/lib/channel/event/message.rs @@ -0,0 +1,106 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +use crate::event::reaction::ReactionGroup; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageNewService { + pub id: Uuid, + pub seq: i64, + pub room: RoomInfo, + pub sender_type: String, + pub sender: UserInfo, + pub thread: Option, + pub in_reply_to: Option, + pub content: String, + pub content_type: String, + #[serde(default)] + pub pinned: bool, + pub system_type: Option, + #[serde(default)] + pub metadata: serde_json::Value, + pub thinking_content: Option, + pub thinking_is_chunked: Option, + pub send_at: DateTime, + #[serde(default)] + pub reactions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageEditedService { + pub id: Uuid, + pub seq: i64, + pub room: RoomInfo, + pub sender: UserInfo, + pub content: String, + pub edited_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageRevokedService { + pub id: Uuid, + pub seq: i64, + pub room: RoomInfo, + pub revoked_by: UserInfo, + pub revoked_at: DateTime, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageStreamStartService { + pub message_id: Uuid, + pub room: RoomInfo, + pub sse_url: Option, + pub display_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageStreamChunkService { + pub message_id: Uuid, + pub room: RoomInfo, + pub seq: i64, + pub content: String, + pub chunk_type: String, + pub display_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageStreamDoneService { + pub message_id: Uuid, + pub room: RoomInfo, + pub content: String, + pub thinking_content: Option, + pub display_name: Option, + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageListService { + pub room: RoomInfo, + pub messages: Vec, + pub total: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageSendClient { + pub room: Uuid, + pub content: String, + pub content_type: String, + pub thread: Option, + pub in_reply_to: Option, + pub attachment_ids: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageEditClient { + pub room: Uuid, + pub message_id: Uuid, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageRevokeClient { + pub room: Uuid, + pub message_id: Uuid, +} diff --git a/lib/channel/event/message_read.rs b/lib/channel/event/message_read.rs new file mode 100644 index 0000000..31b17fc --- /dev/null +++ b/lib/channel/event/message_read.rs @@ -0,0 +1,45 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageReadService { + pub room: RoomInfo, + pub message_id: Uuid, + pub message_seq: i64, + pub reader: UserInfo, + pub read_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageReadBatchService { + pub room: RoomInfo, + pub message_ids: Vec, + pub last_seq: i64, + pub reader: UserInfo, + pub read_at: DateTime, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageReadersService { + pub message_id: Uuid, + pub readers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageReaderEntry { + pub user: UserInfo, + pub read_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageMarkReadClient { + pub room: Uuid, + pub message_ids: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageGetReadersClient { + pub message_id: Uuid, +} diff --git a/lib/channel/event/mod.rs b/lib/channel/event/mod.rs new file mode 100644 index 0000000..6aea10c --- /dev/null +++ b/lib/channel/event/mod.rs @@ -0,0 +1,114 @@ +pub mod ai; +pub mod attachment; +pub mod ban; +pub mod category; +pub mod common; +pub mod conversation; +pub mod dm; +pub mod draft; +pub mod forward; +pub mod invite; +pub mod member; +pub mod message; +pub mod message_read; +pub mod notify; +pub mod pin; +pub mod presence; +pub mod reaction; +pub mod rooms; +pub mod search; +pub mod star; +pub mod thread; +pub mod voice; +pub mod workspace; + +pub use common::{AgentInfo, RoomInfo, UserInfo, WorkspaceInfo}; + +use model::room::RoomMessageModel; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use uuid::Uuid; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum ChannelEventType { + MessageCreated, + MessageUpdated, + MessageDeleted, + ReactionCreated, + ReactionDeleted, + MessageRead, + DmCreated, + DmClosed, + ConversationUpdated, + Custom(String), +} + +impl ChannelEventType { + pub fn as_str(&self) -> &str { + match self { + Self::MessageCreated => "message.created", + Self::MessageUpdated => "message.updated", + Self::MessageDeleted => "message.deleted", + Self::ReactionCreated => "reaction.created", + Self::ReactionDeleted => "reaction.deleted", + Self::MessageRead => "message.read", + Self::DmCreated => "dm.created", + Self::DmClosed => "dm.closed", + Self::ConversationUpdated => "conversation.updated", + Self::Custom(value) => value, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChannelEvent { + pub room: Uuid, + pub seq: Option, + pub event_type: String, + pub message: Option, + pub payload: Option, + #[serde(default)] + pub sender_user: Option, +} + +impl ChannelEvent { + pub fn message_created(message: RoomMessageModel) -> Self { + Self { + room: message.room, + seq: Some(message.seq), + event_type: ChannelEventType::MessageCreated.as_str().to_owned(), + message: Some(message), + payload: None, + sender_user: None, + } + } + + pub fn message_created_with_sender( + message: RoomMessageModel, + sender: UserInfo, + ) -> Self { + Self { + room: message.room, + seq: Some(message.seq), + event_type: ChannelEventType::MessageCreated.as_str().to_owned(), + message: Some(message), + payload: None, + sender_user: Some(sender), + } + } + + pub fn custom( + room: Uuid, + event_type: impl Into, + payload: Value, + ) -> Self { + Self { + room, + seq: None, + event_type: event_type.into(), + message: None, + payload: Some(payload), + sender_user: None, + } + } +} diff --git a/lib/channel/event/notify.rs b/lib/channel/event/notify.rs new file mode 100644 index 0000000..fad42e1 --- /dev/null +++ b/lib/channel/event/notify.rs @@ -0,0 +1,69 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum NotifyEventType { + Created, + Read, + Archived, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum NotifyEvent { + #[serde(rename = "notify.created")] + Created(NotifyCreatedService), + #[serde(rename = "notify.read")] + Read(NotifyReadService), + #[serde(rename = "notify.archived")] + Archived(NotifyArchivedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NotifyCreatedService { + pub id: Uuid, + pub room: Option, + pub workspace: Option, + pub user: UserInfo, + pub notification_type: String, + pub title: String, + pub content: Option, + pub related_message_id: Option, + pub related_user: Option, + pub metadata: Option, + pub created_at: DateTime, + pub deep_link_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NotifyReadService { + pub id: Uuid, + pub user: UserInfo, + pub read_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NotifyArchivedService { + pub id: Uuid, + pub user: UserInfo, + pub archived_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NotifyReadClient { + pub id: Uuid, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NotifyReadAllClient { + pub workspace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NotifyArchiveClient { + pub id: Uuid, +} diff --git a/lib/channel/event/pin.rs b/lib/channel/event/pin.rs new file mode 100644 index 0000000..dcf3c0d --- /dev/null +++ b/lib/channel/event/pin.rs @@ -0,0 +1,49 @@ +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use serde::{Deserialize, Serialize}; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PinEventType { + Added, + Removed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum PinEvent { + #[serde(rename = "pin.added")] + Added(PinAddedService), + #[serde(rename = "pin.removed")] + Removed(PinRemovedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PinAddedService { + pub room: RoomInfo, + pub message: Uuid, + pub pinned_by: UserInfo, + pub pinned_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PinRemovedService { + pub room: RoomInfo, + pub message: Uuid, + pub removed_by: UserInfo, + pub removed_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PinAddClient { + pub room: RoomInfo, + pub message: Uuid, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PinRemoveClient { + pub room: RoomInfo, + pub message: Uuid, +} diff --git a/lib/channel/event/presence.rs b/lib/channel/event/presence.rs new file mode 100644 index 0000000..d44a03c --- /dev/null +++ b/lib/channel/event/presence.rs @@ -0,0 +1,58 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::event::{UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PresenceEventType { + Changed, + CustomStatusUpdated, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum UserPresenceStatus { + Online, + Idle, + Dnd, + Offline, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum PresenceEvent { + #[serde(rename = "presence.changed")] + Changed(PresenceChangedService), + #[serde(rename = "presence.custom_status_updated")] + CustomStatusUpdated(CustomStatusUpdatedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PresenceChangedService { + pub user: UserInfo, + #[serde(rename = "workspace")] + pub project: Option, + pub status: UserPresenceStatus, + pub last_seen_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CustomStatusUpdatedService { + pub user: UserInfo, + pub emoji: Option, + pub text: Option, + pub expires_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PresenceUpdateClient { + pub status: UserPresenceStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CustomStatusUpdateClient { + pub emoji: Option, + pub text: Option, + pub expires_at: Option>, +} diff --git a/lib/channel/event/reaction.rs b/lib/channel/event/reaction.rs new file mode 100644 index 0000000..a618777 --- /dev/null +++ b/lib/channel/event/reaction.rs @@ -0,0 +1,73 @@ +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use serde::{Deserialize, Serialize}; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ReactionEventType { + Added, + Removed, + BatchUpdated, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ReactionEvent { + #[serde(rename = "reaction.added")] + Added(ReactionAddedService), + #[serde(rename = "reaction.removed")] + Removed(ReactionRemovedService), + #[serde(rename = "reaction.batch_updated")] + BatchUpdated(ReactionBatchUpdatedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionAddedService { + pub id: Uuid, + pub room: RoomInfo, + pub message: Uuid, + pub user: UserInfo, + pub emoji: String, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionRemovedService { + pub id: Uuid, + pub room: RoomInfo, + pub message: Uuid, + pub user: UserInfo, + pub emoji: String, + pub removed_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionBatchUpdatedService { + pub room: RoomInfo, + pub message: Uuid, + pub reactions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionGroup { + pub emoji: String, + pub count: i64, + pub reacted_by_me: bool, + pub users: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionAddClient { + pub room: RoomInfo, + pub message: Uuid, + pub emoji: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionRemoveClient { + pub room: RoomInfo, + pub message: Uuid, + pub emoji: String, +} diff --git a/lib/channel/event/rooms.rs b/lib/channel/event/rooms.rs new file mode 100644 index 0000000..34b5cbf --- /dev/null +++ b/lib/channel/event/rooms.rs @@ -0,0 +1,134 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RoomEventType { + Created, + Deleted, + Renamed, + TopicUpdated, + SettingsUpdated, + Moved, + AiUpdated, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum RoomEvent { + #[serde(rename = "room.created")] + Created(RoomCreatedService), + #[serde(rename = "room.deleted")] + Deleted(RoomDeletedService), + #[serde(rename = "room.renamed")] + Renamed(RoomRenamedService), + #[serde(rename = "room.moved")] + Moved(RoomMovedService), + #[serde(rename = "room.ai_updated")] + AiUpdated(RoomAiUpdatedService), + #[serde(rename = "room.topic_updated")] + TopicUpdated(RoomTopicUpdatedService), + #[serde(rename = "room.settings_updated")] + SettingsUpdated(RoomSettingsUpdatedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomCreatedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub public: bool, + pub category: Option, + pub created_by: UserInfo, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomDeletedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub deleted_by: UserInfo, + pub deleted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomRenamedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub old_name: String, + pub new_name: String, + pub renamed_by: UserInfo, + pub renamed_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomMovedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub old_category: Option, + pub new_category: Option, + pub moved_by: UserInfo, + pub moved_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomAiUpdatedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub model: Uuid, + pub model_name: String, + pub version: i64, + pub agent_type: String, + pub updated_by: UserInfo, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomTopicUpdatedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub old_topic: Option, + pub new_topic: Option, + pub updated_by: UserInfo, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomSettingsUpdatedService { + pub room: RoomInfo, + pub workspace: WorkspaceInfo, + pub slowmode_seconds: Option, + pub nsfw: bool, + pub default_auto_archive_duration: Option, + pub updated_by: UserInfo, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomCreateClient { + pub workspace: WorkspaceInfo, + pub room_name: String, + pub public: bool, + pub category: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomUpdateClient { + pub room_name: Option, + pub public: Option, + pub category: Option, + pub topic: Option, + pub slowmode_seconds: Option, + pub nsfw: Option, + pub default_auto_archive_duration: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomDeleteClient {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoomLoadClient { + pub room: RoomInfo, +} diff --git a/lib/channel/event/search.rs b/lib/channel/event/search.rs new file mode 100644 index 0000000..3240117 --- /dev/null +++ b/lib/channel/event/search.rs @@ -0,0 +1,33 @@ +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use serde::{Deserialize, Serialize}; + +use crate::event::{RoomInfo, UserInfo, message::MessageNewService}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResultService { + pub q: String, + pub room: Option, + pub messages: Vec, + pub total: i64, + pub took_ms: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchMessageHitService { + #[serde(flatten)] + pub message: MessageNewService, + pub highlighted_content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchClient { + pub q: String, + pub room: Option, + pub start_time: Option>, + pub end_time: Option>, + pub sender: Option, + pub content_type: Option, + pub limit: Option, + pub offset: Option, +} diff --git a/lib/channel/event/star.rs b/lib/channel/event/star.rs new file mode 100644 index 0000000..8bda44b --- /dev/null +++ b/lib/channel/event/star.rs @@ -0,0 +1,47 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageStarredService { + pub room: RoomInfo, + pub message_id: Uuid, + pub message_seq: i64, + pub starred_by: UserInfo, + pub starred_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageUnstarredService { + pub room: RoomInfo, + pub message_id: Uuid, + pub unstarred_by: UserInfo, + pub unstarred_at: DateTime, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StarredMessageEntry { + pub message_id: Uuid, + pub room: RoomInfo, + pub seq: i64, + pub content: String, + pub content_type: String, + pub sender: UserInfo, + pub starred_at: DateTime, + pub sent_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageStarClient { + pub room: Uuid, + pub message_id: Uuid, + pub star: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StarredListClient { + pub room: Option, + pub limit: Option, + pub offset: Option, +} diff --git a/lib/channel/event/thread.rs b/lib/channel/event/thread.rs new file mode 100644 index 0000000..3d373fc --- /dev/null +++ b/lib/channel/event/thread.rs @@ -0,0 +1,105 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ThreadEventType { + Created, + Updated, + Resolved, + Archived, + ParticipantJoined, + ParticipantLeft, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ThreadEvent { + #[serde(rename = "thread.created")] + Created(ThreadCreatedService), + #[serde(rename = "thread.updated")] + Updated(ThreadUpdatedService), + #[serde(rename = "thread.resolved")] + Resolved(ThreadResolvedService), + #[serde(rename = "thread.archived")] + Archived(ThreadArchivedService), + #[serde(rename = "thread.participant_joined")] + ParticipantJoined(ThreadParticipantJoinedService), + #[serde(rename = "thread.participant_left")] + ParticipantLeft(ThreadParticipantLeftService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadCreatedService { + pub id: Uuid, + pub room: RoomInfo, + pub parent: i64, + pub created_by: UserInfo, + pub participants: serde_json::Value, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadUpdatedService { + pub id: Uuid, + pub room: RoomInfo, + pub last_message_at: Option>, + pub last_message_preview: Option, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadResolvedService { + pub id: Uuid, + pub room: RoomInfo, + pub resolved_by: UserInfo, + pub resolved_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadArchivedService { + pub id: Uuid, + pub room: RoomInfo, + pub archived_by: UserInfo, + pub archived_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadParticipantJoinedService { + pub id: Uuid, + pub room: RoomInfo, + pub user: UserInfo, + pub joined_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadParticipantLeftService { + pub id: Uuid, + pub room: RoomInfo, + pub user: UserInfo, + pub left_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadCreateClient { + pub room: RoomInfo, + pub parent_seq: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadResolveClient { + pub thread_id: Uuid, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadArchiveClient { + pub thread_id: Uuid, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadLoadClient { + pub thread_id: Uuid, +} diff --git a/lib/channel/event/voice.rs b/lib/channel/event/voice.rs new file mode 100644 index 0000000..1fd15f9 --- /dev/null +++ b/lib/channel/event/voice.rs @@ -0,0 +1,125 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum VoiceEventType { + ChannelJoined, + ChannelLeft, + MuteUpdated, + DeafUpdated, + ScreenShareStarted, + ScreenShareStopped, + SpeakingStarted, + SpeakingStopped, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum VoiceEvent { + #[serde(rename = "voice.channel_joined")] + ChannelJoined(VoiceChannelJoinedService), + #[serde(rename = "voice.channel_left")] + ChannelLeft(VoiceChannelLeftService), + #[serde(rename = "voice.mute_updated")] + MuteUpdated(VoiceMuteUpdatedService), + #[serde(rename = "voice.deaf_updated")] + DeafUpdated(VoiceDeafUpdatedService), + #[serde(rename = "voice.screen_share_started")] + ScreenShareStarted(ScreenShareStartedService), + #[serde(rename = "voice.screen_share_stopped")] + ScreenShareStopped(ScreenShareStoppedService), + #[serde(rename = "voice.speaking_started")] + SpeakingStarted(SpeakingStartedService), + #[serde(rename = "voice.speaking_stopped")] + SpeakingStopped(SpeakingStoppedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceChannelJoinedService { + pub room: RoomInfo, + pub workspace: Option, + pub user: UserInfo, + pub muted: bool, + pub deafened: bool, + pub video: bool, + pub joined_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceChannelLeftService { + pub room: RoomInfo, + pub workspace: Option, + pub user: UserInfo, + pub left_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceMuteUpdatedService { + pub room: RoomInfo, + pub user: UserInfo, + pub muted: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceDeafUpdatedService { + pub room: RoomInfo, + pub user: UserInfo, + pub deafened: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScreenShareStartedService { + pub room: RoomInfo, + pub user: UserInfo, + pub started_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScreenShareStoppedService { + pub room: RoomInfo, + pub user: UserInfo, + pub stopped_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpeakingStartedService { + pub room: RoomInfo, + pub user: UserInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpeakingStoppedService { + pub room: RoomInfo, + pub user: UserInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceJoinClient { + pub room: RoomInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceLeaveClient { + pub room: RoomInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceMuteClient { + pub room: RoomInfo, + pub muted: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VoiceDeafClient { + pub room: RoomInfo, + pub deafened: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScreenShareClient { + pub room: RoomInfo, + pub start: bool, +} diff --git a/lib/channel/event/workspace.rs b/lib/channel/event/workspace.rs new file mode 100644 index 0000000..1eabbe9 --- /dev/null +++ b/lib/channel/event/workspace.rs @@ -0,0 +1,120 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkspaceEventType { + RoomCreated, + RoomDeleted, + RoomRenamed, + RoomMoved, + RepoCreated, + RepoUpdated, + RepoDeleted, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum WorkspaceEvent { + #[serde(rename = "workspace.room_created")] + RoomCreated(WorkspaceRoomCreatedService), + #[serde(rename = "workspace.room_deleted")] + RoomDeleted(WorkspaceRoomDeletedService), + #[serde(rename = "workspace.room_renamed")] + RoomRenamed(WorkspaceRoomRenamedService), + #[serde(rename = "workspace.room_moved")] + RoomMoved(WorkspaceRoomMovedService), + #[serde(rename = "workspace.repo_created")] + RepoCreated(WorkspaceRepoCreatedService), + #[serde(rename = "workspace.repo_updated")] + RepoUpdated(WorkspaceRepoUpdatedService), + #[serde(rename = "workspace.repo_deleted")] + RepoDeleted(WorkspaceRepoDeletedService), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomCreatedService { + pub workspace: WorkspaceInfo, + pub room: RoomInfo, + pub public: bool, + pub category: Option, + pub created_by: UserInfo, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomDeletedService { + pub workspace: WorkspaceInfo, + pub room: RoomInfo, + pub deleted_by: UserInfo, + pub deleted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomRenamedService { + pub workspace: WorkspaceInfo, + pub room: RoomInfo, + pub old_name: String, + pub new_name: String, + pub renamed_by: UserInfo, + pub renamed_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomMovedService { + pub workspace: WorkspaceInfo, + pub room: RoomInfo, + pub old_category: Option, + pub new_category: Option, + pub moved_by: UserInfo, + pub moved_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRepoCreatedService { + pub workspace: WorkspaceInfo, + pub repo: Uuid, + pub repo_name: String, + pub created_by: UserInfo, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRepoUpdatedService { + pub workspace: WorkspaceInfo, + pub repo: Uuid, + pub updated_by: UserInfo, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRepoDeletedService { + pub workspace: WorkspaceInfo, + pub repo: Uuid, + pub deleted_by: UserInfo, + pub deleted_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomCreateClient { + pub workspace: Uuid, + pub room_name: String, + pub public: bool, + pub category: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomUpdateClient { + pub room: Uuid, + pub room_name: Option, + pub public: Option, + pub category: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceRoomDeleteClient { + pub room: Uuid, +} diff --git a/lib/channel/http/dispatch.rs b/lib/channel/http/dispatch.rs new file mode 100644 index 0000000..714ed56 --- /dev/null +++ b/lib/channel/http/dispatch.rs @@ -0,0 +1,114 @@ +use uuid::Uuid; + +use crate::event::{member, message, notify, reaction}; +use crate::event::{RoomInfo, UserInfo}; + +use super::out_event::WsOutEvent; + +pub struct EventDispatcher; + +impl EventDispatcher { + pub fn dispatch_message( + room_id: Uuid, + room_name: &str, + msg: &model::room::RoomMessageModel, + ) -> WsOutEvent { + let room = RoomInfo { + id: room_id, + name: room_name.to_string(), + }; + WsOutEvent::MessageNew { + room: room.clone(), + data: message::MessageNewService { + id: msg.id, + seq: msg.seq, + room, + sender_type: "user".to_string(), + sender: UserInfo::unknown(msg.author), + thread: msg.thread, + in_reply_to: msg.parent, + content: msg.content.clone(), + content_type: msg.content_type.clone(), + pinned: msg.pinned, + system_type: msg.system_type.clone(), + metadata: msg.metadata.clone(), + thinking_content: None, + thinking_is_chunked: None, + send_at: msg.created_at, + reactions: vec![], + }, + } + } + + pub fn dispatch_typing_start( + room_id: Uuid, + _room_name: &str, + user_id: Uuid, + display_name: Option, + ) -> WsOutEvent { + let user = UserInfo { + id: user_id, + username: display_name.clone().unwrap_or_default(), + display_name: display_name.unwrap_or_default(), + avatar_url: String::new(), + }; + let room = RoomInfo::unknown(room_id); + WsOutEvent::TypingStart { + room: room.clone(), + data: member::TypingStartService { + room, + user, + sender_type: "user".to_string(), + }, + } + } + + pub fn dispatch_typing_stop( + room_id: Uuid, + _room_name: &str, + user_id: Uuid, + display_name: Option, + ) -> WsOutEvent { + let user = UserInfo { + id: user_id, + username: display_name.clone().unwrap_or_default(), + display_name: display_name.unwrap_or_default(), + avatar_url: String::new(), + }; + let room = RoomInfo::unknown(room_id); + WsOutEvent::TypingStop { + room: room.clone(), + data: member::TypingStopService { + room, + user, + sender_type: "user".to_string(), + }, + } + } + + pub fn dispatch_reactions( + room_id: Uuid, + room_name: &str, + message_id: Uuid, + groups: Vec, + ) -> WsOutEvent { + let room = RoomInfo { + id: room_id, + name: room_name.to_string(), + }; + WsOutEvent::ReactionBatchUpdated { + room: room.clone(), + data: reaction::ReactionBatchUpdatedService { + room, + message: message_id, + reactions: groups, + }, + } + } + + pub fn dispatch_notification( + data: notify::NotifyCreatedService, + ) -> WsOutEvent { + WsOutEvent::NotifyCreated { data } + } +} diff --git a/lib/channel/http/handler/ai.rs b/lib/channel/http/handler/ai.rs new file mode 100644 index 0000000..1e0be2e --- /dev/null +++ b/lib/channel/http/handler/ai.rs @@ -0,0 +1,159 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{AgentInfo, RoomInfo, ai}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn ai_list( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let rows = db::sqlx::query_as::<_, (Uuid, Option, Option, Option, bool, bool)>( + "SELECT ra.agent_session, s.name, s.agent_kind, s.model_version, ra.enabled, ra.auto_reply \ + FROM room_ai ra \ + LEFT JOIN agent_session s ON s.id = ra.agent_session AND s.deleted_at IS NULL \ + WHERE ra.room = $1", + ) + .bind(room) + .fetch_all(bus.inner.db.reader()) + .await?; + + let agents = rows + .into_iter() + .filter_map(|(agent_session, name, agent_kind, model_version, enabled, auto_reply)| { + name.map(|n| ai::RoomAiEntry { + agent_session, + name: n, + agent_kind: agent_kind.unwrap_or_default(), + model_version, + enabled, + auto_reply, + }) + }) + .collect(); + + let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + Ok(Some(WsOutEvent::AiAgentList { + room: ai_room.clone(), + data: ai::RoomAiListService { + room: ai_room, + agents, + }, + })) + } + + pub(super) async fn ai_upsert( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + model: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, enabled, created_by, \ + created_at, updated_at, deleted_at \ + FROM agent_session WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(model) + .fetch_one(bus.inner.db.reader()) + .await + .map_err(|e| match e { + db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound, + other => ChannelError::Database(other), + })?; + db::sqlx::query_as::<_, model::room::RoomAiModel>( + "INSERT INTO room_ai (room, agent_session, enabled, auto_reply, created_by, created_at, updated_at) \ + VALUES ($1, $2, true, false, $3, now(), now()) \ + ON CONFLICT (room, agent_session) DO UPDATE SET enabled = true, updated_at = now() \ + RETURNING room, agent_session, enabled, auto_reply, created_by, created_at, updated_at", + ) + .bind(room) + .bind(model) + .bind(user_id) + .fetch_one(bus.inner.db.writer()) + .await?; + let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let data = ai::AiAgentJoinedService { + room: ai_room, + agent: AgentInfo { + id: model, + name: session.name.clone(), + agent_type: session.agent_kind.clone(), + model_name: None, + }, + joined_at: Utc::now(), + }; + bus.publish_room_event(room, "ai.agent_joined", &data) + .await?; + + Ok(Some(WsOutEvent::AiAgentJoined { room: data.room.clone(), data })) + } + + pub(super) async fn ai_delete( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + agent_id: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let session = db::sqlx::query_as::<_, model::agent::AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, enabled, created_by, \ + created_at, updated_at, deleted_at \ + FROM agent_session WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(agent_id) + .fetch_optional(bus.inner.db.reader()) + .await?; + + let result = db::sqlx::query( + "DELETE FROM room_ai WHERE room = $1 AND agent_session = $2", + ) + .bind(room) + .bind(agent_id) + .execute(bus.inner.db.writer()) + .await?; + + if result.rows_affected() == 0 { + return Err(ChannelError::RoomNotFound); + } + let ai_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let agent_info = session.map(|s| AgentInfo { + id: s.id, + name: s.name, + agent_type: s.agent_kind, + model_name: None, + }).unwrap_or_else(|| AgentInfo::unknown(agent_id)); + + let data = ai::AiAgentLeftService { + room: ai_room, + agent: agent_info, + left_at: Utc::now(), + }; + bus.publish_room_event(room, "ai.agent_left", &data).await?; + + Ok(Some(WsOutEvent::AiAgentLeft { room: data.room.clone(), data })) + } + + pub(super) async fn ai_stop( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + bus.publish_room_event( + room, + "ai.stop", + &serde_json::json!({"stopped_by": user_id}), + ) + .await?; + Ok(None) + } +} diff --git a/lib/channel/http/handler/ban.rs b/lib/channel/http/handler/ban.rs new file mode 100644 index 0000000..1fe8760 --- /dev/null +++ b/lib/channel/http/handler/ban.rs @@ -0,0 +1,74 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{UserInfo, WorkspaceInfo, ban}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn ban_create( + bus: &ChannelBus, + _user_id: Uuid, + workspace: Uuid, + user: Uuid, + reason: Option, + _expires_at: Option>, + ) -> ChannelResult> { + Self::ensure_workspace_member(bus, _user_id, workspace).await?; + db::sqlx::query( + "INSERT INTO user_blacklist (\"user\", black, created_at) \ + VALUES ($1, $2, now()) \ + ON CONFLICT DO NOTHING", + ) + .bind(user) + .bind(_user_id) + .execute(bus.inner.db.writer()) + .await?; + let ban_key = format!("ban:{}:{}:{}", workspace, _user_id, user); + let ban_data = serde_json::json!({ + "workspace": workspace, + "banned_by": _user_id, + "reason": reason, + "expires_at": _expires_at, + "banned_at": Utc::now(), + }); + bus.inner.cache.set(&ban_key, &ban_data).await?; + let data = ban::BannedService { + workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + user: bus.lookup_user(user).await.unwrap_or_else(|_| UserInfo::unknown(user)), + banned_by: bus.lookup_user(_user_id).await.unwrap_or_else(|_| UserInfo::unknown(_user_id)), + reason, + expires_at: _expires_at, + banned_at: Utc::now(), + }; + bus.workspace_changed(workspace).await?; + Ok(Some(WsOutEvent::UserBanned { data })) + } + + pub(super) async fn ban_remove( + bus: &ChannelBus, + _user_id: Uuid, + workspace: Uuid, + user: Uuid, + ) -> ChannelResult> { + db::sqlx::query( + "DELETE FROM user_blacklist WHERE \"user\" = $1 AND black = $2", + ) + .bind(user) + .bind(_user_id) + .execute(bus.inner.db.writer()) + .await?; + let ban_key = format!("ban:{}:{}:{}", workspace, _user_id, user); + bus.inner.cache.remove(&ban_key).await?; + let data = ban::UnbannedService { + workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + user: bus.lookup_user(user).await.unwrap_or_else(|_| UserInfo::unknown(user)), + unbanned_by: bus.lookup_user(_user_id).await.unwrap_or_else(|_| UserInfo::unknown(_user_id)), + unbanned_at: Utc::now(), + }; + bus.workspace_changed(workspace).await?; + Ok(Some(WsOutEvent::UserUnbanned { data })) + } +} diff --git a/lib/channel/http/handler/category.rs b/lib/channel/http/handler/category.rs new file mode 100644 index 0000000..017609f --- /dev/null +++ b/lib/channel/http/handler/category.rs @@ -0,0 +1,138 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{UserInfo, WorkspaceInfo, category}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::MAX_CATEGORY_NAME_LEN; +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn category_create( + bus: &ChannelBus, + user_id: Uuid, + workspace: Uuid, + name: String, + position: Option, + ) -> ChannelResult> { + if name.is_empty() || name.len() > MAX_CATEGORY_NAME_LEN { + return Err(ChannelError::Validation("invalid category name".into())); + } + Self::ensure_workspace_member(bus, user_id, workspace).await?; + let row = db::sqlx::query_as::<_, model::room::RoomCategoryModel>( + "INSERT INTO room_category (wk, name, position, created_at, updated_at) \ + VALUES ($1, $2, $3, now(), now()) \ + RETURNING id, wk, name, position, collapsed, created_at, updated_at", + ) + .bind(workspace) + .bind(&name) + .bind(position.unwrap_or(0)) + .fetch_one(bus.inner.db.writer()) + .await?; + let cc_workspace = bus + .lookup_workspace(workspace) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)); + let cc_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = category::CategoryCreatedService { + id: row.id, + project: cc_workspace, + name: row.name, + position: row.position, + created_by: cc_user, + created_at: row.created_at, + }; + bus.workspace_changed(workspace).await?; + Ok(Some(WsOutEvent::CategoryCreated { + workspace: data.project.clone(), + data, + })) + } + + pub(super) async fn category_update( + bus: &ChannelBus, + user_id: Uuid, + id: Uuid, + name: Option, + position: Option, + ) -> ChannelResult> { + let old = db::sqlx::query_as::<_, model::room::RoomCategoryModel>( + "SELECT id, wk, name, position, collapsed, created_at, updated_at \ + FROM room_category WHERE id = $1", + ) + .bind(id) + .fetch_one(bus.inner.db.reader()) + .await?; + Self::ensure_workspace_member(bus, user_id, old.wk).await?; + let new_name = name.unwrap_or(old.name.clone()); + let new_position = position.unwrap_or(old.position); + db::sqlx::query( + "UPDATE room_category SET name = $2, position = $3, updated_at = now() WHERE id = $1", + ) + .bind(id) + .bind(&new_name) + .bind(new_position) + .execute(bus.inner.db.writer()) + .await?; + let cu_workspace = bus + .lookup_workspace(old.wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(old.wk)); + let cu_user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = category::CategoryUpdatedService { + id, + project: cu_workspace, + name: Some(new_name), + position: Some(new_position), + updated_by: cu_user, + updated_at: Utc::now(), + }; + bus.workspace_changed(old.wk).await?; + Ok(Some(WsOutEvent::CategoryUpdated { workspace: data.project.clone(), data })) + } + + pub(super) async fn category_delete( + bus: &ChannelBus, + _user_id: Uuid, + id: Uuid, + ) -> ChannelResult> { + let existing = db::sqlx::query_as::<_, model::room::RoomCategoryModel>( + "SELECT id, wk, name, position, collapsed, created_at, updated_at \ + FROM room_category WHERE id = $1", + ) + .bind(id) + .fetch_one(bus.inner.db.reader()) + .await?; + Self::ensure_workspace_member(bus, _user_id, existing.wk).await?; + let row = db::sqlx::query_as::<_, model::room::RoomCategoryModel>( + "DELETE FROM room_category WHERE id = $1 \ + RETURNING id, wk, name, position, collapsed, created_at, updated_at", + ) + .bind(id) + .fetch_one(bus.inner.db.writer()) + .await?; + let cd_workspace = bus + .lookup_workspace(row.wk) + .await + .unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)); + let cd_user = bus + .lookup_user(_user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(_user_id)); + let data = category::CategoryDeletedService { + id: row.id, + project: cd_workspace.clone(), + deleted_by: cd_user, + deleted_at: Utc::now(), + }; + bus.workspace_changed(row.wk).await?; + Ok(Some(WsOutEvent::CategoryDeleted { workspace: cd_workspace, data })) + } +} diff --git a/lib/channel/http/handler/conversation.rs b/lib/channel/http/handler/conversation.rs new file mode 100644 index 0000000..47ce248 --- /dev/null +++ b/lib/channel/http/handler/conversation.rs @@ -0,0 +1,231 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, conversation}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn conversation_pin( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + pin: bool, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let now = Utc::now(); + db::sqlx::query( + "INSERT INTO user_room_state (\"user\", room, is_pinned, updated_at) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT (\"user\", room) DO UPDATE \ + SET is_pinned = $3, updated_at = $4", + ) + .bind(user_id) + .bind(room) + .bind(pin) + .bind(now) + .execute(bus.inner.db.writer()) + .await?; + + let room_info = + bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + if pin { + let data = conversation::ConversationPinnedService { + user: user_info, + room: room_info.clone(), + pinned_at: now, + }; + bus.emit_to_user(user_id, "conversation.pinned", &data).await?; + Ok(Some(WsOutEvent::ConversationPinned { + room: room_info, + data, + })) + } else { + let data = conversation::ConversationUnpinnedService { + user: user_info, + room: room_info.clone(), + unpinned_at: now, + }; + bus.emit_to_user(user_id, "conversation.unpinned", &data).await?; + Ok(Some(WsOutEvent::ConversationUnpinned { + room: room_info, + data, + })) + } + } + pub(super) async fn conversation_mute( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + mute: bool, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let now = Utc::now(); + + db::sqlx::query( + "INSERT INTO user_room_state (\"user\", room, is_muted, updated_at) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT (\"user\", room) DO UPDATE \ + SET is_muted = $3, updated_at = $4", + ) + .bind(user_id) + .bind(room) + .bind(mute) + .bind(now) + .execute(bus.inner.db.writer()) + .await?; + + let room_info = + bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + if mute { + let data = conversation::ConversationMutedService { + user: user_info, + room: room_info.clone(), + muted_at: now, + }; + bus.emit_to_user(user_id, "conversation.muted", &data).await?; + Ok(Some(WsOutEvent::ConversationMuted { + room: room_info, + data, + })) + } else { + let data = conversation::ConversationUnmutedService { + user: user_info, + room: room_info.clone(), + unmuted_at: now, + }; + bus.emit_to_user(user_id, "conversation.unmuted", &data).await?; + Ok(Some(WsOutEvent::ConversationUnmuted { + room: room_info, + data, + })) + } + } + pub(super) async fn conversation_notify_level( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + notify_level: String, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let valid = matches!(notify_level.as_str(), "all" | "mentions" | "none"); + if !valid { + return Err(crate::ChannelError::Internal( + "notify_level must be 'all', 'mentions', or 'none'".to_string(), + )); + } + + let now = Utc::now(); + let old_level: Option<(String,)> = db::sqlx::query_as( + "SELECT notify_level FROM user_room_state \ + WHERE \"user\" = $1 AND room = $2", + ) + .bind(user_id) + .bind(room) + .fetch_optional(bus.inner.db.reader()) + .await?; + let old = old_level.map(|r| r.0).unwrap_or_else(|| "all".to_string()); + + db::sqlx::query( + "INSERT INTO user_room_state (\"user\", room, notify_level, updated_at) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT (\"user\", room) DO UPDATE \ + SET notify_level = $3, updated_at = $4", + ) + .bind(user_id) + .bind(room) + .bind(¬ify_level) + .bind(now) + .execute(bus.inner.db.writer()) + .await?; + + let room_info = + bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + let data = conversation::ConversationNotifyLevelChangedService { + user: user_info, + room: room_info.clone(), + old_level: old, + new_level: notify_level, + updated_at: now, + }; + bus.emit_to_user(user_id, "conversation.notify_level_changed", &data) + .await?; + Ok(None) + } + pub(super) async fn conversation_list( + bus: &ChannelBus, + user_id: Uuid, + ) -> ChannelResult> { + let rooms = crate::rooms::user_rooms( + &bus.inner.db, + &bus.inner.cache, + &bus.inner.config, + user_id, + ) + .await?; + + if rooms.is_empty() { + return Ok(Some(WsOutEvent::ConversationList { data: vec![] })); + } + let rows = db::sqlx::query_as::<_, ( + Uuid, // room id + String, // room name + String, // room type + bool, // is_pinned + bool, // is_muted + String, // notify_level + i64, // last_read_seq + i64, // max seq from room_message + )>( + "SELECT r.id, r.name, r.room_type, \ + COALESCE(s.is_pinned, false), \ + COALESCE(s.is_muted, false), \ + COALESCE(s.notify_level, 'all'), \ + COALESCE(s.last_read_seq, 0), \ + COALESCE((SELECT MAX(seq) FROM room_message \ + WHERE room = r.id AND deleted_at IS NULL), 0) \ + FROM room r \ + LEFT JOIN user_room_state s ON s.room = r.id AND s.\"user\" = $1 \ + WHERE r.id = ANY($2) AND r.deleted_at IS NULL AND r.is_archived = false \ + ORDER BY COALESCE(s.is_pinned, false) DESC, r.name", + ) + .bind(user_id) + .bind(&rooms) + .fetch_all(bus.inner.db.reader()) + .await?; + + let summaries: Vec = rows + .into_iter() + .map( + |(id, name, room_type, is_pinned, is_muted, notify_level, last_read_seq, max_seq)| { + let unread = (max_seq - last_read_seq).max(0); + conversation::ConversationSummary { + room: id, + room_name: name, + room_type, + is_pinned, + is_muted, + notify_level, + last_read_seq, + max_seq, + unread_count: unread, + last_read_at: None, + } + }, + ) + .collect(); + + Ok(Some(WsOutEvent::ConversationList { data: summaries })) + } +} diff --git a/lib/channel/http/handler/dm.rs b/lib/channel/http/handler/dm.rs new file mode 100644 index 0000000..680a5ee --- /dev/null +++ b/lib/channel/http/handler/dm.rs @@ -0,0 +1,248 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, dm}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn dm_create( + bus: &ChannelBus, + user_id: Uuid, + recipient: Uuid, + ) -> ChannelResult> { + if user_id == recipient { + return Err(crate::ChannelError::Internal( + "cannot create DM with yourself".to_string(), + )); + } + let recipient_exists: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT id FROM \"user\" WHERE id = $1", + ) + .bind(recipient) + .fetch_optional(bus.inner.db.reader()) + .await?; + if recipient_exists.is_none() { + return Err(crate::ChannelError::UserNotFound); + } + let (initiator, other) = if user_id < recipient { + (user_id, recipient) + } else { + (recipient, user_id) + }; + let existing: Option<(Uuid, Uuid, bool)> = db::sqlx::query_as( + "SELECT room, initiator, is_closed FROM dm_conversation \ + WHERE initiator = $1 AND recipient = $2", + ) + .bind(initiator) + .bind(other) + .fetch_optional(bus.inner.db.reader()) + .await?; + + let now = Utc::now(); + + let (room_id, is_reopen) = if let Some((room, _, is_closed)) = existing { + if is_closed { + db::sqlx::query( + "UPDATE dm_conversation SET is_closed = false, closed_at = NULL, \ + updated_at = now() WHERE initiator = $1 AND recipient = $2", + ) + .bind(initiator) + .bind(other) + .execute(bus.inner.db.writer()) + .await?; + db::sqlx::query( + "UPDATE room SET is_archived = false, updated_at = now() WHERE id = $1", + ) + .bind(room) + .execute(bus.inner.db.writer()) + .await?; + + (room, true) + } else { + (room, false) + } + } else { + let shared_wk: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT wm1.wk FROM wk_member wm1 \ + INNER JOIN wk_member wm2 ON wm2.wk = wm1.wk \ + WHERE wm1.\"user\" = $1 AND wm1.leave_at IS NULL \ + AND wm2.\"user\" = $2 AND wm2.leave_at IS NULL \ + LIMIT 1", + ) + .bind(user_id) + .bind(recipient) + .fetch_optional(bus.inner.db.reader()) + .await?; + + let wk = shared_wk.map(|r| r.0).unwrap_or_else(|| { + Uuid::nil() + }); + let room_id = Uuid::new_v4(); + db::sqlx::query( + "INSERT INTO room (id, wk, name, topic, room_type, position, is_private, \ + created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, NULL, 'DM', 0, true, $4, now(), now())", + ) + .bind(room_id) + .bind(wk) + .bind(format!("dm-{}", &room_id.to_string()[..8])) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + db::sqlx::query( + "INSERT INTO dm_conversation (room, initiator, recipient, created_at, updated_at) \ + VALUES ($1, $2, $3, now(), now()) \ + ON CONFLICT (initiator, recipient) DO NOTHING", + ) + .bind(room_id) + .bind(initiator) + .bind(other) + .execute(bus.inner.db.writer()) + .await?; + for uid in &[user_id, recipient] { + db::sqlx::query( + "INSERT INTO room_permission_overwrite \ + (room, target_type, target_id, allow_mask, deny_mask, created_at) \ + VALUES ($1, 'user', $2, 0, 0, now()) \ + ON CONFLICT DO NOTHING", + ) + .bind(room_id) + .bind(uid) + .execute(bus.inner.db.writer()) + .await?; + } + + (room_id, false) + }; + let _ = crate::rooms::refresh_user_rooms_cache( + &bus.inner.db, + &bus.inner.cache, + &bus.inner.config, + user_id, + ) + .await; + let _ = crate::rooms::refresh_user_rooms_cache( + &bus.inner.db, + &bus.inner.cache, + &bus.inner.config, + recipient, + ) + .await; + + let room_info = + bus.lookup_room(room_id).await.unwrap_or_else(|_| RoomInfo::unknown(room_id)); + let initiator_info = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let recipient_info = + bus.lookup_user(recipient).await.unwrap_or_else(|_| UserInfo::unknown(recipient)); + + if is_reopen { + let data = dm::DmReopenedService { + room: room_info.clone(), + reopened_by: initiator_info, + reopened_at: now, + }; + bus.emit_to_user(user_id, "dm.reopened", &data).await?; + bus.emit_to_user(recipient, "dm.reopened", &data).await?; + Ok(Some(WsOutEvent::DmReopened { + room: room_info, + data, + })) + } else { + let data = dm::DmCreatedService { + room: room_info.clone(), + initiator: initiator_info, + recipient: recipient_info, + created_at: now, + }; + bus.emit_to_user(user_id, "dm.created", &data).await?; + bus.emit_to_user(recipient, "dm.created", &data).await?; + Ok(Some(WsOutEvent::DmCreated { + room: room_info, + data, + })) + } + } + pub(super) async fn dm_close( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + let now = Utc::now(); + + let result = db::sqlx::query( + "UPDATE dm_conversation SET is_closed = true, closed_at = $1, updated_at = $1 \ + WHERE room = $2 AND (initiator = $3 OR recipient = $3) AND is_closed = false", + ) + .bind(now) + .bind(room) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + + if result.rows_affected() == 0 { + return Ok(None); + } + + let room_info = + bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let closed_by = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + let data = dm::DmClosedService { + room: room_info.clone(), + closed_by, + closed_at: now, + }; + bus.publish_room_event(room, "dm.closed", &data).await?; + Ok(Some(WsOutEvent::DmClosed { + room: room_info, + data, + })) + } + pub(super) async fn dm_list( + bus: &ChannelBus, + user_id: Uuid, + ) -> ChannelResult> { + let rows = db::sqlx::query_as::<_, (Uuid, Uuid, Uuid, chrono::DateTime)>( + "SELECT dc.room, dc.initiator, dc.recipient, dc.created_at \ + FROM dm_conversation dc \ + INNER JOIN room r ON r.id = dc.room \ + WHERE (dc.initiator = $1 OR dc.recipient = $1) \ + AND dc.is_closed = false \ + AND r.deleted_at IS NULL \ + ORDER BY dc.updated_at DESC", + ) + .bind(user_id) + .fetch_all(bus.inner.db.reader()) + .await?; + + let mut results = Vec::with_capacity(rows.len()); + for (room_id, initiator_id, recipient_id, created_at) in rows { + let room_info = bus + .lookup_room(room_id) + .await + .unwrap_or_else(|_| RoomInfo::unknown(room_id)); + let initiator_info = bus + .lookup_user(initiator_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(initiator_id)); + let recipient_info = bus + .lookup_user(recipient_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(recipient_id)); + + results.push(dm::DmCreatedService { + room: room_info, + initiator: initiator_info, + recipient: recipient_info, + created_at, + }); + } + + Ok(Some(WsOutEvent::DmList { data: results })) + } +} diff --git a/lib/channel/http/handler/draft.rs b/lib/channel/http/handler/draft.rs new file mode 100644 index 0000000..74e432b --- /dev/null +++ b/lib/channel/http/handler/draft.rs @@ -0,0 +1,52 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, draft}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::{MAX_TEXT_LEN}; +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn draft_save( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + content: String, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + if content.len() > MAX_TEXT_LEN { + return Err(ChannelError::Validation("draft too long".into())); + } + let key = format!("draft:{}:{}", user_id, room); + bus.inner.cache.set(&key, &content).await?; + let ds_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let ds_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = draft::DraftSavedService { + user: ds_user, + room: ds_room, + content, + saved_at: Utc::now(), + }; + Ok(Some(WsOutEvent::DraftSaved { room: data.room.clone(), data })) + } + + pub(super) async fn draft_clear( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let key = format!("draft:{}:{}", user_id, room); + bus.inner.cache.remove(&key).await?; + let dc_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let dc_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = draft::DraftClearedService { + user: dc_user, + room: dc_room, + cleared_at: Utc::now(), + }; + Ok(Some(WsOutEvent::DraftCleared { room: data.room.clone(), data })) + } +} diff --git a/lib/channel/http/handler/forward.rs b/lib/channel/http/handler/forward.rs new file mode 100644 index 0000000..5fb6085 --- /dev/null +++ b/lib/channel/http/handler/forward.rs @@ -0,0 +1,89 @@ +use uuid::Uuid; + +use crate::event::{RoomInfo, forward}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn message_forward( + bus: &ChannelBus, + user_id: Uuid, + source_message_id: Uuid, + target_room: Uuid, + ) -> ChannelResult> { + let source = Self::load_message(bus, source_message_id).await?; + + Self::ensure_room_access(bus, user_id, source.room).await?; + Self::ensure_room_access(bus, user_id, target_room).await?; + + let seq = bus.inner.seq.seq(target_room).await?; + let sender = bus.lookup_user(user_id).await?; + + let source_room_info = bus + .lookup_room(source.room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(source.room)); + let forward_content = format!( + "> Forwarded from {}:\n\n{}", + source_room_info.name, source.content + ); + + let metadata = serde_json::json!({ + "source_room_id": source.room, + "source_message_id": source.id, + "source_room_name": source_room_info.name, + "forwarded_by": user_id, + }); + + let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "INSERT INTO room_message \ + (room, seq, thread, parent, author, content, content_type, system_type, metadata) \ + VALUES ($1, $2, NULL, NULL, $3, $4, 'forward', NULL, $5) \ + RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at", + ) + .bind(target_room) + .bind(seq) + .bind(user_id) + .bind(&forward_content) + .bind(&metadata) + .fetch_one(bus.inner.db.writer()) + .await?; + + let target_room_info = bus + .lookup_room(target_room) + .await + .unwrap_or_else(|_| RoomInfo::unknown(target_room)); + + let fwd_id = row.id; + let fwd_seq = row.seq; + let fwd_content = row.content.clone(); + let fwd_content_type = row.content_type.clone(); + let fwd_created_at = row.created_at; + + bus.publish_room_message( + row, + Some(bus.lookup_user(user_id).await?), + ) + .await?; + + let data = forward::MessageForwardedService { + id: fwd_id, + seq: fwd_seq, + room: target_room_info.clone(), + sender, + content: fwd_content, + content_type: fwd_content_type, + source_room: source_room_info, + source_message_id: source.id, + forwarded_at: fwd_created_at, + }; + + Ok(Some(WsOutEvent::MessageForwarded { + room: target_room_info, + data, + })) + } +} diff --git a/lib/channel/http/handler/helpers.rs b/lib/channel/http/handler/helpers.rs new file mode 100644 index 0000000..7a487d2 --- /dev/null +++ b/lib/channel/http/handler/helpers.rs @@ -0,0 +1,157 @@ +use std::collections::HashMap; + +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, message, reaction}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsHandler; + +impl WsHandler { + pub(super) async fn ensure_room_access( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult<()> { + let rooms = crate::rooms::user_rooms( + &bus.inner.db, + &bus.inner.cache, + &bus.inner.config, + user_id, + ) + .await?; + if rooms.contains(&room) { + Ok(()) + } else { + Err(ChannelError::AccessDenied) + } + } + + pub(super) async fn ensure_message_in_room( + bus: &ChannelBus, + room: Uuid, + message: Uuid, + ) -> ChannelResult<()> { + let exists: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT id FROM room_message WHERE id = $1 AND room = $2 AND deleted_at IS NULL", + ) + .bind(message) + .bind(room) + .fetch_optional(bus.inner.db.reader()) + .await?; + exists.map(|_| ()).ok_or(ChannelError::RoomNotFound) + } + + pub(super) async fn reaction_groups_for_messages( + bus: &ChannelBus, + user_id: Uuid, + message_ids: &[Uuid], + ) -> ChannelResult>> { + if message_ids.is_empty() { + return Ok(HashMap::new()); + } + + let rows = db::sqlx::query_as::<_, (Uuid, String, Uuid)>( + "SELECT message, reaction, \"user\" FROM room_reaction \ + WHERE message = ANY($1) ORDER BY created_at ASC", + ) + .bind(message_ids) + .fetch_all(bus.inner.db.reader()) + .await?; + + let user_ids: Vec = rows.iter().map(|(_, _, user)| *user).collect(); + let users = bus.lookup_users(&user_ids).await.unwrap_or_default(); + let mut grouped: HashMap> = + HashMap::new(); + + for (message_id, emoji, reactor) in rows { + let group = grouped + .entry(message_id) + .or_default() + .entry(emoji.clone()) + .or_insert_with(|| reaction::ReactionGroup { + emoji: emoji.clone(), + count: 0, + reacted_by_me: false, + users: Vec::new(), + }); + group.count += 1; + group.reacted_by_me |= reactor == user_id; + group.users.push( + users + .get(&reactor) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(reactor)), + ); + } + + Ok(grouped + .into_iter() + .map(|(message_id, groups)| { + (message_id, groups.into_values().collect::>()) + }) + .collect()) + } + + pub(super) async fn ensure_workspace_member( + bus: &ChannelBus, + user_id: Uuid, + wk: Uuid, + ) -> ChannelResult<()> { + let row: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT wk FROM wk_member WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(wk) + .bind(user_id) + .fetch_optional(bus.inner.db.reader()) + .await?; + row.map(|_| ()).ok_or(ChannelError::AccessDenied) + } + + #[allow(dead_code)] + pub(super) fn missed_message_data( + m: crate::MissedMessage, + ) -> message::MessageNewService { + message::MessageNewService { + id: m.message_id, + seq: m.seq, + room: RoomInfo::unknown(m.room_id), + sender_type: "user".to_string(), + sender: UserInfo::unknown(m.sender_id), + thread: None, + in_reply_to: None, + content: m.content, + content_type: "text".to_string(), + pinned: false, + system_type: None, + metadata: serde_json::Value::Null, + thinking_content: None, + thinking_is_chunked: None, + send_at: m.send_at, + reactions: vec![], + } + } + #[allow(dead_code)] + pub(super) fn message_data( + m: model::room::RoomMessageModel, + ) -> message::MessageNewService { + message::MessageNewService { + id: m.id, + seq: m.seq, + room: RoomInfo::unknown(m.room), + sender_type: "user".to_string(), + sender: UserInfo::unknown(m.author), + thread: m.thread, + in_reply_to: m.parent, + content: m.content, + content_type: m.content_type, + pinned: m.pinned, + system_type: m.system_type, + metadata: m.metadata, + thinking_content: None, + thinking_is_chunked: None, + send_at: m.created_at, + reactions: vec![], + } + } +} diff --git a/lib/channel/http/handler/invite.rs b/lib/channel/http/handler/invite.rs new file mode 100644 index 0000000..cde04e0 --- /dev/null +++ b/lib/channel/http/handler/invite.rs @@ -0,0 +1,123 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, invite}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn invite_create( + bus: &ChannelBus, + user_id: Uuid, + workspace: Uuid, + _room: Option, + _max_uses: Option, + _expires_at: Option>, + ) -> ChannelResult> { + Self::ensure_workspace_member(bus, user_id, workspace).await?; + let invite_id = Uuid::now_v7(); + let code = Uuid::now_v7().to_string(); + let id_key = format!("invite:id:{}", invite_id); + let code_key = format!("invite:code:{}", code); + let meta = serde_json::json!({ + "workspace": workspace, + "created_by": user_id, + "room": _room, + "max_uses": _max_uses, + "expires_at": _expires_at, + }); + bus.inner.cache.set(&id_key, &meta.to_string()).await?; + bus.inner.cache.set(&code_key, &invite_id.to_string()).await?; + let inv_room = match _room { + Some(r) => Some(bus.lookup_room(r).await.unwrap_or_else(|_| RoomInfo::unknown(r))), + None => None, + }; + let data = invite::InviteCreatedService { + id: invite_id, + workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + room: inv_room, + inviter: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + invitee: None, + code, + max_uses: _max_uses, + expires_at: _expires_at, + created_at: Utc::now(), + }; + Ok(Some(WsOutEvent::InviteCreated { data })) + } + + pub(super) async fn invite_accept( + bus: &ChannelBus, + user_id: Uuid, + code: String, + ) -> ChannelResult> { + let code_key = format!("invite:code:{}", code); + let invite_id_str: Option = bus.inner.cache.get(&code_key).await?; + let invite_id = invite_id_str + .as_deref() + .and_then(|s| Uuid::parse_str(s).ok()) + .ok_or(ChannelError::RoomNotFound)?; + let id_key = format!("invite:id:{}", invite_id); + let stored: Option = bus.inner.cache.get(&id_key).await?; + let meta: serde_json::Value = stored + .as_deref() + .and_then(|s| serde_json::from_str(s).ok()) + .ok_or(ChannelError::RoomNotFound)?; + let wk = meta["workspace"] + .as_str() + .and_then(|s| Uuid::parse_str(s).ok()) + .ok_or(ChannelError::RoomNotFound)?; + db::sqlx::query( + "INSERT INTO wk_member (wk, \"user\", owner, admin, join_at) \ + VALUES ($1, $2, false, false, now()) \ + ON CONFLICT DO NOTHING", + ) + .bind(wk) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + db::sqlx::query( + "INSERT INTO wk_apply_join (wk, \"user\", status, created_at, updated_at) \ + VALUES ($1, $2, 'accepted', now(), now())", + ) + .bind(wk) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + bus.inner.cache.remove(&code_key).await?; + bus.inner.cache.remove(&id_key).await?; + let data = invite::InviteAcceptedService { + id: Uuid::now_v7(), + workspace: bus.lookup_workspace(wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(wk)), + room: None, + user: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + accepted_at: Utc::now(), + }; + bus.workspace_changed(wk).await?; + Ok(Some(WsOutEvent::InviteAccepted { data })) + } + + pub(super) async fn invite_revoke( + bus: &ChannelBus, + _user_id: Uuid, + id: Uuid, + ) -> ChannelResult> { + let id_key = format!("invite:id:{}", id); + let stored: Option = bus.inner.cache.get(&id_key).await?; + let meta: serde_json::Value = stored + .as_deref() + .and_then(|s| serde_json::from_str(s).ok()) + .ok_or(ChannelError::RoomNotFound)?; + let created_by = meta["created_by"] + .as_str() + .and_then(|s| Uuid::parse_str(s).ok()) + .ok_or(ChannelError::RoomNotFound)?; + if created_by != _user_id { + return Err(ChannelError::AccessDenied); + } + bus.inner.cache.remove(&id_key).await?; + Ok(None) + } +} diff --git a/lib/channel/http/handler/message.rs b/lib/channel/http/handler/message.rs new file mode 100644 index 0000000..de60d10 --- /dev/null +++ b/lib/channel/http/handler/message.rs @@ -0,0 +1,590 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, message, thread}; +use crate::{ + ChannelBus, ChannelError, ChannelResult, + pagination::{MessagePagination, PaginationDirection, PaginationParams}, +}; + +use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN}; +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + /// Count non-deleted sibling replies to the same parent message. + async fn count_sibling_replies( + bus: &ChannelBus, + parent_id: Uuid, + ) -> ChannelResult { + let (count,): (i64,) = db::sqlx::query_as( + "SELECT COUNT(*) FROM room_message WHERE parent = $1 AND deleted_at IS NULL", + ) + .bind(parent_id) + .fetch_one(bus.inner.db.reader()) + .await?; + Ok(count) + } + + /// Walk the reply parent chain and return (root_message_id, root_message_seq, chain_depth). + async fn reply_chain_info( + bus: &ChannelBus, + parent_id: Uuid, + ) -> ChannelResult<(Uuid, i64, i32)> { + let rows: Vec<(Uuid, i64, i32)> = db::sqlx::query_as( + r#"WITH RECURSIVE chain AS ( + SELECT id, parent, seq, 1 AS depth + FROM room_message + WHERE id = $1 AND deleted_at IS NULL + UNION ALL + SELECT m.id, m.parent, m.seq, c.depth + 1 + FROM room_message m + JOIN chain c ON m.id = c.parent + WHERE m.deleted_at IS NULL + ) + SELECT id, seq, depth FROM chain ORDER BY depth DESC"#, + ) + .bind(parent_id) + .fetch_all(bus.inner.db.reader()) + .await?; + + let root_id = rows.first().map(|r| r.0).unwrap_or(parent_id); + let root_seq = rows.first().map(|r| r.1).unwrap_or(0); + let max_depth = rows.first().map(|r| r.2).unwrap_or(1); + Ok((root_id, root_seq, max_depth)) + } + + /// Check if any message in the reply parent chain already belongs to a thread. + async fn find_thread_in_chain( + bus: &ChannelBus, + parent_id: Uuid, + ) -> ChannelResult> { + let row: Option<(Uuid,)> = db::sqlx::query_as( + r#"WITH RECURSIVE chain AS ( + SELECT id, parent, thread + FROM room_message + WHERE id = $1 AND deleted_at IS NULL + UNION ALL + SELECT m.id, m.parent, m.thread + FROM room_message m + JOIN chain c ON m.id = c.parent + WHERE m.deleted_at IS NULL + ) + SELECT thread FROM chain WHERE thread IS NOT NULL LIMIT 1"#, + ) + .bind(parent_id) + .fetch_optional(bus.inner.db.reader()) + .await?; + Ok(row.map(|r| r.0)) + } + + /// Update all messages in the reply parent chain to point to the given thread. + async fn attach_chain_to_thread( + bus: &ChannelBus, + parent_id: Uuid, + thread_id: Uuid, + ) -> ChannelResult<()> { + db::sqlx::query( + r#"WITH RECURSIVE chain AS ( + SELECT id FROM room_message + WHERE id = $1 AND deleted_at IS NULL + UNION ALL + SELECT m.id FROM room_message m + JOIN chain c ON m.id = c.parent + WHERE m.deleted_at IS NULL + ) + UPDATE room_message SET thread = $2, updated_at = now() + WHERE id IN (SELECT id FROM chain) AND thread IS NULL"#, + ) + .bind(parent_id) + .bind(thread_id) + .execute(bus.inner.db.writer()) + .await?; + Ok(()) + } + + pub(super) async fn message_create( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + content: String, + content_type: Option, + thread: Option, + in_reply_to: Option, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + if content.len() > MAX_TEXT_LEN { + return Err(ChannelError::Validation( + "message exceeds maximum length".to_string(), + )); + } + if let Some(parent_message) = in_reply_to { + Self::ensure_message_in_room(bus, room, parent_message).await?; + } + + // ── Auto-thread logic ────────────────────────────────────────── + let mut events: Vec = Vec::new(); + let effective_thread: Option = if let Some(ref parent_id) = in_reply_to { + if thread.is_some() { + thread + } else { + let existing = Self::find_thread_in_chain(bus, *parent_id).await?; + if let Some(tid) = existing { + Some(tid) + } else { + let sibling_count = Self::count_sibling_replies(bus, *parent_id).await?; + let (root_id, root_seq, chain_depth) = Self::reply_chain_info(bus, *parent_id).await?; + let should_create = sibling_count >= 3 || chain_depth >= 5; + + if should_create { + let seq = bus.inner.seq.seq(room).await?; + let thread_row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( + "INSERT INTO room_thread (room, seq, starter_message, title, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, '', $4, now(), now()) \ + RETURNING id, room, seq, starter_message, title, created_by, archived, locked, \ + last_message_at, created_at, updated_at, archived_at", + ) + .bind(room) + .bind(seq) + .bind(root_id) // UUID of the root message + .bind(user_id) + .fetch_one(bus.inner.db.writer()) + .await?; + + let new_thread_id = thread_row.id; + Self::attach_chain_to_thread(bus, *parent_id, new_thread_id).await?; + + let tc_room = bus.lookup_room(room).await + .unwrap_or_else(|_| RoomInfo::unknown(room)); + let created_by = bus.lookup_user(user_id).await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = thread::ThreadCreatedService { + id: new_thread_id, + room: tc_room, + parent: root_seq, + created_by, + participants: serde_json::Value::Null, + created_at: thread_row.created_at, + }; + bus.publish_room_event(room, "thread.created", &data).await?; + events.push(WsOutEvent::ThreadCreated { + room: data.room.clone(), + data, + }); + + Some(new_thread_id) + } else { + None + } + } + } + } else { + thread + }; + // ── End auto-thread logic ────────────────────────────────────── + + if let Some(thread_id) = effective_thread { + let exists: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT id FROM room_thread WHERE id = $1 AND room = $2", + ) + .bind(thread_id) + .bind(room) + .fetch_optional(bus.inner.db.reader()) + .await?; + if exists.is_none() { + return Err(ChannelError::RoomNotFound); + } + } + + let seq = bus.inner.seq.seq(room).await?; + let sender = bus.lookup_user(user_id).await?; + let sender_for_response = sender.clone(); + let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "INSERT INTO room_message (room, seq, thread, parent, author, content, content_type) \ + VALUES ($1, $2, $3, $4, $5, $6, $7) \ + RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at", + ) + .bind(room) + .bind(seq) + .bind(effective_thread) + .bind(in_reply_to) + .bind(user_id) + .bind(content) + .bind(content_type.unwrap_or_else(|| "text".to_string())) + .fetch_one(bus.inner.db.writer()) + .await?; + + bus.publish_room_message( + row.clone(), + Some(sender), + ).await?; + let msg_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + events.push(WsOutEvent::MessageNew { + room: msg_room.clone(), + data: message::MessageNewService { + id: row.id, + seq: row.seq, + room: msg_room, + sender_type: "user".to_string(), + sender: sender_for_response, + thread: row.thread, + in_reply_to: row.parent, + content: row.content, + content_type: row.content_type, + pinned: row.pinned, + system_type: row.system_type, + metadata: row.metadata, + thinking_content: None, + thinking_is_chunked: None, + send_at: row.created_at, + reactions: vec![], + }, + }); + + Ok(events.into_iter().find(|e| matches!(e, WsOutEvent::MessageNew { .. }))) + } + + pub(super) async fn message_update( + bus: &ChannelBus, + user_id: Uuid, + message_id: Uuid, + content: String, + ) -> ChannelResult> { + if content.len() > MAX_TEXT_LEN { + return Err(ChannelError::Validation( + "message exceeds maximum length".to_string(), + )); + } + let room_id: (Uuid,) = + db::sqlx::query_as("SELECT room FROM room_message WHERE id = $1 AND deleted_at IS NULL") + .bind(message_id) + .fetch_optional(bus.inner.db.reader()) + .await? + .ok_or(ChannelError::RoomNotFound)?; + Self::ensure_room_access(bus, user_id, room_id.0).await?; + + let old = Self::load_message(bus, message_id).await?; + if old.author != user_id { + return Err(ChannelError::Unauthorized); + } + let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "UPDATE room_message SET content = $2, edited_at = now(), updated_at = now() \ + WHERE id = $1 AND deleted_at IS NULL \ + RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at", + ) + .bind(message_id) + .bind(content) + .fetch_one(bus.inner.db.writer()) + .await?; + db::sqlx::query( + "INSERT INTO room_message_edit_history (message, seq, editor, old_content, new_content) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(message_id) + .bind(row.seq) + .bind(user_id) + .bind(old.content) + .bind(row.content.clone()) + .execute(bus.inner.db.writer()) + .await?; + + let sender = bus.lookup_user(row.author).await.unwrap_or_else(|_| UserInfo::unknown(row.author)); + let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let data = message::MessageEditedService { + id: row.id, + seq: row.seq, + room, + sender, + content: row.content, + edited_at: row.edited_at.unwrap_or_else(Utc::now), + }; + bus.publish_room_event(row.room, "message.edited", &data) + .await?; + Ok(Some(WsOutEvent::MessageEdited { + room: data.room.clone(), + data, + })) + } + + pub(super) async fn message_revoke( + bus: &ChannelBus, + user_id: Uuid, + message_id: Uuid, + ) -> ChannelResult> { + let room_id: (Uuid,) = + db::sqlx::query_as("SELECT room FROM room_message WHERE id = $1 AND deleted_at IS NULL") + .bind(message_id) + .fetch_optional(bus.inner.db.reader()) + .await? + .ok_or(ChannelError::RoomNotFound)?; + Self::ensure_room_access(bus, user_id, room_id.0).await?; + let old = Self::load_message(bus, message_id).await?; + if old.author != user_id { + return Err(ChannelError::Unauthorized); + } + if let Some(window_secs) = bus.inner.config.revoke_window_secs { + let elapsed = Utc::now().signed_duration_since(old.created_at); + if elapsed.num_seconds() > window_secs as i64 { + return Err(ChannelError::Validation( + "message revoke window expired".to_string(), + )); + } + } + let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "UPDATE room_message SET deleted_at = now(), updated_at = now() \ + WHERE id = $1 AND deleted_at IS NULL \ + RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at", + ) + .bind(message_id) + .fetch_one(bus.inner.db.writer()) + .await?; + let revoked_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let data = message::MessageRevokedService { + id: row.id, + seq: row.seq, + room, + revoked_by, + revoked_at: row.deleted_at.unwrap_or_else(Utc::now), + }; + bus.publish_room_event(row.room, "message.revoked", &data) + .await?; + Ok(Some(WsOutEvent::MessageRevoked { + room: data.room.clone(), + data, + })) + } + + pub(super) async fn load_message( + bus: &ChannelBus, + message_id: Uuid, + ) -> ChannelResult { + db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at \ + FROM room_message WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(message_id) + .fetch_one(bus.inner.db.reader()) + .await + .map_err(|e| match e { + db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound, + other => ChannelError::Database(other), + }) + } + + pub(super) async fn message_list( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + before_seq: Option, + after_seq: Option, + limit: Option, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let limit = limit.unwrap_or(MAX_MESSAGES_PER_REQUEST); + let direction = if before_seq.is_some() { + PaginationDirection::Before + } else { + PaginationDirection::After + }; + let cursor = before_seq + .map(|s| s.to_string()) + .or(after_seq.map(|s| s.to_string())); + + let pagination = MessagePagination::new(bus.inner.db.clone()); + let page = pagination + .get_messages(PaginationParams { + room_id: room, + limit, + cursor, + direction, + }) + .await?; + + let mut page_messages = page.messages; + if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none()) { + page_messages.reverse(); + } + let message_ids: Vec = page_messages.iter().map(|m| m.id).collect(); + let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); + + let list_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + + let mut messages: Vec = + Vec::with_capacity(page_messages.len()); + for m in page_messages { + let sender = bus.lookup_user(m.sender_id).await + .unwrap_or_else(|_| UserInfo::unknown(m.sender_id)); + messages.push(message::MessageNewService { + id: m.id, + seq: m.seq, + room: list_room.clone(), + sender_type: "user".to_string(), + sender, + thread: m.thread, + in_reply_to: m.parent, + content: m.content, + content_type: m.content_type, + pinned: m.pinned, + system_type: m.system_type, + metadata: m.metadata, + thinking_content: None, + thinking_is_chunked: None, + send_at: m.send_at, + reactions: reactions.get(&m.id).cloned().unwrap_or_default(), + }); + } + + let total = messages.len() as i64; + Ok(Some(WsOutEvent::MessageList { + room: list_room.clone(), + data: message::MessageListService { + room: list_room, + messages, + total, + }, + })) + } + + pub(super) async fn message_around( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + seq: i64, + limit: Option, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let size = limit + .unwrap_or(MAX_MESSAGES_PER_REQUEST) + .min(MAX_MESSAGES_PER_REQUEST) as i64; + let rows = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "(SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at \ + FROM room_message \ + WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \ + ORDER BY seq DESC LIMIT $3) \ + UNION ALL \ + (SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at \ + FROM room_message \ + WHERE room = $1 AND seq >= $2 AND deleted_at IS NULL \ + ORDER BY seq ASC LIMIT $3) \ + ORDER BY seq ASC", + ) + .bind(room) + .bind(seq) + .bind(size) + .fetch_all(bus.inner.db.reader()) + .await?; + let author_ids: Vec = rows.iter().map(|r| r.author).collect(); + let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); + let message_ids: Vec = rows.iter().map(|r| r.id).collect(); + let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); + let around_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let messages = rows + .into_iter() + .map(|r| { + let sender = user_map + .get(&r.author) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(r.author)); + message::MessageNewService { + id: r.id, + seq: r.seq, + room: around_room.clone(), + sender_type: "user".to_string(), + sender, + thread: r.thread, + in_reply_to: r.parent, + content: r.content, + content_type: r.content_type, + pinned: r.pinned, + system_type: r.system_type, + metadata: r.metadata, + thinking_content: None, + thinking_is_chunked: None, + send_at: r.created_at, + reactions: reactions.get(&r.id).cloned().unwrap_or_default(), + } + }) + .collect::>(); + let total = messages.len() as i64; + Ok(Some(WsOutEvent::MessageList { + room: around_room.clone(), + data: message::MessageListService { + room: around_room, + messages, + total, + }, + })) + } + + pub(super) async fn missed_messages( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + after_seq: i64, + limit: Option, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let limit = + limit.unwrap_or(MAX_MESSAGES_PER_REQUEST as i64).max(0) as usize; + let messages = bus + .inner + .reconnect + .get_missed_messages(room, after_seq) + .await?; + let author_ids: Vec = messages.iter().map(|m| m.sender_id).collect(); + let message_ids: Vec = messages.iter().map(|m| m.message_id).collect(); + let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); + let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); + let missed_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let messages = messages + .into_iter() + .take(limit) + .map(|m| { + let sender = user_map + .get(&m.sender_id) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(m.sender_id)); + message::MessageNewService { + id: m.message_id, + seq: m.seq, + room: missed_room.clone(), + sender_type: "user".to_string(), + sender, + thread: None, + in_reply_to: None, + content: m.content, + content_type: "text".to_string(), + pinned: false, + system_type: None, + metadata: serde_json::Value::Null, + thinking_content: None, + thinking_is_chunked: None, + send_at: m.send_at, + reactions: reactions.get(&m.message_id).cloned().unwrap_or_default(), + } + }) + .collect::>(); + let data = message::MessageListService { + room: missed_room.clone(), + total: messages.len() as i64, + messages, + }; + Ok(Some(WsOutEvent::MessageList { + room: missed_room, + data, + })) + } +} diff --git a/lib/channel/http/handler/message_read.rs b/lib/channel/http/handler/message_read.rs new file mode 100644 index 0000000..5d849b6 --- /dev/null +++ b/lib/channel/http/handler/message_read.rs @@ -0,0 +1,127 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, message_read}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn message_mark_read( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + message_ids: Vec, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + + if message_ids.is_empty() { + return Ok(None); + } + + let now = Utc::now(); + db::sqlx::query( + "INSERT INTO message_read (message, room, \"user\", read_at) \ + SELECT m.id, m.room, $1, $2 \ + FROM room_message m \ + WHERE m.id = ANY($3) AND m.room = $4 AND m.deleted_at IS NULL \ + ON CONFLICT (message, \"user\") DO NOTHING", + ) + .bind(user_id) + .bind(now) + .bind(&message_ids) + .bind(room) + .execute(bus.inner.db.writer()) + .await?; + let max_seq_row: Option<(i64,)> = db::sqlx::query_as( + "SELECT MAX(seq) FROM room_message \ + WHERE id = ANY($1) AND room = $2 AND deleted_at IS NULL", + ) + .bind(&message_ids) + .bind(room) + .fetch_optional(bus.inner.db.reader()) + .await?; + let max_seq = max_seq_row.map(|r| r.0).unwrap_or(0); + + if max_seq > 0 { + db::sqlx::query( + "INSERT INTO user_room_state (\"user\", room, last_read_seq, last_read_at, updated_at) \ + VALUES ($1, $2, $3, $4, $4) \ + ON CONFLICT (\"user\", room) DO UPDATE \ + SET last_read_seq = GREATEST(user_room_state.last_read_seq, $3), \ + last_read_at = $4, updated_at = $4", + ) + .bind(user_id) + .bind(room) + .bind(max_seq) + .bind(now) + .execute(bus.inner.db.writer()) + .await?; + } + + let room_info = + bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let reader_info = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + let data = message_read::MessageReadBatchService { + room: room_info.clone(), + message_ids, + last_seq: max_seq, + reader: reader_info, + read_at: now, + }; + bus.publish_room_event(room, "message.read_batch", &data).await?; + Ok(Some(WsOutEvent::MessageReadBatch { + room: room_info, + data, + })) + } + pub(super) async fn message_get_readers( + bus: &ChannelBus, + user_id: Uuid, + message_id: Uuid, + ) -> ChannelResult> { + let msg_room: Option<(Uuid, Uuid)> = db::sqlx::query_as( + "SELECT room, seq FROM room_message WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(message_id) + .fetch_optional(bus.inner.db.reader()) + .await?; + + let Some((room, _seq)) = msg_room else { + return Err(crate::ChannelError::RoomNotFound); + }; + Self::ensure_room_access(bus, user_id, room).await?; + + let rows = db::sqlx::query_as::<_, (Uuid, chrono::DateTime)>( + "SELECT \"user\", read_at FROM message_read \ + WHERE message = $1 ORDER BY read_at ASC", + ) + .bind(message_id) + .fetch_all(bus.inner.db.reader()) + .await?; + + let user_ids: Vec = rows.iter().map(|(uid, _)| *uid).collect(); + let users = bus.lookup_users(&user_ids).await.unwrap_or_default(); + + let readers: Vec = rows + .into_iter() + .map(|(uid, read_at)| message_read::MessageReaderEntry { + user: users + .get(&uid) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(uid)), + read_at, + }) + .collect(); + + Ok(Some(WsOutEvent::MessageReaders { + data: message_read::MessageReadersService { + message_id, + readers, + }, + })) + } +} diff --git a/lib/channel/http/handler/mod.rs b/lib/channel/http/handler/mod.rs new file mode 100644 index 0000000..4ba1406 --- /dev/null +++ b/lib/channel/http/handler/mod.rs @@ -0,0 +1,336 @@ +use uuid::Uuid; + +use crate::{ChannelBus, ChannelResult}; + +pub(crate) use super::out_event::WsOutEvent; +use super::types::{WS_PROTOCOL_VERSION, WsInMessage}; + +pub(crate) const MAX_TEXT_LEN: usize = 64 * 1024; +pub(crate) const MAX_MESSAGES_PER_REQUEST: u64 = 100; +pub(crate) const MAX_ROOM_NAME_LEN: usize = 100; +pub(crate) const MAX_CATEGORY_NAME_LEN: usize = 50; + +mod helpers; + +mod subscription; +mod message; +mod room; +mod category; +mod reaction; +mod thread; +mod pin; +mod draft; +mod notification; +mod presence; +mod invite; +mod ban; +mod voice; +mod ai; +mod search; +mod user; +mod conversation; +mod dm; +mod forward; +mod message_read; +mod star; + +pub struct WsHandler; + +impl WsHandler { + pub async fn handle( + bus: &ChannelBus, + user_id: Uuid, + msg: WsInMessage, + ) -> ChannelResult> { + match msg { + WsInMessage::Ping => Ok(Some(WsOutEvent::Pong { + protocol_version: WS_PROTOCOL_VERSION, + })), + WsInMessage::Subscribe { room } => { + Self::subscribe(bus, user_id, room).await + } + WsInMessage::Unsubscribe { room } => { + Self::unsubscribe(bus, user_id, room).await + } + WsInMessage::TypingStart { room } => { + Self::typing(bus, room, user_id, "start").await + } + WsInMessage::TypingStop { room } => { + Self::typing(bus, room, user_id, "stop").await + } + WsInMessage::ReadReceipt { + room, + last_read_seq, + } => Self::read_receipt(bus, user_id, room, last_read_seq).await, + WsInMessage::MessageList { + room, + before_seq, + after_seq, + limit, + } => { + Self::message_list( + bus, user_id, room, before_seq, after_seq, limit, + ) + .await + } + WsInMessage::MessageAround { room, seq, limit } => { + Self::message_around(bus, user_id, room, seq, limit).await + } + WsInMessage::MessageCreate { + room, + content, + content_type, + thread, + in_reply_to, + } => { + Self::message_create( + bus, + user_id, + room, + content, + content_type, + thread, + in_reply_to, + ) + .await + } + WsInMessage::MessageUpdate { message, content } => { + Self::message_update(bus, user_id, message, content).await + } + WsInMessage::MessageRevoke { message } => { + Self::message_revoke(bus, user_id, message).await + } + WsInMessage::RoomGet { room } => { + Self::room_get(bus, user_id, room).await + } + WsInMessage::RoomCreate { + workspace, + room_name, + public, + category, + } => { + Self::room_create( + bus, user_id, workspace, room_name, public, category, + ) + .await + } + WsInMessage::RoomUpdate { + room, + room_name, + public, + category, + } => Self::room_update(bus, user_id, room, room_name, public, category).await, + WsInMessage::RoomDelete { room } => { + Self::room_delete(bus, user_id, room).await + } + WsInMessage::CategoryCreate { + workspace, + name, + position, + } => Self::category_create(bus, user_id, workspace, name, position).await, + WsInMessage::CategoryUpdate { id, name, position } => { + Self::category_update(bus, user_id, id, name, position).await + } + WsInMessage::CategoryDelete { id } => { + Self::category_delete(bus, user_id, id).await + } + WsInMessage::AccessGrant { room, user } => { + Self::access_grant(bus, user_id, room, user).await + } + WsInMessage::AccessRevoke { room, user } => { + Self::access_revoke(bus, user_id, room, user).await + } + WsInMessage::StateSetReadSeq { + room, + last_read_seq, + } => Self::read_receipt(bus, user_id, room, last_read_seq).await, + WsInMessage::MissedMessages { + room, + after_seq, + limit, + } => { + Self::missed_messages(bus, user_id, room, after_seq, limit) + .await + } + WsInMessage::CsrfToken => Self::csrf_token(bus, user_id).await, + WsInMessage::StateUpdateDnd { + room, + do_not_disturb, + dnd_start_hour, + dnd_end_hour, + } => { + Self::dnd_update( + bus, user_id, room, do_not_disturb, dnd_start_hour, + dnd_end_hour, + ) + .await + } + WsInMessage::ReactionAdd { + room, + message, + emoji, + } => Self::reaction_add(bus, user_id, room, message, emoji).await, + WsInMessage::ReactionRemove { + room, + message, + emoji, + } => { + Self::reaction_remove(bus, user_id, room, message, emoji).await + } + WsInMessage::ThreadCreate { room, parent } => { + Self::thread_create(bus, user_id, room, parent).await + } + WsInMessage::ThreadResolve { thread_id } => { + Self::thread_resolve(bus, user_id, thread_id).await + } + WsInMessage::ThreadArchive { thread_id } => { + Self::thread_archive(bus, user_id, thread_id).await + } + WsInMessage::PinAdd { room, message } => { + Self::pin_add(bus, user_id, room, message).await + } + WsInMessage::PinRemove { room, message } => { + Self::pin_remove(bus, user_id, room, message).await + } + WsInMessage::DraftSave { room, content } => { + Self::draft_save(bus, user_id, room, content).await + } + WsInMessage::DraftClear { room } => { + Self::draft_clear(bus, user_id, room).await + } + WsInMessage::NotificationMarkRead { id } => { + Self::notification_mark_read(bus, user_id, id).await + } + WsInMessage::NotificationMarkAllRead { workspace_id } => { + Self::notification_mark_all_read(bus, user_id, workspace_id).await + } + WsInMessage::NotificationArchive { id } => { + Self::notification_archive(bus, user_id, id).await + } + WsInMessage::PresenceUpdate { status } => { + Self::presence_update(bus, user_id, status).await + } + WsInMessage::CustomStatusUpdate { + emoji, + text, + expires_at, + } => { + Self::custom_status_update(bus, user_id, emoji, text, expires_at) + .await + } + WsInMessage::InviteCreate { + workspace, + room, + max_uses, + expires_at, + } => { + Self::invite_create(bus, user_id, workspace, room, max_uses, expires_at) + .await + } + WsInMessage::InviteAccept { code } => { + Self::invite_accept(bus, user_id, code).await + } + WsInMessage::InviteRevoke { id } => { + Self::invite_revoke(bus, user_id, id).await + } + WsInMessage::BanCreate { + workspace, + user, + reason, + expires_at, + } => { + Self::ban_create(bus, user_id, workspace, user, reason, expires_at) + .await + } + WsInMessage::BanRemove { workspace, user } => { + Self::ban_remove(bus, user_id, workspace, user).await + } + WsInMessage::VoiceJoin { room } => { + Self::voice_join(bus, user_id, room).await + } + WsInMessage::VoiceLeave { room } => { + Self::voice_leave(bus, user_id, room).await + } + WsInMessage::VoiceMute { room, muted } => { + Self::voice_mute(bus, user_id, room, muted).await + } + WsInMessage::VoiceDeaf { room, deafened } => { + Self::voice_deaf(bus, user_id, room, deafened).await + } + WsInMessage::ScreenShare { room, start } => { + Self::screen_share(bus, user_id, room, start).await + } + WsInMessage::AiList { room } => { + Self::ai_list(bus, user_id, room).await + } + WsInMessage::AiUpsert { room, model } => { + Self::ai_upsert(bus, user_id, room, model).await + } + WsInMessage::AiDelete { room, agent_id } => { + Self::ai_delete(bus, user_id, room, agent_id).await + } + WsInMessage::AiStop { room } => { + Self::ai_stop(bus, user_id, room).await + } + WsInMessage::UserSummary { username } => { + Self::user_summary(bus, username).await + } + WsInMessage::Search { + q, + room, + limit, + offset, + .. + } => Self::search(bus, user_id, q, room, limit, offset).await, + WsInMessage::ConversationPin { room, pin } => { + Self::conversation_pin(bus, user_id, room, pin).await + } + WsInMessage::ConversationMute { room, mute } => { + Self::conversation_mute(bus, user_id, room, mute).await + } + WsInMessage::ConversationNotifyLevel { + room, + notify_level, + } => { + Self::conversation_notify_level( + bus, user_id, room, notify_level, + ) + .await + } + WsInMessage::ConversationList => { + Self::conversation_list(bus, user_id).await + } + WsInMessage::DmCreate { recipient } => { + Self::dm_create(bus, user_id, recipient).await + } + WsInMessage::DmClose { room } => { + Self::dm_close(bus, user_id, room).await + } + WsInMessage::DmList => Self::dm_list(bus, user_id).await, + WsInMessage::MessageMarkRead { + room, + message_ids, + } => { + Self::message_mark_read(bus, user_id, room, message_ids).await + } + WsInMessage::MessageGetReaders { message_id } => { + Self::message_get_readers(bus, user_id, message_id).await + } + WsInMessage::MessageStar { + room, + message, + star, + } => Self::message_star(bus, user_id, room, message, star).await, + WsInMessage::StarredList { room, limit } => { + Self::starred_list(bus, user_id, room, limit).await + } + WsInMessage::MessageForward { + source_message_id, + target_room, + } => { + Self::message_forward(bus, user_id, source_message_id, target_room) + .await + } + } + } +} diff --git a/lib/channel/http/handler/notification.rs b/lib/channel/http/handler/notification.rs new file mode 100644 index 0000000..6ee34f8 --- /dev/null +++ b/lib/channel/http/handler/notification.rs @@ -0,0 +1,77 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{UserInfo, notify}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn notification_mark_read( + bus: &ChannelBus, + user_id: Uuid, + id: Uuid, + ) -> ChannelResult> { + let result = db::sqlx::query( + "UPDATE user_app_notify SET read_at = now(), updated_at = now() \ + WHERE id = $1 AND \"user\" = $2 AND read_at IS NULL AND archived_at IS NULL", + ) + .bind(id) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + if result.rows_affected() == 0 { + return Err(ChannelError::RoomNotFound); + } + let nr_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = notify::NotifyReadService { + id, + user: nr_user, + read_at: Utc::now(), + }; + Ok(Some(WsOutEvent::NotifyRead { data })) + } + + pub(super) async fn notification_mark_all_read( + bus: &ChannelBus, + user_id: Uuid, + workspace_id: Option, + ) -> ChannelResult> { + if let Some(wk) = workspace_id { + db::sqlx::query( + "UPDATE user_app_notify SET read_at = now(), updated_at = now() \ + WHERE \"user\" = $1 AND target_id = $2 AND read_at IS NULL", + ) + .bind(user_id) + .bind(wk) + .execute(bus.inner.db.writer()) + .await?; + } else { + db::sqlx::query( + "UPDATE user_app_notify SET read_at = now(), updated_at = now() \ + WHERE \"user\" = $1 AND read_at IS NULL", + ) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + } + Ok(None) + } + + pub(super) async fn notification_archive( + bus: &ChannelBus, + user_id: Uuid, + id: Uuid, + ) -> ChannelResult> { + db::sqlx::query( + "UPDATE user_app_notify SET archived_at = now(), updated_at = now() \ + WHERE id = $1 AND \"user\" = $2 AND archived_at IS NULL", + ) + .bind(id) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + Ok(None) + } +} diff --git a/lib/channel/http/handler/pin.rs b/lib/channel/http/handler/pin.rs new file mode 100644 index 0000000..0f916b2 --- /dev/null +++ b/lib/channel/http/handler/pin.rs @@ -0,0 +1,83 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, pin}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn pin_add( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + message: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + Self::ensure_message_in_room(bus, room, message).await?; + let seq = bus.inner.seq.seq(room).await?; + let result = db::sqlx::query( + "INSERT INTO room_pin (room, message, seq, pinned_by, created_at) \ + VALUES ($1, $2, $3, $4, now()) \ + ON CONFLICT DO NOTHING", + ) + .bind(room) + .bind(message) + .bind(seq) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + if result.rows_affected() == 0 { + return Ok(None); + } + db::sqlx::query("UPDATE room_message SET pinned = true, updated_at = now() WHERE id = $1") + .bind(message) + .execute(bus.inner.db.writer()) + .await?; + let pa_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let pinned_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = pin::PinAddedService { + room: pa_room, + message, + pinned_by, + pinned_at: Utc::now(), + }; + bus.publish_room_event(room, "pin.added", &data).await?; + Ok(Some(WsOutEvent::PinAdded { room: data.room.clone(), data })) + } + + pub(super) async fn pin_remove( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + message: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + Self::ensure_message_in_room(bus, room, message).await?; + let result = db::sqlx::query( + "DELETE FROM room_pin WHERE room = $1 AND message = $2", + ) + .bind(room) + .bind(message) + .execute(bus.inner.db.writer()) + .await?; + if result.rows_affected() == 0 { + return Ok(None); + } + db::sqlx::query("UPDATE room_message SET pinned = false, updated_at = now() WHERE id = $1") + .bind(message) + .execute(bus.inner.db.writer()) + .await?; + let pr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let removed_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = pin::PinRemovedService { + room: pr_room, + message, + removed_by, + removed_at: Utc::now(), + }; + bus.publish_room_event(room, "pin.removed", &data).await?; + Ok(Some(WsOutEvent::PinRemoved { room: data.room.clone(), data })) + } +} diff --git a/lib/channel/http/handler/presence.rs b/lib/channel/http/handler/presence.rs new file mode 100644 index 0000000..3ddfaab --- /dev/null +++ b/lib/channel/http/handler/presence.rs @@ -0,0 +1,94 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, member, presence}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn dnd_update( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + do_not_disturb: Option, + dnd_start_hour: Option, + dnd_end_hour: Option, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let do_not_disturb = do_not_disturb.unwrap_or(false); + let start_hour = dnd_start_hour.map(|h| h as i32); + let end_hour = dnd_end_hour.map(|h| h as i32); + let key = format!("dnd:{}:{}", user_id, room); + let dnd_data = serde_json::json!({ + "do_not_disturb": do_not_disturb, + "dnd_start_hour": start_hour, + "dnd_end_hour": end_hour, + }); + bus.inner.cache.set(&key, &dnd_data).await?; + let dnd_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let dnd_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = member::DndUpdatedService { + room: dnd_room, + user: dnd_user, + do_not_disturb, + dnd_start_hour: start_hour, + dnd_end_hour: end_hour, + }; + bus.publish_room_event(room, "member.dnd_updated", &data).await?; + Ok(None) + } + + pub(super) async fn presence_update( + bus: &ChannelBus, + user_id: Uuid, + status: presence::UserPresenceStatus, + ) -> ChannelResult> { + let pc_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = presence::PresenceChangedService { + user: pc_user, + project: None, + status, + last_seen_at: Some(Utc::now()), + }; + let rooms = crate::rooms::user_rooms( + &bus.inner.db, + &bus.inner.cache, + &bus.inner.config, + user_id, + ) + .await?; + for room in rooms { + bus.publish_room_event(room, "presence.changed", &data).await?; + } + Ok(Some(WsOutEvent::PresenceChanged { data })) + } + + pub(super) async fn custom_status_update( + bus: &ChannelBus, + user_id: Uuid, + emoji: Option, + text: Option, + expires_at: Option>, + ) -> ChannelResult> { + let cs_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = presence::CustomStatusUpdatedService { + user: cs_user, + emoji, + text, + expires_at, + }; + let rooms = crate::rooms::user_rooms( + &bus.inner.db, + &bus.inner.cache, + &bus.inner.config, + user_id, + ) + .await?; + for room in rooms { + bus.publish_room_event(room, "custom_status.updated", &data).await?; + } + Ok(Some(WsOutEvent::CustomStatusUpdated { data })) + } +} diff --git a/lib/channel/http/handler/reaction.rs b/lib/channel/http/handler/reaction.rs new file mode 100644 index 0000000..786cd50 --- /dev/null +++ b/lib/channel/http/handler/reaction.rs @@ -0,0 +1,91 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, reaction}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn reaction_add( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + message: Uuid, + emoji: String, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + Self::ensure_message_in_room(bus, room, message).await?; + if emoji.is_empty() || emoji.len() > 100 { + return Err(ChannelError::Validation("invalid emoji".into())); + } + let seq = bus.inner.seq.seq(room).await?; + let result = db::sqlx::query( + "INSERT INTO room_reaction (message, \"user\", seq, reaction, created_at) \ + VALUES ($1, $2, $3, $4, now()) \ + ON CONFLICT DO NOTHING", + ) + .bind(message) + .bind(user_id) + .bind(seq) + .bind(&emoji) + .execute(bus.inner.db.writer()) + .await?; + if result.rows_affected() == 0 { + return Ok(None); + } + let user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); + let rct_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let data = reaction::ReactionAddedService { + id: Uuid::now_v7(), + room: rct_room, + message, + user, + emoji, + created_at: Utc::now(), + }; + bus.publish_room_event(room, "reaction.added", &data).await?; + Ok(Some(WsOutEvent::ReactionAdded { room: data.room.clone(), data })) + } + + pub(super) async fn reaction_remove( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + message: Uuid, + emoji: String, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + Self::ensure_message_in_room(bus, room, message).await?; + let result = db::sqlx::query( + "DELETE FROM room_reaction WHERE message = $1 AND \"user\" = $2 AND reaction = $3", + ) + .bind(message) + .bind(user_id) + .bind(&emoji) + .execute(bus.inner.db.writer()) + .await?; + if result.rows_affected() == 0 { + return Ok(None); + } + let user = bus + .lookup_user(user_id) + .await + .unwrap_or_else(|_| UserInfo::unknown(user_id)); + let rct_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let data = reaction::ReactionRemovedService { + id: Uuid::now_v7(), + room: rct_room, + message, + user, + emoji, + removed_at: Utc::now(), + }; + bus.publish_room_event(room, "reaction.removed", &data).await?; + Ok(Some(WsOutEvent::ReactionRemoved { room: data.room.clone(), data })) + } +} diff --git a/lib/channel/http/handler/room.rs b/lib/channel/http/handler/room.rs new file mode 100644 index 0000000..a72518c --- /dev/null +++ b/lib/channel/http/handler/room.rs @@ -0,0 +1,259 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, WorkspaceInfo, member, rooms}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::{MAX_ROOM_NAME_LEN}; +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn room_get( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let row = db::sqlx::query_as::<_, model::room::RoomModel>( + "SELECT id, wk, parent, name, topic, room_type, position, \ + is_private, is_archived, created_by, created_at, updated_at, deleted_at \ + FROM room WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(room) + .fetch_one(bus.inner.db.reader()) + .await?; + Ok(Some(WsOutEvent::Response { + request_id: Uuid::nil(), + data: serde_json::json!({ + "id": row.id, + "wk": row.wk, + "name": row.name, + "topic": row.topic, + "room_type": row.room_type, + "is_private": row.is_private, + "is_archived": row.is_archived, + "parent": row.parent, + "created_by": row.created_by, + "created_at": row.created_at, + }), + })) + } + + pub(super) async fn room_create( + bus: &ChannelBus, + user_id: Uuid, + workspace: Uuid, + room_name: String, + public: bool, + category: Option, + ) -> ChannelResult> { + if room_name.is_empty() || room_name.len() > MAX_ROOM_NAME_LEN { + return Err(ChannelError::Validation("invalid room name".into())); + } + Self::ensure_workspace_member(bus, user_id, workspace).await?; + let is_private = !public; + let row = db::sqlx::query_as::<_, model::room::RoomModel>( + "INSERT INTO room (wk, parent, name, room_type, is_private, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, 'channel', $4, $5, now(), now()) \ + RETURNING id, wk, parent, name, topic, room_type, position, \ + is_private, is_archived, created_by, created_at, updated_at, deleted_at", + ) + .bind(workspace) + .bind(category) + .bind(&room_name) + .bind(is_private) + .bind(user_id) + .fetch_one(bus.inner.db.writer()) + .await?; + db::sqlx::query( + "INSERT INTO room_permission_overwrite \ + (room, target_type, target_id, allow_permissions, deny_permissions, created_at, updated_at) \ + VALUES ($1, 'user', $2, '', '', now(), now())", + ) + .bind(row.id) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + let data = rooms::RoomCreatedService { + room: RoomInfo::from_model(&row), + workspace: bus.lookup_workspace(workspace).await.unwrap_or_else(|_| WorkspaceInfo::unknown(workspace)), + public, + category, + created_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + created_at: row.created_at, + }; + bus.publish_room_event(row.id, "room.created", &data).await?; + bus.room_changed(row.id).await?; + Ok(Some(WsOutEvent::RoomCreated { room: data.room.clone(), data })) + } + + pub(super) async fn room_update( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + room_name: Option, + public: Option, + category: Option, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let old = db::sqlx::query_as::<_, model::room::RoomModel>( + "SELECT id, wk, parent, name, topic, room_type, position, \ + is_private, is_archived, created_by, created_at, updated_at, deleted_at \ + FROM room WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(room) + .fetch_one(bus.inner.db.reader()) + .await?; + let new_name = room_name.unwrap_or(old.name.clone()); + let new_private = + public.map(|p| !p).unwrap_or(old.is_private); + let new_category = category.or(old.parent); + let row = db::sqlx::query_as::<_, model::room::RoomModel>( + "UPDATE room SET name = $2, is_private = $3, parent = $4, updated_at = now() \ + WHERE id = $1 AND deleted_at IS NULL \ + RETURNING id, wk, parent, name, topic, room_type, position, \ + is_private, is_archived, created_by, created_at, updated_at, deleted_at", + ) + .bind(room) + .bind(&new_name) + .bind(new_private) + .bind(new_category) + .fetch_one(bus.inner.db.writer()) + .await?; + let mut renamed = false; + if new_name != old.name { + let data = rooms::RoomRenamedService { + room: RoomInfo::from_model(&row), + workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + old_name: old.name.clone(), + new_name: new_name, + renamed_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + renamed_at: Utc::now(), + }; + bus.publish_room_event(room, "room.renamed", &data).await?; + renamed = true; + } + if new_private != old.is_private || new_category != old.parent { + let data = rooms::RoomSettingsUpdatedService { + room: RoomInfo::from_model(&row), + workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + slowmode_seconds: None, + nsfw: false, + default_auto_archive_duration: None, + updated_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + updated_at: Utc::now(), + }; + bus.publish_room_event(room, "room.settings_updated", &data).await?; + } + bus.room_changed(room).await?; + if renamed { + return Ok(Some(WsOutEvent::RoomRenamed { + room: RoomInfo::from_model(&row), + data: rooms::RoomRenamedService { + room: RoomInfo::from_model(&row), + workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + old_name: old.name, + new_name: row.name, + renamed_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + renamed_at: Utc::now(), + }, + })); + } + Ok(None) + } + + pub(super) async fn room_delete( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let old = db::sqlx::query_as::<_, model::room::RoomModel>( + "SELECT id, wk, parent, name, topic, room_type, position, \ + is_private, is_archived, created_by, created_at, updated_at, deleted_at \ + FROM room WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(room) + .fetch_one(bus.inner.db.reader()) + .await?; + if old.created_by != user_id { + return Err(ChannelError::AccessDenied); + } + let row = db::sqlx::query_as::<_, model::room::RoomModel>( + "UPDATE room SET deleted_at = now(), updated_at = now() \ + WHERE id = $1 AND deleted_at IS NULL \ + RETURNING id, wk, parent, name, topic, room_type, position, \ + is_private, is_archived, created_by, created_at, updated_at, deleted_at", + ) + .bind(room) + .fetch_one(bus.inner.db.writer()) + .await?; + let data = rooms::RoomDeletedService { + room: RoomInfo::from_model(&row), + workspace: bus.lookup_workspace(row.wk).await.unwrap_or_else(|_| WorkspaceInfo::unknown(row.wk)), + deleted_by: bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)), + deleted_at: Utc::now(), + }; + bus.publish_room_event(room, "room.deleted", &data).await?; + bus.room_changed(room).await?; + Ok(Some(WsOutEvent::RoomDeleted { room: data.room.clone(), data })) + } + + pub(super) async fn access_grant( + bus: &ChannelBus, + _user_id: Uuid, + room: Uuid, + target_user: Uuid, + ) -> ChannelResult> { + db::sqlx::query( + "INSERT INTO room_permission_overwrite \ + (room, target_type, target_id, allow_permissions, deny_permissions, created_at, updated_at) \ + SELECT $1, 'user', $2, '', '', now(), now() \ + WHERE NOT EXISTS (SELECT 1 FROM room_permission_overwrite WHERE room = $1 AND target_type = 'user' AND target_id = $2)", + ) + .bind(room) + .bind(target_user) + .execute(bus.inner.db.writer()) + .await?; + let mj_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let mj_user = bus.lookup_user(target_user).await.unwrap_or_else(|_| UserInfo::unknown(target_user)); + let data = member::MemberJoinedService { + room: mj_room, + user: mj_user, + project_role: None, + joined_at: Utc::now(), + }; + bus.publish_room_event(room, "member.joined", &data).await?; + bus.room_changed(room).await?; + Ok(Some(WsOutEvent::MemberJoined { room: data.room.clone(), data })) + } + + pub(super) async fn access_revoke( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + target_user: Uuid, + ) -> ChannelResult> { + db::sqlx::query( + "DELETE FROM room_permission_overwrite \ + WHERE room = $1 AND target_type = 'user' AND target_id = $2", + ) + .bind(room) + .bind(target_user) + .execute(bus.inner.db.writer()) + .await?; + let mr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let mr_target = bus.lookup_user(target_user).await.unwrap_or_else(|_| UserInfo::unknown(target_user)); + let mr_remover = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = member::MemberRemovedService { + room: mr_room, + user: mr_target, + removed_by: mr_remover, + removed_at: Utc::now(), + }; + bus.publish_room_event(room, "member.removed", &data).await?; + bus.room_changed(room).await?; + Ok(Some(WsOutEvent::MemberRemoved { room: data.room.clone(), data })) + } +} diff --git a/lib/channel/http/handler/search.rs b/lib/channel/http/handler/search.rs new file mode 100644 index 0000000..64cb239 --- /dev/null +++ b/lib/channel/http/handler/search.rs @@ -0,0 +1,90 @@ +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo}; +use crate::{ + ChannelBus, ChannelError, ChannelResult, + search::{SearchEngine, SearchQuery}, +}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn search( + bus: &ChannelBus, + user_id: Uuid, + q: String, + room: Option, + limit: Option, + offset: Option, + ) -> ChannelResult> { + if let Some(room_id) = room { + Self::ensure_room_access(bus, user_id, room_id).await?; + } else { + return Err(ChannelError::Validation( + "room is required for websocket search".to_string(), + )); + } + let engine = SearchEngine::new(bus.inner.db.clone()); + let result = engine + .search(SearchQuery { + query: q.clone(), + room_id: room, + user_id: None, + limit: limit.unwrap_or(50), + offset: offset.unwrap_or(0), + }) + .await?; + + let author_ids: Vec = result.hits.iter().map(|h| h.sender_id).collect(); + let message_ids: Vec = result.hits.iter().map(|h| h.message_id).collect(); + let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); + let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) + .await + .unwrap_or_default(); + + let search_room = match room { + Some(r) => Some(bus.lookup_room(r).await.unwrap_or_else(|_| RoomInfo::unknown(r))), + None => None, + }; + let search_msg_room = search_room.clone().unwrap_or_else(|| RoomInfo::unknown(room.unwrap_or_default())); + let data = crate::event::search::SearchResultService { + q, + room: search_room, + messages: result + .hits + .into_iter() + .map(|h| { + let sender = user_map + .get(&h.sender_id) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(h.sender_id)); + crate::event::search::SearchMessageHitService { + message: crate::event::message::MessageNewService { + id: h.message_id, + seq: 0, + room: search_msg_room.clone(), + sender_type: "user".to_string(), + sender, + thread: None, + in_reply_to: None, + content: h.content.clone(), + content_type: "text".to_string(), + pinned: false, + system_type: None, + metadata: serde_json::Value::Null, + thinking_content: None, + thinking_is_chunked: None, + send_at: h.send_at, + reactions: reactions.get(&h.message_id).cloned().unwrap_or_default(), + }, + highlighted_content: h.highlighted, + }}) + .collect(), + total: result.total as i64, + took_ms: 0, + }; + + Ok(Some(WsOutEvent::SearchResult { data })) + } +} diff --git a/lib/channel/http/handler/star.rs b/lib/channel/http/handler/star.rs new file mode 100644 index 0000000..1df491d --- /dev/null +++ b/lib/channel/http/handler/star.rs @@ -0,0 +1,157 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, star}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn message_star( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + message: Uuid, + do_star: bool, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + Self::ensure_message_in_room(bus, room, message).await?; + + let room_info = + bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let user_info = + bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + if do_star { + let result = db::sqlx::query( + "INSERT INTO message_star (message, room, \"user\", created_at) \ + VALUES ($1, $2, $3, now()) ON CONFLICT (message, \"user\") DO NOTHING", + ) + .bind(message) + .bind(room) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + + if result.rows_affected() == 0 { + return Ok(None); + } + + let seq_row: Option<(i64,)> = db::sqlx::query_as( + "SELECT seq FROM room_message WHERE id = $1", + ) + .bind(message) + .fetch_optional(bus.inner.db.reader()) + .await?; + + let data = star::MessageStarredService { + room: room_info.clone(), + message_id: message, + message_seq: seq_row.map(|r| r.0).unwrap_or(0), + starred_by: user_info, + starred_at: Utc::now(), + }; + bus.emit_to_user(user_id, "message.starred", &data).await?; + Ok(Some(WsOutEvent::MessageStarred { + room: room_info, + data, + })) + } else { + let result = db::sqlx::query( + "DELETE FROM message_star WHERE message = $1 AND \"user\" = $2", + ) + .bind(message) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + + if result.rows_affected() == 0 { + return Ok(None); + } + + let data = star::MessageUnstarredService { + room: room_info.clone(), + message_id: message, + unstarred_by: user_info, + unstarred_at: Utc::now(), + }; + bus.emit_to_user(user_id, "message.unstarred", &data).await?; + Ok(Some(WsOutEvent::MessageUnstarred { + room: room_info, + data, + })) + } + } + pub(super) async fn starred_list( + bus: &ChannelBus, + user_id: Uuid, + room: Option, + limit: Option, + ) -> ChannelResult> { + let limit = limit.unwrap_or(50).min(100) as i64; + + let rows = if let Some(room_id) = room { + Self::ensure_room_access(bus, user_id, room_id).await?; + db::sqlx::query_as::<_, (Uuid, Uuid, i64, String, String, Uuid, chrono::DateTime, chrono::DateTime)>( + "SELECT ms.id, rm.id, rm.seq, rm.content, rm.content_type, rm.author, ms.created_at, rm.created_at \ + FROM message_star ms \ + JOIN room_message rm ON rm.id = ms.message \ + WHERE ms.\"user\" = $1 AND ms.room = $2 AND rm.deleted_at IS NULL \ + ORDER BY ms.created_at DESC LIMIT $3", + ) + .bind(user_id) + .bind(room_id) + .bind(limit) + .fetch_all(bus.inner.db.reader()) + .await? + } else { + db::sqlx::query_as::<_, (Uuid, Uuid, i64, String, String, Uuid, chrono::DateTime, chrono::DateTime)>( + "SELECT ms.id, rm.id, rm.seq, rm.content, rm.content_type, rm.author, ms.created_at, rm.created_at \ + FROM message_star ms \ + JOIN room_message rm ON rm.id = ms.message \ + WHERE ms.\"user\" = $1 AND rm.deleted_at IS NULL \ + ORDER BY ms.created_at DESC LIMIT $2", + ) + .bind(user_id) + .bind(limit) + .fetch_all(bus.inner.db.reader()) + .await? + }; + + let author_ids: Vec = rows.iter().map(|r| r.5).collect(); + let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); + + let mut entries = Vec::with_capacity(rows.len()); + for (_star_id, msg_id, seq, content, content_type, author_id, starred_at, sent_at) in rows { + let msg_room_row: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT room FROM room_message WHERE id = $1", + ) + .bind(msg_id) + .fetch_optional(bus.inner.db.reader()) + .await?; + let msg_room_id = msg_room_row.map(|r| r.0).unwrap_or(Uuid::nil()); + let room_info = bus + .lookup_room(msg_room_id) + .await + .unwrap_or_else(|_| RoomInfo::unknown(msg_room_id)); + let sender = user_map + .get(&author_id) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(author_id)); + + entries.push(star::StarredMessageEntry { + message_id: msg_id, + room: room_info, + seq, + content, + content_type, + sender, + starred_at, + sent_at, + }); + } + + Ok(Some(WsOutEvent::StarredList { data: entries })) + } +} diff --git a/lib/channel/http/handler/subscription.rs b/lib/channel/http/handler/subscription.rs new file mode 100644 index 0000000..c5a64ca --- /dev/null +++ b/lib/channel/http/handler/subscription.rs @@ -0,0 +1,143 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, member}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn subscribe( + bus: &ChannelBus, + user_id: Uuid, + _room: Uuid, + ) -> ChannelResult> { + bus.refresh_user(user_id).await?; + Ok(None) + } + + pub(super) async fn unsubscribe( + _bus: &ChannelBus, + _user_id: Uuid, + _room: Uuid, + ) -> ChannelResult> { + Ok(None) + } + + pub(super) async fn typing( + bus: &ChannelBus, + room: Uuid, + user_id: Uuid, + action: &str, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + + let key = (room, user_id); + + if action == "start" { + let ty_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let ty_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let already_typing = bus.inner.typing_states.contains_key(&key); + if let Some((_, (_, _, old_cancel))) = bus.inner.typing_states.remove(&key) { + old_cancel.cancel(); + } + + let cancel = tokio_util::sync::CancellationToken::new(); + let cancel_clone = cancel.clone(); + let bus_clone = bus.clone(); + let user_clone = ty_user.clone(); + let room_clone = ty_room.clone(); + bus.inner.typing_states.insert(key, (ty_user.clone(), ty_room.clone(), cancel)); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + if cancel_clone.is_cancelled() { + return; + } + bus_clone.inner.typing_states.remove(&(room_clone.id, user_clone.id)); + let room_id = room_clone.id; + let stop_data = member::TypingStopService { + room: room_clone, + user: user_clone, + sender_type: "user".to_string(), + stopped_at: Utc::now(), + }; + let _ = bus_clone.publish_room_event(room_id, "typing.stop", &stop_data).await; + }); + if !already_typing { + let data = member::TypingStartService { + room: ty_room, + user: ty_user, + sender_type: "user".to_string(), + started_at: Utc::now(), + }; + bus.publish_room_event(room, "typing.start", &data).await?; + return Ok(Some(WsOutEvent::TypingStart { room: data.room.clone(), data })); + } + Ok(None) + } else { + if let Some((_, (_, _, cancel))) = bus.inner.typing_states.remove(&key) { + cancel.cancel(); + } + + let ty_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let ty_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + + let data = member::TypingStopService { + room: ty_room, + user: ty_user, + sender_type: "user".to_string(), + stopped_at: Utc::now(), + }; + bus.publish_room_event(room, "typing.stop", &data).await?; + Ok(Some(WsOutEvent::TypingStop { room: data.room.clone(), data })) + } + } + + pub(super) async fn read_receipt( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + last_read_seq: i64, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + bus.inner + .reconnect + .save_client_state(user_id, room, last_read_seq) + .await?; + db::sqlx::query( + "INSERT INTO user_room_state (\"user\", room, last_read_seq, last_read_at, updated_at) \ + VALUES ($1, $2, $3, now(), now()) \ + ON CONFLICT (\"user\", room) DO UPDATE \ + SET last_read_seq = GREATEST(user_room_state.last_read_seq, $3), \ + last_read_at = now(), updated_at = now()", + ) + .bind(user_id) + .bind(room) + .bind(last_read_seq) + .execute(bus.inner.db.writer()) + .await?; + let rr_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let rr_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = member::ReadReceiptService { + room: rr_room.clone(), + user: rr_user, + last_read_seq, + updated_at: Utc::now(), + }; + bus.publish_room_event(room, "member.read_receipt", &data) + .await?; + Ok(Some(WsOutEvent::ReadReceipt { + room: rr_room, + data, + })) + } + + pub(super) async fn csrf_token( + bus: &ChannelBus, + user_id: Uuid, + ) -> ChannelResult> { + let token = bus.inner.csrf.generate_token(user_id).await?; + Ok(Some(WsOutEvent::CsrfToken { token })) + } +} diff --git a/lib/channel/http/handler/thread.rs b/lib/channel/http/handler/thread.rs new file mode 100644 index 0000000..95c7d00 --- /dev/null +++ b/lib/channel/http/handler/thread.rs @@ -0,0 +1,121 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, thread}; +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn thread_create( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + parent: i64, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + // Look up the message UUID from seq + room + let parent_id: Option<(Uuid,)> = db::sqlx::query_as( + "SELECT id FROM room_message WHERE room = $1 AND seq = $2 AND deleted_at IS NULL", + ) + .bind(room) + .bind(parent) + .fetch_optional(bus.inner.db.reader()) + .await?; + let parent_msg_id = parent_id.ok_or(ChannelError::RoomNotFound)?.0; + let seq = bus.inner.seq.seq(room).await?; + let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( + "INSERT INTO room_thread (room, seq, starter_message, title, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, '', $4, now(), now()) \ + RETURNING id, room, seq, starter_message, title, created_by, archived, locked, \ + last_message_at, created_at, updated_at, archived_at", + ) + .bind(room) + .bind(seq) + .bind(parent_msg_id) // UUID of the starter message + .bind(user_id) + .fetch_one(bus.inner.db.writer()) + .await?; + let tc_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let created_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = thread::ThreadCreatedService { + id: row.id, + room: tc_room, + parent, + created_by, + participants: serde_json::Value::Null, + created_at: row.created_at, + }; + bus.publish_room_event(room, "thread.created", &data).await?; + Ok(Some(WsOutEvent::ThreadCreated { room: data.room.clone(), data })) + } + + pub(super) async fn thread_resolve( + bus: &ChannelBus, + user_id: Uuid, + thread_id: Uuid, + ) -> ChannelResult> { + let existing: (Uuid,) = db::sqlx::query_as( + "SELECT room FROM room_thread WHERE id = $1", + ) + .bind(thread_id) + .fetch_optional(bus.inner.db.reader()) + .await? + .ok_or(ChannelError::RoomNotFound)?; + Self::ensure_room_access(bus, user_id, existing.0).await?; + let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( + "UPDATE room_thread SET locked = true, updated_at = now() \ + WHERE id = $1 \ + RETURNING id, room, seq, starter_message, title, created_by, archived, locked, \ + last_message_at, created_at, updated_at, archived_at", + ) + .bind(thread_id) + .fetch_one(bus.inner.db.writer()) + .await?; + let tr_room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let resolved_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = thread::ThreadResolvedService { + id: row.id, + room: tr_room, + resolved_by, + resolved_at: Utc::now(), + }; + bus.publish_room_event(row.room, "thread.resolved", &data).await?; + Ok(Some(WsOutEvent::ThreadResolved { room: data.room.clone(), data })) + } + + pub(super) async fn thread_archive( + bus: &ChannelBus, + user_id: Uuid, + thread_id: Uuid, + ) -> ChannelResult> { + let existing: (Uuid,) = db::sqlx::query_as( + "SELECT room FROM room_thread WHERE id = $1", + ) + .bind(thread_id) + .fetch_optional(bus.inner.db.reader()) + .await? + .ok_or(ChannelError::RoomNotFound)?; + Self::ensure_room_access(bus, user_id, existing.0).await?; + let row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( + "UPDATE room_thread SET archived = true, archived_at = now(), updated_at = now() \ + WHERE id = $1 \ + RETURNING id, room, seq, starter_message, title, created_by, archived, locked, \ + last_message_at, created_at, updated_at, archived_at", + ) + .bind(thread_id) + .fetch_one(bus.inner.db.writer()) + .await?; + let ta_room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); + let archived_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = thread::ThreadArchivedService { + id: row.id, + room: ta_room, + archived_by, + archived_at: Utc::now(), + }; + bus.publish_room_event(row.room, "thread.archived", &data).await?; + Ok(Some(WsOutEvent::ThreadArchived { room: data.room.clone(), data })) + } +} diff --git a/lib/channel/http/handler/user.rs b/lib/channel/http/handler/user.rs new file mode 100644 index 0000000..37068ab --- /dev/null +++ b/lib/channel/http/handler/user.rs @@ -0,0 +1,35 @@ +use uuid::Uuid; + +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn user_summary( + bus: &ChannelBus, + username: String, + ) -> ChannelResult> { + let user = db::sqlx::query_as::<_, model::users::UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, \ + allow_use, can_search, last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE username = $1", + ) + .bind(&username) + .fetch_one(bus.inner.db.reader()) + .await + .map_err(|e| match e { + db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound, + other => ChannelError::Database(other), + })?; + Ok(Some(WsOutEvent::Response { + request_id: Uuid::nil(), + data: serde_json::json!({ + "id": user.id, + "username": user.username, + "display_name": user.display_name, + "avatar_url": user.avatar_url, + }), + })) + } +} diff --git a/lib/channel/http/handler/voice.rs b/lib/channel/http/handler/voice.rs new file mode 100644 index 0000000..403f402 --- /dev/null +++ b/lib/channel/http/handler/voice.rs @@ -0,0 +1,97 @@ +use chrono::Utc; +use uuid::Uuid; + +use crate::event::{RoomInfo, UserInfo, voice}; +use crate::{ChannelBus, ChannelResult}; + +use super::WsOutEvent; +use super::WsHandler; + +impl WsHandler { + pub(super) async fn voice_join( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let vj_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let vj_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = voice::VoiceChannelJoinedService { + room: vj_room, + workspace: None, + user: vj_user, + muted: false, + deafened: false, + video: false, + joined_at: Utc::now(), + }; + bus.publish_room_event(room, "voice.channel_joined", &data).await?; + Ok(Some(WsOutEvent::VoiceChannelJoined { room: data.room.clone(), data })) + } + + pub(super) async fn voice_leave( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + let vl_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); + let vl_user = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); + let data = voice::VoiceChannelLeftService { + room: vl_room, + workspace: None, + user: vl_user, + left_at: Utc::now(), + }; + bus.publish_room_event(room, "voice.channel_left", &data).await?; + Ok(Some(WsOutEvent::VoiceChannelLeft { room: data.room.clone(), data })) + } + + pub(super) async fn voice_mute( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + muted: bool, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + bus.publish_room_event( + room, + "voice.mute_updated", + &serde_json::json!({"user_id": user_id, "muted": muted}), + ) + .await?; + Ok(None) + } + + pub(super) async fn voice_deaf( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + deafened: bool, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + bus.publish_room_event( + room, + "voice.deaf_updated", + &serde_json::json!({"user_id": user_id, "deafened": deafened}), + ) + .await?; + Ok(None) + } + + pub(super) async fn screen_share( + bus: &ChannelBus, + user_id: Uuid, + room: Uuid, + start: bool, + ) -> ChannelResult> { + Self::ensure_room_access(bus, user_id, room).await?; + bus.publish_room_event( + room, + "voice.screen_share", + &serde_json::json!({"user_id": user_id, "start": start}), + ) + .await?; + Ok(None) + } +} diff --git a/lib/channel/http/mod.rs b/lib/channel/http/mod.rs new file mode 100644 index 0000000..1fd8392 --- /dev/null +++ b/lib/channel/http/mod.rs @@ -0,0 +1,10 @@ +pub mod handler; +pub mod out_event; +pub mod session; +pub mod types; +pub mod ws; + +pub use handler::WsHandler; +pub use out_event::{WsError, WsOutEvent}; +pub use session::WsUserCtx; +pub use types::{WS_PROTOCOL_VERSION, WsInMessage}; diff --git a/lib/channel/http/out_event.rs b/lib/channel/http/out_event.rs new file mode 100644 index 0000000..4b9bdcf --- /dev/null +++ b/lib/channel/http/out_event.rs @@ -0,0 +1,287 @@ +use serde::Serialize; +use uuid::Uuid; + +use crate::event::{ + RoomInfo, WorkspaceInfo, + ai, attachment, ban, category, conversation, dm, draft, forward, invite, + member, message, message_read, notify, pin, presence, reaction, rooms, + search, star, thread, voice, workspace, +}; + +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum WsOutEvent { + Pong { + protocol_version: u32, + }, + Error(WsError), + CsrfToken { + token: String, + }, + RoomCreated { + room: RoomInfo, + data: rooms::RoomCreatedService, + }, + RoomDeleted { + room: RoomInfo, + data: rooms::RoomDeletedService, + }, + RoomRenamed { + room: RoomInfo, + data: rooms::RoomRenamedService, + }, + RoomTopicUpdated { + room: RoomInfo, + data: rooms::RoomTopicUpdatedService, + }, + RoomSettingsUpdated { + room: RoomInfo, + data: rooms::RoomSettingsUpdatedService, + }, + RoomMoved { + room: RoomInfo, + data: rooms::RoomMovedService, + }, + MessageNew { + room: RoomInfo, + data: message::MessageNewService, + }, + MessageEdited { + room: RoomInfo, + data: message::MessageEditedService, + }, + MessageRevoked { + room: RoomInfo, + data: message::MessageRevokedService, + }, + MessageStreamStart { + room: RoomInfo, + data: message::MessageStreamStartService, + }, + MessageStreamChunk { + room: RoomInfo, + data: message::MessageStreamChunkService, + }, + MessageStreamDone { + room: RoomInfo, + data: message::MessageStreamDoneService, + }, + MessageList { + room: RoomInfo, + data: message::MessageListService, + }, + MemberJoined { + room: RoomInfo, + data: member::MemberJoinedService, + }, + MemberRemoved { + room: RoomInfo, + data: member::MemberRemovedService, + }, + ReadReceipt { + room: RoomInfo, + data: member::ReadReceiptService, + }, + TypingStart { + room: RoomInfo, + data: member::TypingStartService, + }, + TypingStop { + room: RoomInfo, + data: member::TypingStopService, + }, + ReactionAdded { + room: RoomInfo, + data: reaction::ReactionAddedService, + }, + ReactionRemoved { + room: RoomInfo, + data: reaction::ReactionRemovedService, + }, + ReactionBatchUpdated { + room: RoomInfo, + data: reaction::ReactionBatchUpdatedService, + }, + ThreadCreated { + room: RoomInfo, + data: thread::ThreadCreatedService, + }, + ThreadUpdated { + room: RoomInfo, + data: thread::ThreadUpdatedService, + }, + ThreadResolved { + room: RoomInfo, + data: thread::ThreadResolvedService, + }, + ThreadArchived { + room: RoomInfo, + data: thread::ThreadArchivedService, + }, + CategoryCreated { + workspace: WorkspaceInfo, + data: category::CategoryCreatedService, + }, + CategoryUpdated { + workspace: WorkspaceInfo, + data: category::CategoryUpdatedService, + }, + CategoryDeleted { + workspace: WorkspaceInfo, + data: category::CategoryDeletedService, + }, + PinAdded { + room: RoomInfo, + data: pin::PinAddedService, + }, + PinRemoved { + room: RoomInfo, + data: pin::PinRemovedService, + }, + WorkspaceRoomCreated { + workspace: WorkspaceInfo, + data: workspace::WorkspaceRoomCreatedService, + }, + WorkspaceRoomDeleted { + workspace: WorkspaceInfo, + data: workspace::WorkspaceRoomDeletedService, + }, + DraftSaved { + room: RoomInfo, + data: draft::DraftSavedService, + }, + DraftCleared { + room: RoomInfo, + data: draft::DraftClearedService, + }, + SearchResult { + data: search::SearchResultService, + }, + NotifyCreated { + data: notify::NotifyCreatedService, + }, + NotifyRead { + data: notify::NotifyReadService, + }, + PresenceChanged { + data: presence::PresenceChangedService, + }, + CustomStatusUpdated { + data: presence::CustomStatusUpdatedService, + }, + InviteCreated { + data: invite::InviteCreatedService, + }, + InviteAccepted { + data: invite::InviteAcceptedService, + }, + AttachmentUploaded { + room: RoomInfo, + data: attachment::AttachmentUploadedService, + }, + UserBanned { + data: ban::BannedService, + }, + UserUnbanned { + data: ban::UnbannedService, + }, + AiAgentJoined { + room: RoomInfo, + data: ai::AiAgentJoinedService, + }, + AiAgentLeft { + room: RoomInfo, + data: ai::AiAgentLeftService, + }, + AiAgentList { + room: RoomInfo, + data: ai::RoomAiListService, + }, + AiAgentStatusChanged { + room: RoomInfo, + data: ai::AiAgentStatusChangedService, + }, + VoiceChannelJoined { + room: RoomInfo, + data: voice::VoiceChannelJoinedService, + }, + VoiceChannelLeft { + room: RoomInfo, + data: voice::VoiceChannelLeftService, + }, + ConversationPinned { + room: RoomInfo, + data: conversation::ConversationPinnedService, + }, + ConversationUnpinned { + room: RoomInfo, + data: conversation::ConversationUnpinnedService, + }, + ConversationMuted { + room: RoomInfo, + data: conversation::ConversationMutedService, + }, + ConversationUnmuted { + room: RoomInfo, + data: conversation::ConversationUnmutedService, + }, + ConversationUnreadUpdated { + room: RoomInfo, + data: conversation::ConversationUnreadUpdatedService, + }, + ConversationList { + data: Vec, + }, + DmCreated { + room: RoomInfo, + data: dm::DmCreatedService, + }, + DmClosed { + room: RoomInfo, + data: dm::DmClosedService, + }, + DmReopened { + room: RoomInfo, + data: dm::DmReopenedService, + }, + DmList { + data: Vec, + }, + MessageRead { + room: RoomInfo, + data: message_read::MessageReadService, + }, + MessageReadBatch { + room: RoomInfo, + data: message_read::MessageReadBatchService, + }, + MessageReaders { + data: message_read::MessageReadersService, + }, + MessageStarred { + room: RoomInfo, + data: star::MessageStarredService, + }, + MessageUnstarred { + room: RoomInfo, + data: star::MessageUnstarredService, + }, + StarredList { + data: Vec, + }, + MessageForwarded { + room: RoomInfo, + data: forward::MessageForwardedService, + }, + Response { + request_id: Uuid, + data: serde_json::Value, + }, +} + +#[derive(Debug, Clone, Serialize)] +pub struct WsError { + pub code: i32, + pub error: String, + pub message: String, +} diff --git a/lib/channel/http/session.rs b/lib/channel/http/session.rs new file mode 100644 index 0000000..0313e51 --- /dev/null +++ b/lib/channel/http/session.rs @@ -0,0 +1,22 @@ +use uuid::Uuid; + +use crate::token::ChannelTokenContext; + +#[derive(Clone)] +pub struct WsUserCtx { + pub user_id: Uuid, + pub device_id: String, + pub client_id: String, + pub display_name: String, +} + +impl From for WsUserCtx { + fn from(ctx: ChannelTokenContext) -> Self { + Self { + user_id: ctx.user_id, + device_id: ctx.device_id, + client_id: ctx.client_id, + display_name: ctx.user_id.to_string(), + } + } +} diff --git a/lib/channel/http/types.rs b/lib/channel/http/types.rs new file mode 100644 index 0000000..aa6b56b --- /dev/null +++ b/lib/channel/http/types.rs @@ -0,0 +1,329 @@ +pub const WS_PROTOCOL_VERSION: u32 = 1; + +use chrono::{DateTime, Utc}; +use serde::Deserialize; +use uuid::Uuid; + +use crate::event::presence::UserPresenceStatus; + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum WsInMessage { + Ping, + Subscribe { + room: Uuid, + }, + Unsubscribe { + room: Uuid, + }, + TypingStart { + room: Uuid, + }, + TypingStop { + room: Uuid, + }, + ReadReceipt { + room: Uuid, + last_read_seq: i64, + }, + MessageList { + room: Uuid, + before_seq: Option, + after_seq: Option, + limit: Option, + }, + MessageAround { + room: Uuid, + seq: i64, + limit: Option, + }, + MessageCreate { + room: Uuid, + content: String, + content_type: Option, + thread: Option, + in_reply_to: Option, + }, + MessageUpdate { + message: Uuid, + content: String, + }, + MessageRevoke { + message: Uuid, + }, + RoomGet { + room: Uuid, + }, + RoomCreate { + workspace: Uuid, + room_name: String, + public: bool, + category: Option, + }, + RoomUpdate { + room: Uuid, + room_name: Option, + public: Option, + category: Option, + }, + RoomDelete { + room: Uuid, + }, + CategoryCreate { + workspace: Uuid, + name: String, + position: Option, + }, + CategoryUpdate { + id: Uuid, + name: Option, + position: Option, + }, + CategoryDelete { + id: Uuid, + }, + AccessGrant { + room: Uuid, + user: Uuid, + }, + AccessRevoke { + room: Uuid, + user: Uuid, + }, + StateSetReadSeq { + room: Uuid, + last_read_seq: i64, + }, + MissedMessages { + room: Uuid, + after_seq: i64, + limit: Option, + }, + CsrfToken, + StateUpdateDnd { + room: Uuid, + do_not_disturb: Option, + dnd_start_hour: Option, + dnd_end_hour: Option, + }, + ReactionAdd { + room: Uuid, + message: Uuid, + emoji: String, + }, + ReactionRemove { + room: Uuid, + message: Uuid, + emoji: String, + }, + ThreadCreate { + room: Uuid, + parent: i64, + }, + ThreadResolve { + thread_id: Uuid, + }, + ThreadArchive { + thread_id: Uuid, + }, + PinAdd { + room: Uuid, + message: Uuid, + }, + PinRemove { + room: Uuid, + message: Uuid, + }, + DraftSave { + room: Uuid, + content: String, + }, + DraftClear { + room: Uuid, + }, + Search { + q: String, + room: Option, + start_time: Option>, + end_time: Option>, + sender_id: Option, + content_type: Option, + limit: Option, + offset: Option, + }, + NotificationMarkRead { + id: Uuid, + }, + NotificationMarkAllRead { + workspace_id: Option, + }, + NotificationArchive { + id: Uuid, + }, + PresenceUpdate { + status: UserPresenceStatus, + }, + CustomStatusUpdate { + emoji: Option, + text: Option, + expires_at: Option>, + }, + InviteCreate { + workspace: Uuid, + room: Option, + max_uses: Option, + expires_at: Option>, + }, + InviteAccept { + code: String, + }, + InviteRevoke { + id: Uuid, + }, + BanCreate { + workspace: Uuid, + user: Uuid, + reason: Option, + expires_at: Option>, + }, + BanRemove { + workspace: Uuid, + user: Uuid, + }, + VoiceJoin { + room: Uuid, + }, + VoiceLeave { + room: Uuid, + }, + VoiceMute { + room: Uuid, + muted: bool, + }, + VoiceDeaf { + room: Uuid, + deafened: bool, + }, + ScreenShare { + room: Uuid, + start: bool, + }, + AiList { + room: Uuid, + }, + AiUpsert { + room: Uuid, + model: Uuid, + }, + AiDelete { + room: Uuid, + agent_id: Uuid, + }, + AiStop { + room: Uuid, + }, + UserSummary { + username: String, + }, + ConversationPin { + room: Uuid, + pin: bool, + }, + ConversationMute { + room: Uuid, + mute: bool, + }, + ConversationNotifyLevel { + room: Uuid, + notify_level: String, + }, + ConversationList, + DmCreate { + recipient: Uuid, + }, + DmClose { + room: Uuid, + }, + DmList, + MessageMarkRead { + room: Uuid, + message_ids: Vec, + }, + MessageGetReaders { + message_id: Uuid, + }, + MessageStar { + room: Uuid, + message: Uuid, + star: bool, + }, + StarredList { + room: Option, + limit: Option, + }, + MessageForward { + source_message_id: Uuid, + target_room: Uuid, + }, +} + +macro_rules! room_variants { + ($self:expr, $($variant:ident),* $(,)?) => { + match $self { + $( Self::$variant { room, .. } => Some(*room), )* + _ => None, + } + }; +} + +impl WsInMessage { + pub fn room_id(&self) -> Option { + room_variants!( + self, + Subscribe, + Unsubscribe, + TypingStart, + TypingStop, + ReadReceipt, + MessageCreate, + MessageList, + RoomUpdate, + RoomDelete, + RoomGet, + AccessGrant, + AccessRevoke, + StateSetReadSeq, + StateUpdateDnd, + MissedMessages, + ReactionAdd, + ReactionRemove, + ThreadCreate, + PinAdd, + PinRemove, + DraftSave, + DraftClear, + VoiceJoin, + VoiceLeave, + VoiceMute, + VoiceDeaf, + ScreenShare, + AiList, + AiUpsert, + AiDelete, + AiStop, + ConversationPin, + ConversationMute, + ConversationNotifyLevel, + DmClose, + MessageMarkRead, + MessageStar, + ) + .or_else(|| match self { + Self::Search { + room: Some(room), .. + } => Some(*room), + Self::MessageForward { target_room, .. } => Some(*target_room), + _ => None, + }) + } +} diff --git a/lib/channel/http/ws.rs b/lib/channel/http/ws.rs new file mode 100644 index 0000000..f4d3373 --- /dev/null +++ b/lib/channel/http/ws.rs @@ -0,0 +1,153 @@ +use socketio::{EventPayload, Socket}; +use uuid::Uuid; + +use crate::{ChannelBus, ChannelError, ChannelResult}; + +use super::handler::WsHandler; +use super::out_event::{WsError, WsOutEvent}; +use super::types::WsInMessage; + +const CHANNEL_EVENT: &str = "channel.message"; + +pub async fn register_message_handler( + bus: &ChannelBus, +) -> crate::ChannelResult<()> { + let namespace = bus.inner.io.namespace(&bus.inner.config.namespace).await; + + let bus_clone = bus.clone(); + namespace + .on(CHANNEL_EVENT, move |socket, data: EventPayload| { + let bus = bus_clone.clone(); + async move { + handle_inbound(&bus, &socket, data).await; + } + }) + .await; + + Ok(()) +} + +async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) { + let user_id = match socket.session_user() { + Some(id) => id, + None => { + tracing::warn!("channel message from unauthenticated socket"); + send_error(socket, ChannelError::Unauthorized.to_ws_error()).await; + return; + } + }; + let payload = match data.args.first() { + Some(v) => v, + None => { + tracing::warn!("channel message with empty args"); + return; + } + }; + + let parsed = payload; + + let text = serde_json::to_string(payload).unwrap_or_default(); + if parsed + .get("type") + .and_then(|t| t.as_str()) + == Some("ping") + { + let pong = WsOutEvent::Pong { + protocol_version: super::types::WS_PROTOCOL_VERSION, + }; + send_event(socket, &pong).await.ok(); + return; + } + if !check_rate_limit(bus, user_id).await { + send_error(socket, ChannelError::RateLimitExceeded.to_ws_error()).await; + return; + } + if text.len() > super::handler::MAX_TEXT_LEN { + send_error( + socket, + WsError { + code: 422, + error: "message_too_long".to_string(), + message: "message exceeds maximum length".to_string(), + }, + ) + .await; + return; + } + let request_id: Option = parsed + .get("_request_id") + .and_then(|r| serde_json::from_value(r.clone()).ok()); + match serde_json::from_value::(payload.clone()) { + Ok(in_msg) => match WsHandler::handle(bus, user_id, in_msg).await { + Ok(Some(event)) => { + let rid = request_id.unwrap_or(Uuid::nil()); + let resp = WsOutEvent::Response { + request_id: rid, + data: serde_json::to_value(&event).unwrap_or_default(), + }; + send_event(socket, &resp).await.ok(); + } + Ok(None) => { + let rid = request_id.unwrap_or(Uuid::nil()); + let ack = WsOutEvent::Response { + request_id: rid, + data: serde_json::json!({"ok": true}), + }; + send_event(socket, &ack).await.ok(); + } + Err(e) => { + tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed"); + let rid = request_id.unwrap_or(Uuid::nil()); + let err_resp = WsOutEvent::Response { + request_id: rid, + data: serde_json::to_value(&e.to_ws_error()) + .unwrap_or_default(), + }; + send_event(socket, &err_resp).await.ok(); + } + }, + Err(e) => { + tracing::warn!(error = %e, "WS transport parse error"); + send_error( + socket, + WsError { + code: 400, + error: "parse_error".to_string(), + message: e.to_string(), + }, + ) + .await; + } + } +} + +async fn check_rate_limit(bus: &ChannelBus, user_id: Uuid) -> bool { + bus.inner + .rate_limiter + .check_rate_limit(user_id, "ws_message") + .await + .unwrap_or(true) +} + +async fn send_event(socket: &Socket, event: &WsOutEvent) -> ChannelResult<()> { + let json = serde_json::to_string(event)?; + socket + .emit(CHANNEL_EVENT, &json) + .await + .map_err(|e| { + tracing::warn!(error = %e, "WS send failed"); + ChannelError::SocketIo(e) + }) +} + +async fn send_error(socket: &Socket, error: WsError) { + let json = serde_json::json!({ + "type": "error", + "code": error.code, + "error": error.error, + "message": error.message, + }); + if let Err(e) = socket.emit(CHANNEL_EVENT, json.to_string()).await { + tracing::warn!(error = %e, "WS error send failed"); + } +} diff --git a/lib/channel/lib.rs b/lib/channel/lib.rs new file mode 100644 index 0000000..d0e7c26 --- /dev/null +++ b/lib/channel/lib.rs @@ -0,0 +1,39 @@ +mod ack; +mod bus; +mod cdn; +mod circuit_breaker; +mod config; +mod dedup; +mod envelope; +mod error; +pub mod event; +pub mod http; +mod metrics; +mod pagination; +mod reconnect; +pub mod rooms; +mod search; +mod security; +mod seq; +mod token; + +pub use ack::{AckRequest, AckResponse, AckStatus, AckTracker, MessageAck}; +pub use bus::ChannelBus; +pub use cdn::{CdnManager, CdnStoredFile}; +pub use circuit_breaker::{CircuitBreaker, CircuitBreakerError}; +pub use config::ChannelBusConfig; +pub use dedup::DeduplicationManager; +pub use envelope::ChannelEnvelope; +pub use error::{ChannelError, ChannelResult}; +pub use metrics::ChannelMetrics; +pub use pagination::{ + MessageItem, MessagePage, MessagePagination, PaginationDirection, + PaginationParams, +}; +pub use reconnect::{ClientState, MissedMessage, ReconnectManager}; +pub use search::{SearchEngine, SearchHit, SearchQuery, SearchResult}; +pub use security::{CsrfProtection, RateLimiter}; +pub use seq::SeqAllocator; +pub use token::{ + ChannelAccessToken, ChannelTokenApply, ChannelTokenContext, TOKEN_TTL_SECS, +}; diff --git a/lib/channel/metrics.rs b/lib/channel/metrics.rs new file mode 100644 index 0000000..7f32113 --- /dev/null +++ b/lib/channel/metrics.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +#[derive(Clone)] +pub struct ChannelMetrics { + pub messages_sent: Arc, + pub messages_received: Arc, + pub messages_failed: Arc, + pub active_connections: Arc, +} + +impl ChannelMetrics { + pub fn new() -> Self { + Self { + messages_sent: Arc::new(std::sync::atomic::AtomicU64::new(0)), + messages_received: Arc::new(std::sync::atomic::AtomicU64::new(0)), + messages_failed: Arc::new(std::sync::atomic::AtomicU64::new(0)), + active_connections: Arc::new(std::sync::atomic::AtomicI64::new(0)), + } + } + + pub fn increment_sent(&self) { + self.messages_sent + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_received(&self) { + self.messages_received + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_failed(&self) { + self.messages_failed + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_connections(&self) { + self.active_connections + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn decrement_connections(&self) { + self.active_connections + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } +} diff --git a/lib/channel/pagination.rs b/lib/channel/pagination.rs new file mode 100644 index 0000000..811bf1c --- /dev/null +++ b/lib/channel/pagination.rs @@ -0,0 +1,220 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::rooms::RM_COLUMNS; +use crate::{ChannelError, ChannelResult}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessagePage { + pub messages: Vec, + pub has_more: bool, + pub next_cursor: Option, + pub prev_cursor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageItem { + pub id: Uuid, + pub room_id: Uuid, + pub seq: i64, + pub thread: Option, + pub parent: Option, + pub content: String, + pub content_type: String, + pub pinned: bool, + pub system_type: Option, + pub metadata: serde_json::Value, + pub sender_id: Uuid, + pub send_at: chrono::DateTime, + pub edited_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PaginationParams { + pub room_id: Uuid, + pub limit: u64, + pub cursor: Option, + pub direction: PaginationDirection, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PaginationDirection { + Before, + After, +} + +pub struct MessagePagination { + db: db::AppDatabase, +} + +impl MessagePagination { + pub fn new(db: db::AppDatabase) -> Self { + Self { db } + } + + pub async fn get_messages( + &self, + params: PaginationParams, + ) -> ChannelResult { + let limit = std::cmp::Ord::min(params.limit, 100) as i64; + let cursor_seq = params.cursor.and_then(|c| c.parse::().ok()); + + let messages = match (params.direction, cursor_seq) { + (PaginationDirection::Before, Some(seq)) => { + db::sqlx::query_as::<_, model::room::RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \ + ORDER BY seq DESC LIMIT $3" + )), + ) + .bind(params.room_id) + .bind(seq) + .bind(limit + 1) + .fetch_all(self.db.reader()) + .await? + } + (PaginationDirection::After, Some(seq)) => { + db::sqlx::query_as::<_, model::room::RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ + ORDER BY seq ASC LIMIT $3" + )), + ) + .bind(params.room_id) + .bind(seq) + .bind(limit + 1) + .fetch_all(self.db.reader()) + .await? + } + _ => { + db::sqlx::query_as::<_, model::room::RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND deleted_at IS NULL \ + ORDER BY seq DESC LIMIT $2" + )), + ) + .bind(params.room_id) + .bind(limit + 1) + .fetch_all(self.db.reader()) + .await? + } + }; + + let has_more = messages.len() > limit as usize; + let messages: Vec<_> = + messages.into_iter().take(limit as usize).collect(); + + let next_cursor = if has_more { + messages.last().map(|m| m.seq.to_string()) + } else { + None + }; + let prev_cursor = messages.first().map(|m| m.seq.to_string()); + + let items: Vec = messages + .into_iter() + .map(|m| MessageItem { + id: m.id, + room_id: m.room, + seq: m.seq, + thread: m.thread, + parent: m.parent, + content: m.content, + content_type: m.content_type, + pinned: m.pinned, + system_type: m.system_type, + metadata: m.metadata, + sender_id: m.author, + send_at: m.created_at, + edited_at: m.edited_at, + }) + .collect(); + + Ok(MessagePage { + messages: items, + has_more, + next_cursor, + prev_cursor, + }) + } + + pub async fn get_messages_around( + &self, + room_id: Uuid, + message_id: Uuid, + context_size: i64, + ) -> ChannelResult { + let target = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE id = $1 AND room = $2 AND deleted_at IS NULL" + )), + ) + .bind(message_id) + .bind(room_id) + .fetch_optional(self.db.reader()) + .await? + .ok_or(ChannelError::Internal("message not found".to_string()))?; + + let before = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \ + ORDER BY seq DESC LIMIT $3" + )), + ) + .bind(room_id) + .bind(target.seq) + .bind(context_size) + .fetch_all(self.db.reader()) + .await?; + + let after = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ + ORDER BY seq ASC LIMIT $3" + )), + ) + .bind(room_id) + .bind(target.seq) + .bind(context_size) + .fetch_all(self.db.reader()) + .await?; + + let mut all_messages = before; + all_messages.reverse(); + all_messages.push(target); + all_messages.extend(after); + + let items: Vec = all_messages + .into_iter() + .map(|m| MessageItem { + id: m.id, + room_id: m.room, + seq: m.seq, + thread: m.thread, + parent: m.parent, + content: m.content, + content_type: m.content_type, + pinned: m.pinned, + system_type: m.system_type, + metadata: m.metadata, + sender_id: m.author, + send_at: m.created_at, + edited_at: m.edited_at, + }) + .collect(); + + Ok(MessagePage { + messages: items, + has_more: false, + next_cursor: None, + prev_cursor: None, + }) + } +} diff --git a/lib/channel/reconnect.rs b/lib/channel/reconnect.rs new file mode 100644 index 0000000..97d394c --- /dev/null +++ b/lib/channel/reconnect.rs @@ -0,0 +1,113 @@ +use std::collections::HashMap; +use uuid::Uuid; + +use model::room::RoomMessageModel; +use serde::{Deserialize, Serialize}; + +use crate::rooms::RM_COLUMNS; +use crate::ChannelResult; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientState { + pub user_id: Uuid, + pub last_seq: HashMap, + pub last_seen: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MissedMessage { + pub room_id: Uuid, + pub message_id: Uuid, + pub seq: i64, + pub content: String, + pub sender_id: Uuid, + pub send_at: chrono::DateTime, +} + +#[derive(Clone)] +pub struct ReconnectManager { + cache: cache::AppCache, + db: db::AppDatabase, +} + +impl ReconnectManager { + pub fn new(cache: cache::AppCache, db: db::AppDatabase) -> Self { + Self { cache, db } + } + + pub async fn save_client_state( + &self, + user_id: Uuid, + room_id: Uuid, + last_seq: i64, + ) -> ChannelResult<()> { + let key = format!("client:state:{}:{}", user_id, room_id); + self.cache.set(&key, &last_seq.to_string()).await?; + if let Some(cluster) = &self.cache.cluster { + cluster + .expire(&key, std::time::Duration::from_secs(86400)) + .await?; + } + Ok(()) + } + + pub async fn get_last_seq( + &self, + user_id: Uuid, + room_id: Uuid, + ) -> ChannelResult> { + let key = format!("client:state:{}:{}", user_id, room_id); + let value: Option = self.cache.get(&key).await?; + Ok(value.and_then(|v| v.parse::().ok())) + } + + pub async fn get_missed_messages( + &self, + room_id: Uuid, + since_seq: i64, + ) -> ChannelResult> { + let messages = db::sqlx::query_as::<_, RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ + ORDER BY seq ASC \ + LIMIT 100" + )), + ) + .bind(room_id) + .bind(since_seq) + .fetch_all(self.db.reader()) + .await?; + + let missed: Vec = messages + .into_iter() + .map(|m| MissedMessage { + room_id: m.room, + message_id: m.id, + seq: m.seq, + content: m.content, + sender_id: m.author, + send_at: m.created_at, + }) + .collect(); + + Ok(missed) + } + + pub async fn handle_reconnect( + &self, + _user_id: Uuid, + room_states: HashMap, + ) -> ChannelResult>> { + let mut result = HashMap::new(); + + for (room_id, client_seq) in room_states { + let missed = self.get_missed_messages(room_id, client_seq).await?; + if !missed.is_empty() { + result.insert(room_id, missed); + } + } + + Ok(result) + } +} diff --git a/lib/channel/richtext.rs b/lib/channel/richtext.rs new file mode 100644 index 0000000..3dc4743 --- /dev/null +++ b/lib/channel/richtext.rs @@ -0,0 +1,112 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RichTextBlock { + pub block_type: BlockType, + pub content: String, + pub attributes: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum BlockType { + Text, + Code, + Quote, + Link, + Mention, + Emoji, + Image, +} + +pub struct RichTextRenderer; + +impl RichTextRenderer { + pub fn new() -> Self { + Self {} + } + + pub fn parse_markdown(&self, content: &str) -> Vec { + vec![RichTextBlock { + block_type: BlockType::Text, + content: content.to_string(), + attributes: None, + }] + } + + pub fn parse_mentions(&self, content: &str) -> Vec { + content + .split_whitespace() + .filter(|w| w.starts_with('@')) + .filter_map(|w| Uuid::parse_str(&w[1..]).ok()) + .collect() + } + + pub fn highlight_code(&self, code: &str, language: &str) -> String { + format!("```{}\n{}\n```", language, code) + } + + pub fn render_to_html(&self, blocks: &[RichTextBlock]) -> String { + blocks + .iter() + .map(|block| match block.block_type { + BlockType::Text => { + format!("

{}

", html_escape(&block.content)) + } + BlockType::Code => format!( + "
{}
", + html_escape(&block.content) + ), + BlockType::Quote => format!( + "
{}
", + html_escape(&block.content) + ), + BlockType::Link => { + let safe_href = sanitize_uri(&block.content); + format!( + "{}", + html_escape(&safe_href), + html_escape(&block.content) + ) + } + BlockType::Mention => format!( + "@{}", + html_escape(&block.content) + ), + BlockType::Emoji => format!( + "{}", + html_escape(&block.content) + ), + BlockType::Image => { + let safe_src = sanitize_uri(&block.content); + if safe_src.is_empty() { + String::new() + } else { + format!("", html_escape(&safe_src)) + } + } + }) + .collect::>() + .join("\n") + } +} +fn sanitize_uri(uri: &str) -> String { + let lower = uri.to_lowercase(); + if lower.starts_with("http://") + || lower.starts_with("https://") + || lower.starts_with("mailto:") + { + uri.to_string() + } else { + String::new() + } +} + +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} diff --git a/lib/channel/rooms.rs b/lib/channel/rooms.rs new file mode 100644 index 0000000..b079ac8 --- /dev/null +++ b/lib/channel/rooms.rs @@ -0,0 +1,221 @@ +use cache::AppCache; +use db::{AppDatabase, sqlx}; +use model::room::RoomMessageModel; +use serde::Serialize; +use uuid::Uuid; + +use crate::{ChannelBusConfig, ChannelResult}; + +pub(crate) const RM_COLUMNS: &str = + "id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at"; + +pub(crate) fn room_socket_name(room: Uuid) -> String { + format!("room:{room}") +} + +pub(crate) fn user_rooms_cache_key(user: Uuid) -> String { + format!("channel:user:{user}:rooms") +} +#[derive(Debug, Serialize)] +pub struct RoomListItem { + pub id: Uuid, + pub name: String, + pub topic: Option, + pub room_type: String, + pub is_private: bool, + pub category: Option, + pub workspace_id: Uuid, +} +#[derive(Debug, Serialize)] +pub struct CategoryListItem { + pub id: Uuid, + pub name: String, + pub position: i32, +} +pub async fn user_rooms_for_api( + db: &AppDatabase, + cache: &AppCache, + config: &ChannelBusConfig, + user: Uuid, +) -> ChannelResult> { + let room_ids = user_rooms(db, cache, config, user).await?; + if room_ids.is_empty() { + return Ok(Vec::new()); + } + + let rows = sqlx::query_as::<_, (Uuid, String, Option, String, bool, Option, Uuid)>( + "SELECT id, name, topic, room_type, is_private, parent, wk \ + FROM room \ + WHERE id = ANY($1) AND deleted_at IS NULL AND is_archived = false \ + ORDER BY name", + ) + .bind(&room_ids) + .fetch_all(db.reader()) + .await?; + + Ok(rows + .into_iter() + .map(|(id, name, topic, room_type, is_private, category, workspace_id)| RoomListItem { + id, + name, + topic, + room_type, + is_private, + category, + workspace_id, + }) + .collect()) +} +pub async fn user_categories_for_api( + db: &AppDatabase, + cache: &AppCache, + config: &ChannelBusConfig, + user: Uuid, +) -> ChannelResult> { + let room_ids = user_rooms(db, cache, config, user).await?; + if room_ids.is_empty() { + return Ok(Vec::new()); + } + let wk_rows = sqlx::query_as::<_, (Uuid,)>( + "SELECT DISTINCT wk FROM room WHERE id = ANY($1) AND deleted_at IS NULL", + ) + .bind(&room_ids) + .fetch_all(db.reader()) + .await?; + + let wk_ids: Vec = wk_rows.into_iter().map(|r| r.0).collect(); + if wk_ids.is_empty() { + return Ok(Vec::new()); + } + + let rows = sqlx::query_as::<_, (Uuid, String, i32)>( + "SELECT id, name, position FROM room_category WHERE wk = ANY($1) ORDER BY position, name", + ) + .bind(&wk_ids) + .fetch_all(db.reader()) + .await?; + + Ok(rows + .into_iter() + .map(|(id, name, position)| CategoryListItem { id, name, position }) + .collect()) +} + +pub(crate) async fn user_rooms( + db: &AppDatabase, + cache: &AppCache, + config: &ChannelBusConfig, + user: Uuid, +) -> ChannelResult> { + let key = user_rooms_cache_key(user); + if let Some(rooms) = cache.get::>(&key).await? { + return Ok(rooms); + } + + let rooms = load_user_rooms(db, user).await?; + cache_set_with_ttl(cache, &key, &rooms, config.room_cache_ttl_hint).await?; + Ok(rooms) +} + +pub(crate) async fn refresh_user_rooms_cache( + db: &AppDatabase, + cache: &AppCache, + config: &ChannelBusConfig, + user: Uuid, +) -> ChannelResult> { + let key = user_rooms_cache_key(user); + cache.remove(&key).await?; + let rooms = load_user_rooms(db, user).await?; + cache_set_with_ttl(cache, &key, &rooms, config.room_cache_ttl_hint).await?; + Ok(rooms) +} + +async fn cache_set_with_ttl( + cache: &AppCache, + key: &str, + value: &T, + ttl: Option, +) -> ChannelResult<()> +where + T: serde::Serialize + serde::de::DeserializeOwned, +{ + cache.set(key, value).await?; + if let Some(ttl) = ttl { + if let Some(cluster) = &cache.cluster { + cluster.expire(key, ttl).await?; + } + } + Ok(()) +} + +pub(crate) async fn active_workspace_users( + db: &AppDatabase, + wk: Uuid, +) -> ChannelResult> { + let rows = sqlx::query_as::<_, (Uuid,)>( + "SELECT \"user\" FROM wk_member WHERE wk = $1 AND leave_at IS NULL", + ) + .bind(wk) + .fetch_all(db.reader()) + .await?; + Ok(rows.into_iter().map(|row| row.0).collect()) +} + +pub(crate) async fn room_workspace( + db: &AppDatabase, + room: Uuid, +) -> ChannelResult> { + let row = sqlx::query_as::<_, (Uuid,)>("SELECT wk FROM room WHERE id = $1") + .bind(room) + .fetch_optional(db.reader()) + .await?; + Ok(row.map(|row| row.0)) +} + +pub(crate) async fn catchup_messages( + db: &AppDatabase, + config: &ChannelBusConfig, + room: Uuid, + after_seq: i64, +) -> ChannelResult> { + let rows = sqlx::query_as::<_, RoomMessageModel>( + db::sqlx::AssertSqlSafe(format!( + "SELECT {RM_COLUMNS} FROM room_message \ + WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ + ORDER BY seq ASC \ + LIMIT $3" + )), + ) + .bind(room) + .bind(after_seq) + .bind(config.catchup_limit) + .fetch_all(db.reader()) + .await?; + Ok(rows) +} + +async fn load_user_rooms( + db: &AppDatabase, + user: Uuid, +) -> ChannelResult> { + let rows = sqlx::query_as::<_, (Uuid,)>( + "SELECT r.id \ + FROM room r \ + INNER JOIN wk_member wm ON wm.wk = r.wk \ + WHERE wm.\"user\" = $1 \ + AND wm.leave_at IS NULL \ + AND r.deleted_at IS NULL \ + AND r.is_archived = false \ + AND (r.is_private = false \ + OR EXISTS ( \ + SELECT 1 FROM room_permission_overwrite po \ + WHERE po.room = r.id AND po.target_id = $1 \ + )) \ + ORDER BY r.id", + ) + .bind(user) + .fetch_all(db.reader()) + .await?; + Ok(rows.into_iter().map(|row| row.0).collect()) +} diff --git a/lib/channel/search.rs b/lib/channel/search.rs new file mode 100644 index 0000000..8b52932 --- /dev/null +++ b/lib/channel/search.rs @@ -0,0 +1,137 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::ChannelResult; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchQuery { + pub query: String, + pub room_id: Option, + pub user_id: Option, + pub limit: u64, + pub offset: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + pub total: u64, + pub hits: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchHit { + pub message_id: Uuid, + pub room_id: Uuid, + pub content: String, + pub highlighted: String, + pub sender_id: Uuid, + pub send_at: chrono::DateTime, + pub score: f64, +} + +pub struct SearchEngine { + db: db::AppDatabase, +} + +impl SearchEngine { + pub fn new(db: db::AppDatabase) -> Self { + Self { db } + } + + pub async fn search( + &self, + query: SearchQuery, + ) -> ChannelResult { + let search_term = format!("%{}%", escape_like(&query.query)); + let room_filter = query.room_id; + let user_filter = query.user_id; + + let count: (i64,) = db::sqlx::query_as( + "SELECT COUNT(*) FROM room_message \ + WHERE ($1::uuid IS NULL OR room = $1) \ + AND ($2::uuid IS NULL OR author = $2) \ + AND content LIKE $3 ESCAPE '\\' \ + AND deleted_at IS NULL", + ) + .bind(room_filter) + .bind(user_filter) + .bind(&search_term) + .fetch_one(self.db.reader()) + .await?; + + let total = count.0 as u64; + + let messages = db::sqlx::query_as::<_, model::room::RoomMessageModel>( + "SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ + system_type, metadata, edited_at, created_at, updated_at, deleted_at \ + FROM room_message \ + WHERE ($1::uuid IS NULL OR room = $1) \ + AND ($2::uuid IS NULL OR author = $2) \ + AND content LIKE $3 ESCAPE '\\' \ + AND deleted_at IS NULL \ + ORDER BY created_at DESC LIMIT $4 OFFSET $5" + ) + .bind(room_filter) + .bind(user_filter) + .bind(&search_term) + .bind(query.limit as i64) + .bind(query.offset as i64) + .fetch_all(self.db.reader()) + .await?; + + let hits: Vec = messages + .into_iter() + .map(|m| SearchHit { + message_id: m.id, + room_id: m.room, + content: m.content.clone(), + highlighted: highlight_text(&m.content, &query.query), + sender_id: m.author, + send_at: m.created_at, + score: 1.0, + }) + .collect(); + + Ok(SearchResult { total, hits }) + } +} +fn escape_like(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + for ch in input.chars() { + match ch { + '\\' => out.push_str("\\\\"), + '%' => out.push_str("\\%"), + '_' => out.push_str("\\_"), + _ => out.push(ch), + } + } + out +} +fn highlight_text(content: &str, query: &str) -> String { + let lower_content = content.to_lowercase(); + let lower_query = query.to_lowercase(); + let char_pos = match lower_content.find(&lower_query) { + Some(p) => p, + None => return content.to_string(), + }; + let match_chars = lower_content[..char_pos].chars().count(); + let query_chars = lower_query.chars().count(); + + let mut before = String::new(); + let mut matched = String::new(); + let mut after = String::new(); + let mut char_idx = 0; + + for ch in content.chars() { + if char_idx < match_chars { + before.push(ch); + } else if char_idx < match_chars + query_chars { + matched.push(ch); + } else { + after.push(ch); + } + char_idx += 1; + } + + format!("{}{}{}", before, matched, after) +} diff --git a/lib/channel/security.rs b/lib/channel/security.rs new file mode 100644 index 0000000..7c2e297 --- /dev/null +++ b/lib/channel/security.rs @@ -0,0 +1,152 @@ +use std::time::Duration; +use uuid::Uuid; + +use crate::{ChannelError, ChannelResult}; + +const RATE_LIMIT_SCRIPT: &str = r#" +local key = KEYS[1] +local max = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local current = tonumber(redis.call('INCR', key)) +if current == 1 then + redis.call('EXPIRE', key, window) +end +if current > max then + return 0 +end +return 1 +"#; + +#[derive(Clone)] +pub struct RateLimiter { + cache: cache::AppCache, + max_requests: u32, + window: Duration, +} + +impl RateLimiter { + pub fn new(cache: cache::AppCache) -> Self { + Self { + cache, + max_requests: 100, + window: Duration::from_secs(60), + } + } + + pub fn with_config( + cache: cache::AppCache, + max_requests: u32, + window: Duration, + ) -> Self { + Self { + cache, + max_requests, + window, + } + } + + pub async fn check_rate_limit( + &self, + user_id: Uuid, + action: &str, + ) -> ChannelResult { + let cluster = require_cluster(&self.cache)?; + let key = format!("ratelimit:{}:{}", user_id, action); + let mut conn = cluster.conn(); + + let allowed: i64 = redis::Cmd::new() + .arg("EVAL") + .arg(RATE_LIMIT_SCRIPT) + .arg(1) + .arg(&key) + .arg(self.max_requests) + .arg(self.window.as_secs()) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + Ok(allowed == 1) + } + + pub async fn get_remaining( + &self, + user_id: Uuid, + action: &str, + ) -> ChannelResult { + let key = format!("ratelimit:{}:{}", user_id, action); + let count: Option = self.cache.get(&key).await?; + let current = count.unwrap_or(0); + Ok(self.max_requests.saturating_sub(current)) + } +} + +const CSRF_TTL_SECS: u64 = 3600; + +#[derive(Clone)] +pub struct CsrfProtection { + cache: cache::AppCache, +} + +impl CsrfProtection { + pub fn new(cache: cache::AppCache) -> Self { + Self { cache } + } + + pub async fn generate_token(&self, user_id: Uuid) -> ChannelResult { + let token = Uuid::new_v4().to_string(); + let key = format!("csrf:{}:{}", user_id, token); + let cluster = require_cluster(&self.cache)?; + let mut conn = cluster.conn(); + + let _: () = redis::Cmd::new() + .arg("SET") + .arg(&key) + .arg("1") + .arg("EX") + .arg(CSRF_TTL_SECS) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + Ok(token) + } + + pub async fn validate_token( + &self, + user_id: Uuid, + token: &str, + ) -> ChannelResult { + let key = format!("csrf:{}:{}", user_id, token); + let cluster = require_cluster(&self.cache)?; + let mut conn = cluster.conn(); + const VALIDATE_SCRIPT: &str = r#" +local key = KEYS[1] +local exists = redis.call('EXISTS', key) +if exists == 1 then + redis.call('DEL', key) + return 1 +end +return 0 +"#; + + let valid: i64 = redis::Cmd::new() + .arg("EVAL") + .arg(VALIDATE_SCRIPT) + .arg(1) + .arg(&key) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + Ok(valid == 1) + } +} + +pub(crate) fn require_cluster( + cache: &cache::AppCache, +) -> ChannelResult<&cache::ClusterCache> { + cache + .cluster + .as_ref() + .ok_or(ChannelError::Internal("no cluster cache".to_string())) +} diff --git a/lib/channel/seq.rs b/lib/channel/seq.rs new file mode 100644 index 0000000..129b417 --- /dev/null +++ b/lib/channel/seq.rs @@ -0,0 +1,185 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicI64, Ordering}; + +use dashmap::DashMap; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::{ChannelError, ChannelResult, security::require_cluster}; + +const SEQ_KEY_PREFIX: &str = "room:seq:"; +const DEFAULT_SEGMENT_SIZE: u64 = 1024; +const MAX_REFRESH_RETRIES: u32 = 3; + +const BOOTSTRAP_SCRIPT: &str = r#" +local key = KEYS[1] +local db_max = tonumber(ARGV[1]) +local current = tonumber(redis.call('GET', key) or '0') +if current < db_max then + redis.call('SET', key, db_max) +end +return tonumber(redis.call('GET', key)) +"#; + +struct SegmentState { + end: i64, + next: AtomicI64, +} + +pub struct SeqAllocator(Arc); + +struct SeqAllocatorInner { + cache: cache::AppCache, + db: db::AppDatabase, + segments: DashMap>, + refresh_locks: DashMap>>, + segment_size: u64, +} + +impl Clone for SeqAllocator { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +impl SeqAllocator { + pub fn new(cache: cache::AppCache, db: db::AppDatabase) -> Self { + Self::with_segment_size(cache, db, DEFAULT_SEGMENT_SIZE) + } + + pub fn with_segment_size( + cache: cache::AppCache, + db: db::AppDatabase, + size: u64, + ) -> Self { + Self(Arc::new(SeqAllocatorInner { + cache, + db, + segments: DashMap::new(), + refresh_locks: DashMap::new(), + segment_size: if size > 0 { size } else { DEFAULT_SEGMENT_SIZE }, + })) + } + + pub async fn seq(&self, room: Uuid) -> ChannelResult { + for _ in 0..MAX_REFRESH_RETRIES { + if let Some(seq) = self.try_allocate(&room) { + return Ok(seq); + } + + let lock = self.get_refresh_lock(room); + let _guard = lock.lock().await; + + if let Some(seq) = self.try_allocate(&room) { + return Ok(seq); + } + + self.refresh_segment(room).await?; + self.0.refresh_locks.remove(&room); + } + + Err(ChannelError::Internal( + "seq allocation exhausted".to_string(), + )) + } + + pub async fn bootstrap(&self, room: Uuid) -> ChannelResult { + let db_max = self.db_max_seq(room).await?; + if db_max == 0 { + return Ok(0); + } + + let key = format!("{}{}", SEQ_KEY_PREFIX, room); + let cluster = require_cluster(&self.0.cache)?; + let mut conn = cluster.conn(); + + let current: i64 = redis::Cmd::new() + .arg("EVAL") + .arg(BOOTSTRAP_SCRIPT) + .arg(1) + .arg(&key) + .arg(db_max) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + self.0.segments.remove(&room); + Ok(current) + } + + pub async fn bootstrap_all( + &self, + rooms: Vec, + ) -> ChannelResult> { + let mut results = HashMap::with_capacity(rooms.len()); + for room in rooms { + results.insert(room, self.bootstrap(room).await?); + } + Ok(results) + } + + fn try_allocate(&self, room: &Uuid) -> Option { + let state = self.0.segments.get(room)?; + loop { + let current = state.next.load(Ordering::Acquire); + if current >= state.end { + return None; + } + if state + .next + .compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + return Some(current); + } + } + } + + fn get_refresh_lock(&self, room: Uuid) -> Arc> { + Arc::clone( + self.0 + .refresh_locks + .entry(room) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .value(), + ) + } + + async fn refresh_segment(&self, room: Uuid) -> ChannelResult<()> { + let key = format!("{}{}", SEQ_KEY_PREFIX, room); + let cluster = require_cluster(&self.0.cache)?; + let mut conn = cluster.conn(); + + let counter: i64 = redis::Cmd::new() + .arg("INCRBY") + .arg(&key) + .arg(self.0.segment_size as i64) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + let start = counter - self.0.segment_size as i64 + 1; + let end = counter + 1; + + self.0.segments.insert( + room, + Arc::new(SegmentState { + end, + next: AtomicI64::new(start), + }), + ); + + Ok(()) + } + + async fn db_max_seq(&self, room: Uuid) -> ChannelResult { + let row: (i64,) = db::sqlx::query_as( + "SELECT COALESCE(MAX(seq), 0) FROM room_message WHERE room = $1 AND deleted_at IS NULL", + ) + .bind(room) + .fetch_one(self.0.db.reader()) + .await?; + Ok(row.0) + } +} diff --git a/lib/channel/token.rs b/lib/channel/token.rs new file mode 100644 index 0000000..1ad4240 --- /dev/null +++ b/lib/channel/token.rs @@ -0,0 +1,306 @@ +use base64::Engine; +use hmac::{KeyInit, Mac}; +use sha2::Sha256; +use uuid::Uuid; + +use crate::{ + ChannelBus, ChannelError, ChannelResult, security::require_cluster, +}; + +type HmacSha256 = hmac::Hmac; + +const VERSION: u8 = 0; +pub const TOKEN_TTL_SECS: u64 = 600; +const SESSION_TTL_SECS: u64 = 1800; +const MAX_LIFETIME_SECS: i64 = 3000; + +const TOKEN_PREFIX: &str = "token:access:"; +const SESSION_PREFIX: &str = "channel:session:"; + +const MAX_TOKEN_BASE64_LEN: usize = 256; + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct ChannelAccessToken { + pub access_token: String, +} + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct ChannelTokenApply { + pub client_id: String, + pub device_id: String, +} + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct ChannelTokenContext { + pub user_id: Uuid, + pub device_id: String, + pub client_id: String, +} + +struct TokenPayload { + user_id: Uuid, + created_at: i64, +} + +impl TokenPayload { + const LEN: usize = 57; + + fn encode(&self, signing_key: &[u8]) -> ChannelResult> { + let mut buf = Vec::with_capacity(Self::LEN); + buf.push(VERSION); + buf.extend_from_slice(self.user_id.as_bytes()); + buf.extend_from_slice(&self.created_at.to_be_bytes()); + + let tag = hmac_sign(signing_key, &buf)?; + buf.extend_from_slice(&tag); + + Ok(buf) + } + + fn decode(bytes: &[u8], signing_key: &[u8]) -> ChannelResult { + if bytes.len() != Self::LEN || bytes[0] != VERSION { + return Err(ChannelError::TokenInvalidOrExpired); + } + + let expected_tag = hmac_sign(signing_key, &bytes[..25])?; + if !constant_time_eq(&expected_tag, &bytes[25..]) { + return Err(ChannelError::TokenInvalidOrExpired); + } + + let user_id_bytes: [u8; 16] = bytes[1..17].try_into().map_err( + |_: std::array::TryFromSliceError| { + ChannelError::TokenInvalidOrExpired + }, + )?; + let user_id = Uuid::from_bytes(user_id_bytes); + let created_at_bytes: [u8; 8] = bytes[17..25].try_into().map_err( + |_: std::array::TryFromSliceError| { + ChannelError::TokenInvalidOrExpired + }, + )?; + let created_at = i64::from_be_bytes(created_at_bytes); + + Ok(TokenPayload { + user_id, + created_at, + }) + } +} + +impl ChannelBus { + fn signing_key(&self) -> ChannelResult<[u8; 32]> { + let secret = + self.inner.config.signing_secret.as_deref().ok_or( + ChannelError::Internal("no signing secret".to_string()), + )?; + let mut mac = + HmacSha256::new_from_slice(secret.as_bytes()).map_err(|_| { + ChannelError::Internal("hmac init failed".to_string()) + })?; + mac.update(b"channel-access-token-signing-key"); + let result = mac.finalize().into_bytes(); + Ok(result.into()) + } + + fn session_hash_key(&self, user_id: &Uuid, created_at: i64) -> String { + format!("{}{}:{}", SESSION_PREFIX, user_id, created_at) + } + + fn token_redis_key(&self, token_str: &str) -> String { + format!("{}{}", TOKEN_PREFIX, token_str) + } + + pub async fn apply_access_token( + &self, + user_id: Uuid, + apply: ChannelTokenApply, + ) -> ChannelResult { + let created_at = chrono::Utc::now().timestamp(); + let signing_key = self.signing_key()?; + + let payload = TokenPayload { + user_id, + created_at, + }; + let token_bytes = payload.encode(&signing_key)?; + let access_token = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(&token_bytes); + + let session_key = self.session_hash_key(&user_id, created_at); + let token_key = self.token_redis_key(&access_token); + + let cluster = require_cluster(&self.inner.cache)?; + let mut conn = cluster.conn(); + let mut pipe = redis::Pipeline::new(); + pipe.hset(&session_key, "device_id", &apply.device_id) + .hset(&session_key, "client_id", &apply.client_id) + .expire(&session_key, SESSION_TTL_SECS as i64); + pipe.query_async::<()>(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + redis::Cmd::set_ex(&token_key, &session_key, TOKEN_TTL_SECS) + .query_async::<()>(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + Ok(ChannelAccessToken { access_token }) + } + + pub async fn check_access_token( + &self, + access_token: String, + ) -> ChannelResult { + let token_bytes = decode_token_bytes(&access_token)?; + + let signing_key = self.signing_key()?; + let payload = TokenPayload::decode(&token_bytes, &signing_key)?; + + let elapsed = chrono::Utc::now().timestamp() - payload.created_at; + if elapsed > MAX_LIFETIME_SECS { + return Err(ChannelError::TokenInvalidOrExpired); + } + + let session_key = + self.session_hash_key(&payload.user_id, payload.created_at); + + let cluster = require_cluster(&self.inner.cache)?; + let mut conn = cluster.conn(); + + let token_key = self.token_redis_key(&access_token); + let token_exists: bool = redis::Cmd::new() + .arg("EXISTS") + .arg(&token_key) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + if !token_exists { + return Err(ChannelError::TokenInvalidOrExpired); + } + + let hash_data: std::collections::HashMap = + redis::Cmd::new() + .arg("HGETALL") + .arg(&session_key) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + let device_id = hash_data + .get("device_id") + .cloned() + .ok_or(ChannelError::TokenInvalidOrExpired)?; + let client_id = hash_data + .get("client_id") + .cloned() + .ok_or(ChannelError::TokenInvalidOrExpired)?; + + Ok(ChannelTokenContext { + user_id: payload.user_id, + device_id, + client_id, + }) + } + + pub async fn renew_access_token( + &self, + access_token: String, + ) -> ChannelResult { + let token_bytes = decode_token_bytes(&access_token)?; + + let signing_key = self.signing_key()?; + let payload = TokenPayload::decode(&token_bytes, &signing_key)?; + + let elapsed = chrono::Utc::now().timestamp() - payload.created_at; + if elapsed > MAX_LIFETIME_SECS { + return Err(ChannelError::RenewalLimitExceeded); + } + + let session_key = + self.session_hash_key(&payload.user_id, payload.created_at); + let token_key = self.token_redis_key(&access_token); + + let cluster = require_cluster(&self.inner.cache)?; + let mut conn = cluster.conn(); + let hash_data: std::collections::HashMap = + redis::Cmd::new() + .arg("HGETALL") + .arg(&session_key) + .query_async(&mut conn) + .await + .map_err(|e| { + ChannelError::Cache(cache::CacheError::Redis(e)) + })?; + + let device_id = hash_data + .get("device_id") + .cloned() + .ok_or(ChannelError::TokenInvalidOrExpired)?; + let client_id = hash_data + .get("client_id") + .cloned() + .ok_or(ChannelError::TokenInvalidOrExpired)?; + let _: () = redis::Cmd::new() + .arg("DEL") + .arg(&token_key) + .query_async(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + let created_at = chrono::Utc::now().timestamp(); + let new_payload = TokenPayload { + user_id: payload.user_id, + created_at, + }; + let new_token_bytes = new_payload.encode(&signing_key)?; + let new_access_token = + base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(&new_token_bytes); + + let new_session_key = + self.session_hash_key(&payload.user_id, created_at); + let new_token_key = self.token_redis_key(&new_access_token); + let mut pipe = redis::Pipeline::new(); + pipe.hset(&new_session_key, "device_id", &device_id) + .hset(&new_session_key, "client_id", &client_id) + .expire(&new_session_key, SESSION_TTL_SECS as i64); + + pipe.query_async::<()>(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + redis::Cmd::set_ex(&new_token_key, &new_session_key, TOKEN_TTL_SECS) + .query_async::<()>(&mut conn) + .await + .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; + + Ok(ChannelAccessToken { + access_token: new_access_token, + }) + } +} + +fn decode_token_bytes(token: &str) -> ChannelResult> { + if token.len() > MAX_TOKEN_BASE64_LEN { + return Err(ChannelError::TokenInvalidOrExpired); + } + base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(token) + .map_err(|_| ChannelError::TokenInvalidOrExpired) +} + +fn hmac_sign(key: &[u8], payload: &[u8]) -> ChannelResult<[u8; 32]> { + let mut mac = HmacSha256::new_from_slice(key) + .map_err(|_| ChannelError::Internal("hmac sign failed".to_string()))?; + mac.update(payload); + Ok(mac.finalize().into_bytes().into()) +} + +fn constant_time_eq(expected: &[u8; 32], actual: &[u8]) -> bool { + if actual.len() != 32 { + return false; + } + let mut diff = 0u8; + for i in 0..32 { + diff |= expected[i] ^ actual[i]; + } + diff == 0 +} diff --git a/lib/config/Cargo.toml b/lib/config/Cargo.toml new file mode 100644 index 0000000..7a44b88 --- /dev/null +++ b/lib/config/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "config" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true +[lib] +path = "lib.rs" +name = "config" +[dependencies] +dotenvy = { workspace = true } +anyhow = { workspace = true } +serde = { workspace = true, features = ["derive"] } +uuid = { workspace = true, features = ["v4"] } +num_cpus = { workspace = true } +tracing = { workspace = true } +[lints] +workspace = true diff --git a/lib/config/ai.rs b/lib/config/ai.rs new file mode 100644 index 0000000..e646bb7 --- /dev/null +++ b/lib/config/ai.rs @@ -0,0 +1,16 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn ai_basic_url(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_AI_BASIC_URL") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_AI_BASIC_URL not found")) + } + pub fn ai_api_key(&self) -> anyhow::Result { + if let Some(api_key) = self.env.get("APP_AI_API_KEY") { + return Ok(api_key.to_string()); + } + Err(anyhow::anyhow!("APP_AI_API_KEY not found")) + } +} diff --git a/lib/config/app.rs b/lib/config/app.rs new file mode 100644 index 0000000..0335865 --- /dev/null +++ b/lib/config/app.rs @@ -0,0 +1,37 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn app_name(&self) -> anyhow::Result { + if let Some(name) = self.env.get("APP_NAME") { + return Ok(name.to_string()); + } + Ok(env!("CARGO_PKG_NAME").to_string()) + } + + pub fn app_version(&self) -> anyhow::Result { + if let Some(version) = self.env.get("APP_VERSION") { + return Ok(version.to_string()); + } + Ok(env!("CARGO_PKG_VERSION").to_string()) + } + pub fn app_description(&self) -> anyhow::Result { + if let Some(description) = self.env.get("APP_DESCRIPTION") { + return Ok(description.to_string()); + } + Ok(env!("CARGO_PKG_DESCRIPTION").to_string()) + } + + pub fn api_port(&self) -> anyhow::Result { + if let Some(port) = self.env.get("APP_API_PORT") { + return Ok(port.parse::()?); + } + Ok(8080) + } + + pub fn session_secret(&self) -> anyhow::Result { + if let Some(secret) = self.env.get("APP_SESSION_SECRET") { + return Ok(secret.to_string()); + } + Err(anyhow::anyhow!("APP_SESSION_SECRET not found")) + } +} diff --git a/lib/config/auth.rs b/lib/config/auth.rs new file mode 100644 index 0000000..d800763 --- /dev/null +++ b/lib/config/auth.rs @@ -0,0 +1,19 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn auth_rotation_interval(&self) -> anyhow::Result { + self.parse_env("APP_AUTH_ROTATION_INTERVAL_SECONDS", 10800) + } + + pub fn auth_key_ttl(&self) -> anyhow::Result { + self.parse_env("APP_AUTH_KEY_TTL_SECONDS", 2_592_000) + } + + pub fn auth_session_ttl(&self) -> anyhow::Result { + self.parse_env("APP_AUTH_SESSION_TTL_SECONDS", 2_592_000) + } + + pub fn auth_token_ttl(&self) -> anyhow::Result { + self.parse_env("APP_AUTH_TOKEN_TTL_SECONDS", 21_600) + } +} diff --git a/lib/config/avatar.rs b/lib/config/avatar.rs new file mode 100644 index 0000000..efe2bb9 --- /dev/null +++ b/lib/config/avatar.rs @@ -0,0 +1,10 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn avatar_path(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_AVATAR_PATH") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_AVATAR_PATH not found")) + } +} diff --git a/lib/config/cache.rs b/lib/config/cache.rs new file mode 100644 index 0000000..eee340a --- /dev/null +++ b/lib/config/cache.rs @@ -0,0 +1,83 @@ +use std::time::Duration; + +use crate::AppConfig; + +impl AppConfig { + pub fn cache_local_max_capacity(&self) -> anyhow::Result { + self.parse_env("APP_CACHE_LOCAL_MAX_CAPACITY", 10_000) + } + + pub fn cache_local_ttl(&self) -> anyhow::Result> { + self.parse_optional_duration_secs( + "APP_CACHE_LOCAL_TTL_SECONDS", + Some(300), + ) + } + + pub fn cache_local_tti(&self) -> anyhow::Result> { + self.parse_optional_duration_secs("APP_CACHE_LOCAL_TTI_SECONDS", None) + } + + pub fn cache_default_ttl(&self) -> anyhow::Result> { + self.parse_optional_duration_secs( + "APP_CACHE_DEFAULT_TTL_SECONDS", + Some(300), + ) + } + + pub fn cache_cluster_enabled(&self) -> anyhow::Result { + self.parse_env("APP_CACHE_CLUSTER_ENABLED", false) + } + + pub fn cache_cluster_key_prefix(&self) -> Option { + self.env + .get("APP_CACHE_CLUSTER_KEY_PREFIX") + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } + + pub fn cache_cluster_command_timeout(&self) -> anyhow::Result { + self.parse_duration_secs("APP_CACHE_CLUSTER_COMMAND_TIMEOUT_SECONDS", 3) + } + + pub fn cache_cluster_write_through(&self) -> anyhow::Result { + self.parse_env("APP_CACHE_CLUSTER_WRITE_THROUGH", true) + } + + pub(crate) fn parse_env( + &self, + key: &str, + default: T, + ) -> anyhow::Result + where + T: std::str::FromStr, + T::Err: std::error::Error + Send + Sync + 'static, + { + match self.env.get(key).map(|value| value.trim()) { + Some(value) if !value.is_empty() => Ok(value.parse::()?), + _ => Ok(default), + } + } + + pub(crate) fn parse_duration_secs( + &self, + key: &str, + default_secs: u64, + ) -> anyhow::Result { + Ok(Duration::from_secs(self.parse_env(key, default_secs)?)) + } + + pub(crate) fn parse_optional_duration_secs( + &self, + key: &str, + default_secs: Option, + ) -> anyhow::Result> { + match self.env.get(key).map(|value| value.trim()) { + Some("0") => Ok(None), + Some(value) if !value.is_empty() => { + Ok(Some(Duration::from_secs(value.parse()?))) + } + _ => Ok(default_secs.map(Duration::from_secs)), + } + } +} diff --git a/lib/config/database.rs b/lib/config/database.rs new file mode 100644 index 0000000..6409e19 --- /dev/null +++ b/lib/config/database.rs @@ -0,0 +1,83 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn database_url(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_DATABASE_URL") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_DATABASE_URL not found")) + } + pub fn database_max_connections(&self) -> anyhow::Result { + if let Some(max_connections) = + self.env.get("APP_DATABASE_MAX_CONNECTIONS") + { + return Ok(max_connections.parse::()?); + } + Ok(10) + } + pub fn database_min_connections(&self) -> anyhow::Result { + if let Some(min_connections) = + self.env.get("APP_DATABASE_MIN_CONNECTIONS") + { + return Ok(min_connections.parse::()?); + } + Ok(2) + } + pub fn database_idle_timeout(&self) -> anyhow::Result { + if let Some(idle_timeout) = self.env.get("APP_DATABASE_IDLE_TIMEOUT") { + return Ok(idle_timeout.parse::()?); + } + Ok(600) // seconds + } + pub fn database_max_lifetime(&self) -> anyhow::Result { + if let Some(max_lifetime) = self.env.get("APP_DATABASE_MAX_LIFETIME") { + return Ok(max_lifetime.parse::()?); + } + Ok(3600) // seconds + } + pub fn database_connection_timeout(&self) -> anyhow::Result { + if let Some(connection_timeout) = + self.env.get("APP_DATABASE_CONNECTION_TIMEOUT") + { + return Ok(connection_timeout.parse::()?); + } + Ok(8) // seconds + } + pub fn database_schema_search_path(&self) -> anyhow::Result { + if let Some(schema_search_path) = + self.env.get("APP_DATABASE_SCHEMA_SEARCH_PATH") + { + return Ok(schema_search_path.to_string()); + } + Ok("public".to_string()) + } + pub fn database_read_replicas(&self) -> anyhow::Result> { + if let Some(replicas) = self.env.get("APP_DATABASE_REPLICAS") { + if replicas.is_empty() { + return Ok(None); + } + return Ok(Some(replicas.to_string())); + } + Ok(None) + } + pub fn database_health_check_interval(&self) -> anyhow::Result { + if let Some(interval) = + self.env.get("APP_DATABASE_HEALTH_CHECK_INTERVAL") + { + return Ok(interval.parse::()?); + } + Ok(30) + } + pub fn database_retry_attempts(&self) -> anyhow::Result { + if let Some(attempts) = self.env.get("APP_DATABASE_RETRY_ATTEMPTS") { + return Ok(attempts.parse::()?); + } + Ok(3) + } + pub fn database_retry_delay(&self) -> anyhow::Result { + if let Some(delay) = self.env.get("APP_DATABASE_RETRY_DELAY") { + return Ok(delay.parse::()?); + } + Ok(5) + } +} diff --git a/lib/config/domain.rs b/lib/config/domain.rs new file mode 100644 index 0000000..702d272 --- /dev/null +++ b/lib/config/domain.rs @@ -0,0 +1,29 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn main_domain(&self) -> anyhow::Result { + if let Some(domain_url) = self.env.get("APP_DOMAIN_URL") { + return Ok(domain_url.to_string()); + } + Ok("http://127.0.0.1".to_string()) + } + + pub fn static_domain(&self) -> anyhow::Result { + if let Some(static_domain) = self.env.get("APP_STATIC_DOMAIN") { + return Ok(static_domain.to_string()); + } + self.main_domain() + } + pub fn media_domain(&self) -> anyhow::Result { + if let Some(media_domain) = self.env.get("APP_MEDIA_DOMAIN") { + return Ok(media_domain.to_string()); + } + self.main_domain() + } + pub fn git_http_domain(&self) -> anyhow::Result { + if let Some(git_http_domain) = self.env.get("APP_GIT_HTTP_DOMAIN") { + return Ok(git_http_domain.to_string()); + } + self.main_domain() + } +} diff --git a/lib/config/embed.rs b/lib/config/embed.rs new file mode 100644 index 0000000..8a816b4 --- /dev/null +++ b/lib/config/embed.rs @@ -0,0 +1,37 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn get_embed_model_base_url(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_EMBED_MODEL_BASE_URL") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_EMBED_MODEL_BASE_URL not found")) + } + pub fn get_embed_model_dimensions(&self) -> anyhow::Result { + if let Some(dimensions) = self.env.get("APP_EMBED_MODEL_DIMENSIONS") { + return Ok(dimensions.parse::()?); + } + Err(anyhow::anyhow!("APP_EMBED_MODEL_DIMENSIONS not found")) + } + pub fn get_embed_model_api_key(&self) -> anyhow::Result { + if let Some(api_key) = self.env.get("APP_EMBED_MODEL_API_KEY") { + return Ok(api_key.to_string()); + } + Err(anyhow::anyhow!("APP_EMBED_MODEL_API_KEY not found")) + } + pub fn get_embed_model_name(&self) -> anyhow::Result { + if let Some(model_name) = self.env.get("APP_EMBED_MODEL_NAME") { + return Ok(model_name.to_string()); + } + Err(anyhow::anyhow!("APP_EMBED_MODEL_NAME not found")) + } + pub fn get_qdrant_url(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_QDRANT_URL") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_QDRANT_URL not found")) + } + pub fn get_qdrant_api_key(&self) -> Option { + self.env.get("APP_QDRANT_API_KEY").map(|s| s.to_string()) + } +} diff --git a/lib/config/git.rs b/lib/config/git.rs new file mode 100644 index 0000000..c6f05a5 --- /dev/null +++ b/lib/config/git.rs @@ -0,0 +1,39 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn git_http_port(&self) -> anyhow::Result { + if let Some(port) = self.env.get("APP_GIT_HTTP_PORT") { + return Ok(port.parse::()?); + } + Ok(8021) + } + + pub fn git_rpc_port(&self) -> anyhow::Result { + if let Some(port) = self.env.get("APP_GIT_RPC_PORT") { + return Ok(port.parse::()?); + } + Ok(8030) + } + + pub fn git_rpc_addr(&self) -> anyhow::Result { + self.env + .get("APP_GIT_RPC_ADDR") + .map(|v| v.trim().to_string()) + .ok_or_else(|| anyhow::anyhow!("APP_GIT_RPC_ADDR not set")) + } + + pub fn repos_root(&self) -> anyhow::Result { + if let Some(root) = self.env.get("APP_REPOS_ROOT") { + return Ok(root.to_string()); + } + let base = std::env::current_dir()?; + Ok(base.join("data").join("repos").to_string_lossy().to_string()) + } + + pub fn gitsync_health_port(&self) -> u16 { + self.env + .get("APP_GITSYNC_HEALTH_PORT") + .and_then(|v| v.trim().parse::().ok()) + .unwrap_or(8083) + } +} diff --git a/lib/config/hook.rs b/lib/config/hook.rs new file mode 100644 index 0000000..4056bbc --- /dev/null +++ b/lib/config/hook.rs @@ -0,0 +1,84 @@ +use serde::{Deserialize, Serialize}; + +use crate::AppConfig; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PoolConfig { + pub max_concurrent: usize, + pub cpu_threshold: f32, + pub redis_list_prefix: String, + pub redis_log_channel: String, + pub redis_block_timeout_secs: u64, + pub redis_max_retries: usize, + pub worker_id: String, +} + +impl PoolConfig { + pub fn from_env(config: &AppConfig) -> Self { + let max_concurrent = config + .env + .get("HOOK_POOL_MAX_CONCURRENT") + .and_then(|v| v.parse().ok()) + .unwrap_or_else(num_cpus::get); + + let cpu_threshold = config + .env + .get("HOOK_POOL_CPU_THRESHOLD") + .and_then(|v| v.parse().ok()) + .unwrap_or(80.0); + + let redis_list_prefix = config + .env + .get("HOOK_POOL_REDIS_LIST_PREFIX") + .cloned() + .unwrap_or_else(|| "{hook}".to_string()); + + let redis_log_channel = config + .env + .get("HOOK_POOL_REDIS_LOG_CHANNEL") + .cloned() + .unwrap_or_else(|| "hook:logs".to_string()); + + let redis_block_timeout_secs = config + .env + .get("HOOK_POOL_REDIS_BLOCK_TIMEOUT") + .and_then(|v| v.parse().ok()) + .unwrap_or(5); + + let redis_max_retries = config + .env + .get("HOOK_POOL_REDIS_MAX_RETRIES") + .and_then(|v| v.parse().ok()) + .unwrap_or(3); + + let worker_id = config + .env + .get("HOOK_POOL_WORKER_ID") + .cloned() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + Self { + max_concurrent, + cpu_threshold, + redis_list_prefix, + redis_log_channel, + redis_block_timeout_secs, + redis_max_retries, + worker_id, + } + } +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + max_concurrent: num_cpus::get(), + cpu_threshold: 80.0, + redis_list_prefix: "{hook}".to_string(), + redis_log_channel: "hook:logs".to_string(), + redis_block_timeout_secs: 5, + redis_max_retries: 3, + worker_id: uuid::Uuid::new_v4().to_string(), + } + } +} diff --git a/lib/config/lib.rs b/lib/config/lib.rs new file mode 100644 index 0000000..0d2598a --- /dev/null +++ b/lib/config/lib.rs @@ -0,0 +1,58 @@ +use std::{collections::HashMap, sync::OnceLock}; + +pub static GLOBAL_CONFIG: OnceLock = OnceLock::new(); + +#[derive(Clone, Debug)] +pub struct AppConfig { + pub env: HashMap, +} + +impl AppConfig { + const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"]; + pub fn load() -> AppConfig { + let mut env = HashMap::new(); + for env_file in AppConfig::ENV_FILES { + if let Err(e) = dotenvy::from_path(env_file) { + tracing::debug!(file = %env_file, error = %e, "dotenv load skipped"); + } + if let Ok(env_file_content) = std::fs::read_to_string(env_file) { + for line in env_file_content.lines() { + if let Some((key, value)) = line.split_once('=') { + env.insert(key.to_string(), value.to_string()); + } + } + } + } + env = env.into_iter().chain(std::env::vars()).collect(); + let this = AppConfig { env }; + if GLOBAL_CONFIG.get().is_some() { + GLOBAL_CONFIG.get().unwrap().clone() + } else { + let _ = GLOBAL_CONFIG.set(this); + GLOBAL_CONFIG + .get() + .expect("global config should be set after load") + .clone() + } + } +} + +pub mod ai; +pub mod app; +pub mod auth; +pub mod avatar; +pub mod cache; +pub mod database; +pub mod domain; +pub mod embed; +pub mod git; +pub mod hook; +pub mod logs; +pub mod nats; +pub mod oauth; +pub mod pull_request; +pub mod qdrant; +pub mod redis; +pub mod smtp; +pub mod ssh; +pub mod storage; diff --git a/lib/config/logs.rs b/lib/config/logs.rs new file mode 100644 index 0000000..b3e5b98 --- /dev/null +++ b/lib/config/logs.rs @@ -0,0 +1,95 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn log_level(&self) -> anyhow::Result { + if let Some(level) = self.env.get("APP_LOG_LEVEL") { + return Ok(level.to_string()); + } + Ok("info".to_string()) + } + + pub fn log_format(&self) -> anyhow::Result { + if let Some(format) = self.env.get("APP_LOG_FORMAT") { + return Ok(format.to_string()); + } + Ok("json".to_string()) + } + + pub fn log_file_enabled(&self) -> anyhow::Result { + if let Some(enabled) = self.env.get("APP_LOG_FILE_ENABLED") { + return Ok(enabled.parse::()?); + } + Ok(false) + } + + pub fn log_file_path(&self) -> anyhow::Result { + if let Some(path) = self.env.get("APP_LOG_FILE_PATH") { + return Ok(path.to_string()); + } + Ok("./logs".to_string()) + } + + pub fn log_file_rotation(&self) -> anyhow::Result { + if let Some(rotation) = self.env.get("APP_LOG_FILE_ROTATION") { + return Ok(rotation.to_string()); + } + Ok("daily".to_string()) + } + + pub fn log_file_max_files(&self) -> anyhow::Result { + if let Some(max_files) = self.env.get("APP_LOG_FILE_MAX_FILES") { + return Ok(max_files.parse::()?); + } + Ok(7) + } + + pub fn log_file_max_size(&self) -> anyhow::Result { + if let Some(max_size) = self.env.get("APP_LOG_FILE_MAX_SIZE") { + return Ok(max_size.parse::()?); + } + Ok(104857600) // 100MB + } + + pub fn otel_enabled(&self) -> anyhow::Result { + if let Some(enabled) = self.env.get("APP_OTEL_ENABLED") { + return Ok(enabled.parse::()?); + } + Ok(false) + } + + pub fn otel_endpoint(&self) -> anyhow::Result { + if let Some(endpoint) = self.env.get("APP_OTEL_ENDPOINT") { + return Ok(endpoint.to_string()); + } + Ok("http://localhost:5080/api/default/v1/traces".to_string()) + } + + pub fn otel_service_name(&self) -> anyhow::Result { + if let Some(service_name) = self.env.get("APP_OTEL_SERVICE_NAME") { + return Ok(service_name.to_string()); + } + Ok(env!("CARGO_PKG_NAME").to_string()) + } + + pub fn otel_service_version(&self) -> anyhow::Result { + if let Some(service_version) = self.env.get("APP_OTEL_SERVICE_VERSION") + { + return Ok(service_version.to_string()); + } + Ok(env!("CARGO_PKG_VERSION").to_string()) + } + + pub fn otel_authorization(&self) -> anyhow::Result> { + if let Some(authorization) = self.env.get("APP_OTEL_AUTHORIZATION") { + return Ok(Some(authorization.to_string())); + } + Ok(None) + } + + pub fn otel_organization(&self) -> anyhow::Result> { + if let Some(organization) = self.env.get("APP_OTEL_ORGANIZATION") { + return Ok(Some(organization.to_string())); + } + Ok(None) + } +} diff --git a/lib/config/nats.rs b/lib/config/nats.rs new file mode 100644 index 0000000..a2e2233 --- /dev/null +++ b/lib/config/nats.rs @@ -0,0 +1,83 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn nats_url(&self) -> String { + self.env + .get("NATS_URL") + .cloned() + .unwrap_or_else(|| "localhost:4222".to_string()) + } + + pub fn nats_token(&self) -> Option { + self.env.get("NATS_TOKEN").cloned() + } + + pub fn nats_stream_name(&self) -> String { + self.env + .get("NATS_STREAM_NAME") + .cloned() + .unwrap_or_else(|| "APP_EVENTS".to_string()) + } + + pub fn nats_stream_subjects(&self) -> Vec { + self.env + .get("NATS_STREAM_SUBJECTS") + .map(|subjects| { + subjects + .split(',') + .map(str::trim) + .filter(|subject| !subject.is_empty()) + .map(ToOwned::to_owned) + .collect() + }) + .filter(|subjects: &Vec| !subjects.is_empty()) + .unwrap_or_else(|| { + let subject = self + .env + .get("APP_EMAIL_TOPIC") + .or_else(|| self.env.get("EMAIL_TOPIC")) + .cloned() + .unwrap_or_else(|| "email.send".to_string()); + vec![subject.clone(), format!("{subject}.>")] + }) + } + + pub fn nats_max_deliver(&self) -> i64 { + self.env + .get("NATS_MAX_DELIVER") + .and_then(|v| v.parse().ok()) + .unwrap_or(3) + } + + pub fn nats_ack_wait_secs(&self) -> u64 { + self.env + .get("NATS_ACK_WAIT_SECS") + .and_then(|v| v.parse().ok()) + .unwrap_or(10) + } + + pub fn nats_retry_delay_secs(&self) -> u64 { + self.env + .get("NATS_RETRY_DELAY_SECS") + .and_then(|v| v.parse().ok()) + .unwrap_or_else(|| self.nats_ack_wait_secs()) + } + + pub fn nats_max_age_secs(&self) -> u64 { + self.env + .get("NATS_MAX_AGE_SECS") + .and_then(|v| v.parse().ok()) + .unwrap_or(86_400) + } + + pub fn nats_buffer_size(&self) -> usize { + self.env + .get("NATS_BUFFER_SIZE") + .and_then(|v| v.parse().ok()) + .unwrap_or(256) + } + + pub fn nats_is_enabled(&self) -> bool { + !self.nats_url().is_empty() + } +} diff --git a/lib/config/oauth.rs b/lib/config/oauth.rs new file mode 100644 index 0000000..d58f749 --- /dev/null +++ b/lib/config/oauth.rs @@ -0,0 +1,7 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn oauth_server_port(&self) -> anyhow::Result { + self.parse_env("APP_OAUTH_SERVER_PORT", 8082) + } +} diff --git a/lib/config/pull_request.rs b/lib/config/pull_request.rs new file mode 100644 index 0000000..fd82ea0 --- /dev/null +++ b/lib/config/pull_request.rs @@ -0,0 +1,24 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn pr_rpc_addr(&self) -> anyhow::Result { + self.env + .get("APP_PR_RPC_ADDR") + .map(|v| v.trim().to_string()) + .ok_or_else(|| anyhow::anyhow!("APP_PR_RPC_ADDR not set")) + } + + pub fn pr_rpc_port(&self) -> anyhow::Result { + if let Some(port) = self.env.get("APP_PR_RPC_PORT") { + return Ok(port.parse::()?); + } + Ok(8040) + } + + pub fn pr_http_port(&self) -> anyhow::Result { + if let Some(port) = self.env.get("APP_PR_HTTP_PORT") { + return Ok(port.parse::()?); + } + Ok(8041) + } +} diff --git a/lib/config/qdrant.rs b/lib/config/qdrant.rs new file mode 100644 index 0000000..a507f08 --- /dev/null +++ b/lib/config/qdrant.rs @@ -0,0 +1,17 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn qdrant_url(&self) -> anyhow::Result { + if let Some(url) = self.env.get("APP_QDRANT_URL") { + return Ok(url.to_string()); + } + Err(anyhow::anyhow!("APP_QDRANT_URL not found")) + } + + pub fn qdrant_api_key(&self) -> anyhow::Result> { + if let Some(api_key) = self.env.get("APP_QDRANT_API_KEY") { + return Ok(Some(api_key.to_string())); + } + Ok(None) + } +} diff --git a/lib/config/redis.rs b/lib/config/redis.rs new file mode 100644 index 0000000..ca44d89 --- /dev/null +++ b/lib/config/redis.rs @@ -0,0 +1,41 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn redis_url(&self) -> anyhow::Result { + let urls = self.redis_urls()?; + urls.into_iter().next().ok_or_else(|| { + anyhow::anyhow!("APP_REDIS_URLS or APP_REDIS_URL is empty") + }) + } + + pub fn redis_urls(&self) -> anyhow::Result> { + if let Some(urls) = self.env.get("APP_REDIS_URLS") { + return Ok(urls.split(',').map(|s| s.trim().to_string()).collect()); + } + if let Some(url) = self.env.get("APP_REDIS_URL") { + return Ok(vec![url.to_string()]); + } + Err(anyhow::anyhow!("APP_REDIS_URLS or APP_REDIS_URL not found")) + } + + pub fn redis_pool_size(&self) -> anyhow::Result { + if let Some(pool_size) = self.env.get("APP_REDIS_POOL_SIZE") { + return Ok(pool_size.parse::()?); + } + Ok(10) + } + + pub fn redis_connect_timeout(&self) -> anyhow::Result { + if let Some(timeout) = self.env.get("APP_REDIS_CONNECT_TIMEOUT") { + return Ok(timeout.parse::()?); + } + Ok(5) + } + + pub fn redis_acquire_timeout(&self) -> anyhow::Result { + if let Some(timeout) = self.env.get("APP_REDIS_ACQUIRE_TIMEOUT") { + return Ok(timeout.parse::()?); + } + Ok(5) + } +} diff --git a/lib/config/smtp.rs b/lib/config/smtp.rs new file mode 100644 index 0000000..12ee602 --- /dev/null +++ b/lib/config/smtp.rs @@ -0,0 +1,84 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn smtp_host(&self) -> anyhow::Result { + if let Some(host) = self.env.get("APP_SMTP_HOST") { + return Ok(host.to_string()); + } + Err(anyhow::anyhow!("APP_SMTP_HOST not found")) + } + + pub fn smtp_port(&self) -> anyhow::Result { + if let Some(port) = self.env.get("APP_SMTP_PORT") { + return Ok(port.parse::()?); + } + Ok(587) + } + + pub fn smtp_username(&self) -> anyhow::Result { + if let Some(username) = self.env.get("APP_SMTP_USERNAME") { + return Ok(username.to_string()); + } + Err(anyhow::anyhow!("APP_SMTP_USERNAME not found")) + } + + pub fn smtp_password(&self) -> anyhow::Result { + if let Some(password) = self.env.get("APP_SMTP_PASSWORD") { + return Ok(password.to_string()); + } + Err(anyhow::anyhow!("APP_SMTP_PASSWORD not found")) + } + + pub fn smtp_from(&self) -> anyhow::Result { + if let Some(from) = self.env.get("APP_SMTP_FROM") { + return Ok(from.to_string()); + } + Err(anyhow::anyhow!("APP_SMTP_FROM not found")) + } + + pub fn smtp_tls(&self) -> anyhow::Result { + if let Some(tls) = self.env.get("APP_SMTP_TLS") { + return Ok(tls.parse::()?); + } + Ok(true) + } + + pub fn smtp_timeout(&self) -> anyhow::Result { + if let Some(timeout) = self.env.get("APP_SMTP_TIMEOUT") { + return Ok(timeout.parse::()?); + } + Ok(30) + } + + pub fn email_topic(&self) -> String { + self.env + .get("APP_EMAIL_TOPIC") + .or_else(|| self.env.get("EMAIL_TOPIC")) + .cloned() + .unwrap_or_else(|| "email.send".to_string()) + } + + pub fn email_consumer_group_id(&self) -> String { + self.env + .get("APP_EMAIL_CONSUMER_GROUP_ID") + .or_else(|| self.env.get("EMAIL_CONSUMER_GROUP_ID")) + .cloned() + .unwrap_or_else(|| "email-service".to_string()) + } + + pub fn email_send_retry_attempts(&self) -> u32 { + self.env + .get("APP_EMAIL_SEND_RETRY_ATTEMPTS") + .or_else(|| self.env.get("EMAIL_SEND_RETRY_ATTEMPTS")) + .and_then(|value| value.parse().ok()) + .unwrap_or(3) + } + + pub fn email_send_retry_base_delay_secs(&self) -> u64 { + self.env + .get("APP_EMAIL_SEND_RETRY_BASE_DELAY_SECS") + .or_else(|| self.env.get("EMAIL_SEND_RETRY_BASE_DELAY_SECS")) + .and_then(|value| value.parse().ok()) + .unwrap_or(1) + } +} diff --git a/lib/config/ssh.rs b/lib/config/ssh.rs new file mode 100644 index 0000000..73d755e --- /dev/null +++ b/lib/config/ssh.rs @@ -0,0 +1,38 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn ssh_domain(&self) -> anyhow::Result { + if let Some(ssh_domain) = self.env.get("APP_SSH_DOMAIN") { + return Ok(ssh_domain.to_string()); + } + let main_domain = self.main_domain()?; + if let Some(stripped) = main_domain.strip_prefix("https://") { + Ok(stripped.to_string()) + } else if let Some(stripped) = main_domain.strip_prefix("http://") { + Ok(stripped.to_string()) + } else { + Ok(main_domain) + } + } + + pub fn ssh_port(&self) -> anyhow::Result { + if let Some(ssh_port) = self.env.get("APP_SSH_PORT") { + return Ok(ssh_port.parse::()?); + } + Ok(8022) + } + + pub fn ssh_server_private_key_file(&self) -> anyhow::Result { + if let Some(path) = self.env.get("APP_SSH_SERVER_PRIVATE_KEY_FILE") { + return Ok(path.to_string()); + } + Ok("".to_string()) + } + + pub fn ssh_server_public_key_file(&self) -> anyhow::Result { + if let Some(path) = self.env.get("APP_SSH_SERVER_PUBLIC_KEY_FILE") { + return Ok(path.to_string()); + } + Ok("".to_string()) + } +} diff --git a/lib/config/storage.rs b/lib/config/storage.rs new file mode 100644 index 0000000..7c668e6 --- /dev/null +++ b/lib/config/storage.rs @@ -0,0 +1,123 @@ +use crate::AppConfig; + +impl AppConfig { + pub fn storage_backend(&self) -> anyhow::Result { + let backend = self + .env + .get("APP_STORAGE_BACKEND") + .or_else(|| self.env.get("STORAGE_BACKEND")) + .map(|value| value.trim().to_ascii_lowercase()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "s3".to_string()); + + Ok(backend) + } + + pub fn storage_path(&self) -> String { + self.env + .get("STORAGE_PATH") + .cloned() + .unwrap_or_else(|| "/data".to_string()) + } + + pub fn storage_public_url(&self) -> String { + self.env + .get("APP_STORAGE_PUBLIC_URL") + .or_else(|| self.env.get("STORAGE_PUBLIC_URL")) + .cloned() + .unwrap_or_else(|| "/files".to_string()) + } + + pub fn storage_public_url_base(&self) -> Option { + self.env + .get("APP_STORAGE_PUBLIC_URL") + .or_else(|| self.env.get("STORAGE_PUBLIC_URL")) + .map(|value| value.trim().trim_end_matches('/').to_string()) + .filter(|value| !value.is_empty()) + } + + pub fn storage_s3_bucket(&self) -> anyhow::Result { + self.env + .get("APP_STORAGE_S3_BUCKET") + .or_else(|| self.env.get("AWS_S3_BUCKET")) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + .ok_or_else(|| anyhow::anyhow!("APP_STORAGE_S3_BUCKET is required")) + } + + pub fn storage_s3_region(&self) -> String { + self.env + .get("APP_STORAGE_S3_REGION") + .or_else(|| self.env.get("AWS_REGION")) + .or_else(|| self.env.get("AWS_DEFAULT_REGION")) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "us-east-1".to_string()) + } + + pub fn storage_s3_endpoint_url(&self) -> Option { + self.env + .get("APP_STORAGE_S3_ENDPOINT_URL") + .or_else(|| self.env.get("AWS_ENDPOINT_URL_S3")) + .or_else(|| self.env.get("AWS_ENDPOINT_URL")) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } + + pub fn storage_s3_access_key_id(&self) -> Option { + self.env + .get("APP_STORAGE_S3_ACCESS_KEY_ID") + .or_else(|| self.env.get("AWS_ACCESS_KEY_ID")) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } + + pub fn storage_s3_secret_access_key(&self) -> Option { + self.env + .get("APP_STORAGE_S3_SECRET_ACCESS_KEY") + .or_else(|| self.env.get("AWS_SECRET_ACCESS_KEY")) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } + + pub fn storage_s3_session_token(&self) -> Option { + self.env + .get("APP_STORAGE_S3_SESSION_TOKEN") + .or_else(|| self.env.get("AWS_SESSION_TOKEN")) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } + + pub fn storage_s3_force_path_style(&self) -> anyhow::Result { + self.parse_env("APP_STORAGE_S3_FORCE_PATH_STYLE", false) + } + + pub fn storage_presigned_url_ttl( + &self, + ) -> anyhow::Result { + self.parse_duration_secs("APP_STORAGE_PRESIGNED_URL_TTL_SECONDS", 900) + } + + pub fn storage_max_file_size(&self) -> usize { + self.env + .get("APP_STORAGE_MAX_FILE_SIZE") + .or_else(|| self.env.get("STORAGE_MAX_FILE_SIZE")) + .and_then(|s| s.parse::().ok()) + .unwrap_or(10 * 1024 * 1024) // 10MB default + } + + pub fn vapid_public_key(&self) -> Option { + self.env.get("VAPID_PUBLIC_KEY").cloned() + } + + pub fn vapid_private_key(&self) -> Option { + self.env.get("VAPID_PRIVATE_KEY").cloned() + } + + pub fn vapid_sender_email(&self) -> String { + self.env + .get("VAPID_SENDER_EMAIL") + .cloned() + .unwrap_or_else(|| "mailto:admin@example.com".to_string()) + } +} diff --git a/lib/db/Cargo.toml b/lib/db/Cargo.toml new file mode 100644 index 0000000..7529fbd --- /dev/null +++ b/lib/db/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "db" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "db" +[dependencies] +config = { workspace = true } +sea-orm = { workspace = true } +anyhow = { workspace = true } +sqlx = { workspace = true, features = ["postgres","runtime-tokio"]} +sqlparser = { workspace = true, features = [] } +[lints] +workspace = true diff --git a/lib/db/database.rs b/lib/db/database.rs new file mode 100644 index 0000000..7c25bd7 --- /dev/null +++ b/lib/db/database.rs @@ -0,0 +1,232 @@ +use std::{str::FromStr, time::Duration}; + +use config::AppConfig; +use sqlx::{ + AssertSqlSafe, ConnectOptions, FromRow, PgPool, + postgres::{ + PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow, + }, +}; + +use crate::{ + route::{SqlRoute, route_sql}, + transaction::AppTransaction, +}; + +#[derive(Clone)] +pub struct AppDatabase { + db_write: PgPool, + db_read: Option, +} + +impl AppDatabase { + pub async fn init(cfg: &AppConfig) -> anyhow::Result { + let db_url = cfg.database_url()?; + let max_connections = cfg.database_max_connections()?; + let min_connections = cfg.database_min_connections()?; + let idle_timeout = cfg.database_idle_timeout()?; + let max_lifetime = cfg.database_max_lifetime()?; + let connection_timeout = cfg.database_connection_timeout()?; + let schema_search_path = cfg.database_schema_search_path()?; + let read_replica = cfg.database_read_replicas()?; + + let write_options = build_pg_options(&db_url, &schema_search_path)?; + + let db_write = build_pool( + write_options, + max_connections, + min_connections, + idle_timeout, + max_lifetime, + connection_timeout, + ) + .await?; + + sqlx::query(AssertSqlSafe("SELECT 1".to_owned())) + .execute(&db_write) + .await?; + + let db_read = if let Some(replica_url) = read_replica { + let read_options = + build_pg_options(&replica_url, &schema_search_path)?; + + let pool = build_pool( + read_options, + max_connections, + min_connections, + idle_timeout, + max_lifetime, + connection_timeout, + ) + .await?; + + sqlx::query(AssertSqlSafe("SELECT 1".to_owned())) + .execute(&pool) + .await?; + + Some(pool) + } else { + None + }; + + Ok(Self { db_write, db_read }) + } + + pub fn writer(&self) -> &PgPool { + &self.db_write + } + + pub fn reader(&self) -> &PgPool { + self.db_read.as_ref().unwrap_or(&self.db_write) + } + + pub fn route_pool(&self, sql: &str) -> &PgPool { + match route_sql(sql) { + SqlRoute::Write => self.writer(), + SqlRoute::Read => self.reader(), + } + } + + pub async fn begin(&self) -> Result, sqlx::Error> { + let txn = self.db_write.begin().await?; + Ok(AppTransaction { inner: txn }) + } + + pub async fn begin_read_only( + &self, + ) -> Result, sqlx::Error> { + let mut txn = self.reader().begin().await?; + + sqlx::query(AssertSqlSafe("SET TRANSACTION READ ONLY".to_owned())) + .execute(&mut *txn) + .await?; + + Ok(AppTransaction { inner: txn }) + } + + pub async fn execute( + &self, + sql: &str, + ) -> Result { + self.execute_with_args(sql, PgArguments::default()).await + } + + pub async fn execute_with_args( + &self, + sql: &str, + args: PgArguments, + ) -> Result { + let pool = self.route_pool(sql); + + sqlx::query_with(AssertSqlSafe(sql.to_owned()), args) + .execute(pool) + .await + } + + pub async fn fetch_one(&self, sql: &str) -> Result + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + self.fetch_one_with_args(sql, PgArguments::default()).await + } + + pub async fn fetch_one_with_args( + &self, + sql: &str, + args: PgArguments, + ) -> Result + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + let pool = self.route_pool(sql); + + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_one(pool) + .await + } + + pub async fn fetch_optional( + &self, + sql: &str, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + self.fetch_optional_with_args(sql, PgArguments::default()) + .await + } + + pub async fn fetch_optional_with_args( + &self, + sql: &str, + args: PgArguments, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + let pool = self.route_pool(sql); + + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_optional(pool) + .await + } + + pub async fn fetch_all(&self, sql: &str) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + self.fetch_all_with_args(sql, PgArguments::default()).await + } + + pub async fn fetch_all_with_args( + &self, + sql: &str, + args: PgArguments, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + let pool = self.route_pool(sql); + + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_all(pool) + .await + } +} + +fn build_pg_options( + db_url: &str, + schema_search_path: &str, +) -> anyhow::Result { + let options = PgConnectOptions::from_str(db_url)? + .options([("search_path", schema_search_path)]) + .disable_statement_logging(); + + Ok(options) +} + +async fn build_pool( + options: PgConnectOptions, + max_connections: u32, + min_connections: u32, + idle_timeout_secs: u64, + max_lifetime_secs: u64, + connection_timeout_secs: u64, +) -> Result { + let mut pool_options = PgPoolOptions::new() + .max_connections(max_connections) + .min_connections(min_connections) + .acquire_timeout(Duration::from_secs(connection_timeout_secs.max(1))); + + if idle_timeout_secs > 0 { + pool_options = + pool_options.idle_timeout(Duration::from_secs(idle_timeout_secs)); + } + + if max_lifetime_secs > 0 { + pool_options = + pool_options.max_lifetime(Duration::from_secs(max_lifetime_secs)); + } + + pool_options.connect_with(options).await +} diff --git a/lib/db/lib.rs b/lib/db/lib.rs new file mode 100644 index 0000000..bcb6f73 --- /dev/null +++ b/lib/db/lib.rs @@ -0,0 +1,7 @@ +pub mod database; +pub mod route; +pub mod transaction; + +pub use database::AppDatabase; +pub use sqlx; +pub use transaction::AppTransaction; diff --git a/lib/db/route.rs b/lib/db/route.rs new file mode 100644 index 0000000..3f020a6 --- /dev/null +++ b/lib/db/route.rs @@ -0,0 +1,88 @@ +use sqlparser::{ + ast::{Query, SetExpr, Statement}, + dialect::PostgreSqlDialect, + parser::Parser, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SqlRoute { + Read, + Write, +} + +pub fn route_sql(sql: &str) -> SqlRoute { + let trimmed = sql.trim_start(); + + if starts_with_hint(trimmed, "write") { + return SqlRoute::Write; + } + + if starts_with_hint(trimmed, "read") { + return SqlRoute::Read; + } + + if is_read_query_by_ast(sql) { + SqlRoute::Read + } else { + SqlRoute::Write + } +} + +fn starts_with_hint(sql: &str, hint: &str) -> bool { + let expected = format!("/*+ {hint} */"); + sql.starts_with(&expected) +} + +fn is_read_query_by_ast(sql: &str) -> bool { + let dialect = PostgreSqlDialect {}; + + let Ok(statements) = Parser::parse_sql(&dialect, sql) else { + return false; + }; + + if statements.is_empty() { + return false; + } + + statements.iter().all(is_read_statement) +} + +fn is_read_statement(statement: &Statement) -> bool { + match statement { + Statement::Query(query) => is_read_query_ast(query), + Statement::ShowVariable { .. } + | Statement::ShowVariables { .. } + | Statement::ShowCreate { .. } + | Statement::ShowColumns { .. } + | Statement::ShowTables { .. } => true, + _ => false, + } +} + +fn is_read_query_ast(query: &Query) -> bool { + if !query.locks.is_empty() { + return false; + } + + match query.body.as_ref() { + SetExpr::Select(_) => true, + SetExpr::Query(inner) => is_read_query_ast(inner), + SetExpr::SetOperation { left, right, .. } => { + is_read_set_expr(left) && is_read_set_expr(right) + } + SetExpr::Values(_) => true, + _ => false, + } +} + +fn is_read_set_expr(expr: &SetExpr) -> bool { + match expr { + SetExpr::Select(_) => true, + SetExpr::Query(query) => is_read_query_ast(query), + SetExpr::SetOperation { left, right, .. } => { + is_read_set_expr(left) && is_read_set_expr(right) + } + SetExpr::Values(_) => true, + _ => false, + } +} diff --git a/lib/db/transaction.rs b/lib/db/transaction.rs new file mode 100644 index 0000000..1f3c5a5 --- /dev/null +++ b/lib/db/transaction.rs @@ -0,0 +1,98 @@ +use sqlx::{ + AssertSqlSafe, FromRow, Postgres, Transaction, + postgres::{PgArguments, PgQueryResult, PgRow}, +}; + +pub struct AppTransaction<'a> { + pub(crate) inner: Transaction<'a, Postgres>, +} + +impl<'a> AppTransaction<'a> { + pub fn inner_mut(&mut self) -> &mut Transaction<'a, Postgres> { + &mut self.inner + } + pub async fn commit(self) -> Result<(), sqlx::Error> { + self.inner.commit().await + } + pub async fn rollback(self) -> Result<(), sqlx::Error> { + self.inner.rollback().await + } + pub async fn execute( + &mut self, + sql: &str, + ) -> Result { + self.execute_with_args(sql, PgArguments::default()).await + } + pub async fn execute_with_args( + &mut self, + sql: &str, + args: PgArguments, + ) -> Result { + sqlx::query_with(AssertSqlSafe(sql.to_owned()), args) + .execute(&mut *self.inner) + .await + } + pub async fn fetch_one(&mut self, sql: &str) -> Result + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + self.fetch_one_with_args(sql, PgArguments::default()).await + } + pub async fn fetch_one_with_args( + &mut self, + sql: &str, + args: PgArguments, + ) -> Result + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_one(&mut *self.inner) + .await + } + pub async fn fetch_optional( + &mut self, + sql: &str, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + self.fetch_optional_with_args(sql, PgArguments::default()) + .await + } + + pub async fn fetch_optional_with_args( + &mut self, + sql: &str, + args: PgArguments, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_optional(&mut *self.inner) + .await + } + pub async fn fetch_all( + &mut self, + sql: &str, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + self.fetch_all_with_args(sql, PgArguments::default()).await + } + + pub async fn fetch_all_with_args( + &mut self, + sql: &str, + args: PgArguments, + ) -> Result, sqlx::Error> + where + for<'r> T: FromRow<'r, PgRow> + Send + Unpin, + { + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_all(&mut *self.inner) + .await + } +} diff --git a/lib/email/Cargo.toml b/lib/email/Cargo.toml new file mode 100644 index 0000000..6a68416 --- /dev/null +++ b/lib/email/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "email" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "email" + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +config = { workspace = true } +lettre = { workspace = true } +queue = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["rt", "time"] } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/lib/email/app.rs b/lib/email/app.rs new file mode 100644 index 0000000..cec57db --- /dev/null +++ b/lib/email/app.rs @@ -0,0 +1,39 @@ +use config::AppConfig; +use queue::NatsProducer; + +use crate::EmailMessage; + +#[derive(Clone)] +pub struct AppEmail { + producer: NatsProducer, + topic: String, +} + +impl AppEmail { + pub async fn init(config: &AppConfig) -> anyhow::Result { + Ok(Self { + producer: NatsProducer::new(config).await?, + topic: config.email_topic(), + }) + } + + pub fn with_topic( + producer: NatsProducer, + topic: impl Into, + ) -> Self { + Self { + producer, + topic: topic.into(), + } + } + + pub async fn send(&self, message: EmailMessage) -> anyhow::Result<()> { + self.producer + .send(&self.topic, &message.to, &message, None) + .await + } + + pub fn topic(&self) -> &str { + &self.topic + } +} diff --git a/lib/email/lib.rs b/lib/email/lib.rs new file mode 100644 index 0000000..dae02d9 --- /dev/null +++ b/lib/email/lib.rs @@ -0,0 +1,9 @@ +mod app; +mod message; +mod smtp; +mod worker; + +pub use app::AppEmail; +pub use message::EmailMessage; +pub use smtp::SmtpEmailSender; +pub use worker::EmailWorker; diff --git a/lib/email/message.rs b/lib/email/message.rs new file mode 100644 index 0000000..00cb48f --- /dev/null +++ b/lib/email/message.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct EmailMessage { + pub to: String, + pub subject: String, + pub body: String, +} diff --git a/lib/email/smtp.rs b/lib/email/smtp.rs new file mode 100644 index 0000000..5160ba5 --- /dev/null +++ b/lib/email/smtp.rs @@ -0,0 +1,135 @@ +use std::time::Duration; + +use config::AppConfig; +use lettre::{ + Message, Transport, + message::{Mailbox, header::ContentType}, + transport::smtp::{PoolConfig, SmtpTransport, authentication::Credentials}, +}; +use tracing::warn; + +use crate::EmailMessage; + +#[derive(Clone)] +pub struct SmtpEmailSender { + mailer: SmtpTransport, + from: Mailbox, + retry_attempts: u32, + retry_base_delay: Duration, +} + +impl SmtpEmailSender { + pub fn new(config: &AppConfig) -> anyhow::Result { + let smtp_host = config.smtp_host()?; + let smtp_port = config.smtp_port()?; + let smtp_username = config.smtp_username()?; + let smtp_password = config.smtp_password()?; + let smtp_from = config.smtp_from()?; + let smtp_tls = config.smtp_tls()?; + let smtp_timeout = config.smtp_timeout()?; + + let builder = if smtp_tls { + if smtp_port == 465 { + SmtpTransport::relay(&smtp_host)? + } else { + SmtpTransport::starttls_relay(&smtp_host)? + } + } else { + SmtpTransport::builder_dangerous(&smtp_host) + }; + + let mailer = builder + .credentials(Credentials::new(smtp_username, smtp_password)) + .port(smtp_port) + .timeout(Some(Duration::from_secs(smtp_timeout))) + .pool_config(PoolConfig::new().min_idle(0).max_size(10)) + .build(); + + Ok(Self { + mailer, + from: smtp_from.parse()?, + retry_attempts: config.email_send_retry_attempts(), + retry_base_delay: Duration::from_secs( + config.email_send_retry_base_delay_secs(), + ), + }) + } + + pub async fn send(&self, message: EmailMessage) -> anyhow::Result<()> { + let recipient: Mailbox = message.to.parse()?; + let email = Message::builder() + .from(self.from.clone()) + .to(recipient) + .subject(message.subject) + .header(ContentType::TEXT_PLAIN) + .body(message.body)?; + + let attempts = self.retry_attempts.max(1); + let mut last_error = None; + + for attempt in 0..attempts { + let mailer = self.mailer.clone(); + let email = email.clone(); + let result = + tokio::task::spawn_blocking(move || mailer.send(&email)).await; + + match result { + Ok(Ok(_)) => return Ok(()), + Ok(Err(error)) => { + last_error = Some(anyhow::anyhow!(error)); + warn!(attempt = attempt + 1, "email send attempt failed"); + } + Err(error) => { + return Err(anyhow::anyhow!( + "email send task failed: {error}" + )); + } + } + + if attempt + 1 < attempts { + let multiplier = + 1u64.checked_shl(attempt).unwrap_or(u64::MAX).max(1); + tokio::time::sleep( + self.retry_base_delay.saturating_mul(multiplier as u32), + ) + .await; + } + } + + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("email send failed"))) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use config::AppConfig; + + use super::SmtpEmailSender; + + #[test] + fn smtp_config_accepts_credentials_with_url_special_characters() { + let config = AppConfig { + env: HashMap::from([ + ("APP_SMTP_HOST".to_string(), "smtp.example.com".to_string()), + ("APP_SMTP_PORT".to_string(), "587".to_string()), + ( + "APP_SMTP_USERNAME".to_string(), + "user@example.com".to_string(), + ), + ( + "APP_SMTP_PASSWORD".to_string(), + "p@ss:word/with?chars".to_string(), + ), + ( + "APP_SMTP_FROM".to_string(), + "Gitdata ".to_string(), + ), + ("APP_SMTP_TLS".to_string(), "true".to_string()), + ]), + }; + + assert!(SmtpEmailSender::new(&config).is_ok()); + } +} diff --git a/lib/email/worker.rs b/lib/email/worker.rs new file mode 100644 index 0000000..f48693f --- /dev/null +++ b/lib/email/worker.rs @@ -0,0 +1,46 @@ +use config::AppConfig; +use queue::{AckAction, MessageHandler, NatsConsumer}; +use tracing::error; + +use crate::{EmailMessage, SmtpEmailSender}; + +pub struct EmailWorker { + sender: SmtpEmailSender, +} + +impl EmailWorker { + pub fn new(sender: SmtpEmailSender) -> Self { + Self { sender } + } + + pub async fn start(config: &AppConfig) -> anyhow::Result<()> { + let worker = Self::new(SmtpEmailSender::new(config)?); + let consumer = + NatsConsumer::new(config, &config.email_consumer_group_id()) + .await?; + let topic = config.email_topic(); + consumer.start_consuming(&[topic.as_str()], worker).await?; + std::future::pending().await + } +} + +#[async_trait::async_trait] +impl MessageHandler for EmailWorker { + async fn handle(&self, topic: &str, payload: &[u8]) -> AckAction { + let message = match serde_json::from_slice::(payload) { + Ok(message) => message, + Err(error) => { + error!(topic, error = %error, "invalid email message payload"); + return AckAction::Ack; + } + }; + + match self.sender.send(message).await { + Ok(()) => AckAction::Ack, + Err(error) => { + error!(topic, error = %error, "email message send failed"); + AckAction::Nack + } + } + } +} diff --git a/lib/git/Cargo.toml b/lib/git/Cargo.toml new file mode 100644 index 0000000..4ce22b3 --- /dev/null +++ b/lib/git/Cargo.toml @@ -0,0 +1,68 @@ +[package] +name = "git" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "git" + +[dependencies] +thiserror = { workspace = true } +cache = { workspace = true } +db = { workspace = true } +storage = { workspace = true } +config = { workspace = true } +juniper = { workspace = true, features = [] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "sync", "process", "io-util"] } +anyhow = { workspace = true } +duct = { workspace = true } +gix = { workspace = true } +gix-archive = { workspace = true } +gix-worktree-stream = { workspace = true } +model = { workspace = true } +sqlx = { workspace = true, features = ["derive", "postgres", "runtime-tokio", "uuid", "chrono"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tonic = { workspace = true } +prost = { workspace = true } +dashmap = { workspace = true } +tracing = { workspace = true } +tokio-stream = { workspace = true } +actix-web = { workspace = true, features = [] } +russh = { workspace = true, features = ["async-trait","rsa","legacy-ed25519-pkcs8-parser","serde"] } +hmac = { workspace = true } +hex = { workspace = true } +reqwest = { workspace = true } +async-stream = { workspace = true } +async-trait = { workspace = true } +tokio-util = { workspace = true } +argon2 = { workspace = true } +password-hash = { workspace = true } +base64 = { workspace = true } +sha2 = { workspace = true } +uuid = { workspace = true, features = ["v4", "v7"] } +chrono = { workspace = true } +futures-util = { workspace = true } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] } +deadpool-redis = { workspace = true } +serde_yaml = { workspace = true } +num_cpus = { workspace = true } +miette = { workspace = true } +parsefile = { workspace = true } +tonic-prost = { workspace = true } + +[build-dependencies] +tonic-build = { workspace = true } +tonic-prost-build = "0.14.6" +[lints] +workspace = true diff --git a/lib/git/bare.rs b/lib/git/bare.rs new file mode 100644 index 0000000..98f223c --- /dev/null +++ b/lib/git/bare.rs @@ -0,0 +1,134 @@ +use std::path::PathBuf; + +use crate::errors::GitResult; + +#[derive(Clone)] +pub struct GitBare { + pub bare_dir: PathBuf, +} + +#[derive(Debug, Clone)] +pub struct LastCommitInfo { + pub message: String, + pub time: String, + pub author_name: String, + pub author_email: String, +} + +impl GitBare { + pub fn gix_repo(&self) -> crate::errors::GitResult { + gix::open(&self.bare_dir).map_err(|e| { + crate::errors::GitError::Internal(format!( + "failed to open gix repository: {e}" + )) + }) + } + pub fn set_default_branch(&self, branch_name: &str) -> GitResult<()> { + let output = self.git_command_trusted(vec![ + "symbolic-ref".to_string(), + "HEAD".to_string(), + format!("refs/heads/{}", branch_name), + ])?; + if !output.success { + return Err(crate::errors::GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + Ok(()) + } + pub fn last_commits_for_paths(&self, paths: &[String]) -> GitResult>> { + use gix::traverse::commit::simple::CommitTimeOrder; + + if paths.is_empty() { + return Ok(vec![]); + } + + let repo = self.gix_repo()?; + let head_id = repo.head_id() + .map_err(|e| crate::errors::GitError::Internal(format!("no HEAD: {e}")))?; + + let walk = repo.rev_walk(vec![head_id.detach()]) + .sorting(gix::revision::walk::Sorting::ByCommitTime(CommitTimeOrder::NewestFirst)) + .first_parent_only() + .all() + .map_err(|e| crate::errors::GitError::Internal(format!("rev_walk: {e}")))?; + + let mut result: Vec> = vec![None; paths.len()]; + let mut remaining: std::collections::HashSet = (0..paths.len()).collect(); + + for walk_item in walk { + if remaining.is_empty() { + break; + } + let info = match walk_item { + Ok(i) => i, + Err(_) => continue, + }; + let oid = info.id().detach(); + let commit = match repo.find_commit(oid) { + Ok(c) => c, + Err(_) => continue, + }; + let decoded = match commit.decode() { + Ok(d) => d, + Err(_) => continue, + }; + + let current_tree = match repo.find_tree(decoded.tree()) { + Ok(t) => t, + Err(_) => continue, + }; + + let parent_tree = decoded.parents().next().and_then(|pid| { + let hex = pid.to_hex().to_string(); + let gid = gix::hash::ObjectId::from_hex(hex.as_bytes()).ok()?; + let pc = repo.find_commit(gid).ok()?; + let p_decoded = pc.decode().ok()?; + repo.find_tree(p_decoded.tree()).ok() + }); + + let mut diff_opts = gix::diff::Options::default(); + diff_opts.track_path(); + let changes = match repo.diff_tree_to_tree( + parent_tree.as_ref(), + Some(¤t_tree), + Some(diff_opts), + ) { + Ok(d) => d, + Err(_) => continue, + }; + + let changed: std::collections::HashSet = changes.iter() + .map(|c| c.location().to_string()) + .collect(); + + let is_root = parent_tree.is_none(); + let author_sig = match decoded.author() { + Ok(s) => s, + Err(_) => continue, + }; + let time = author_sig.time().unwrap_or(gix::date::Time { seconds: 0, offset: 0 }); + let msg = commit.message_raw() + .map(|r| r.to_string().trim_end_matches('\n').to_string()) + .unwrap_or_default(); + let summary = msg.lines().next().unwrap_or("").to_string(); + + let matched: Vec = remaining.iter().copied().filter(|&idx| { + is_root || changed.contains(&paths[idx]) + }).collect(); + + for idx in matched { + result[idx] = Some(LastCommitInfo { + message: summary.clone(), + time: format!("{}", time.seconds), + author_name: author_sig.name.to_string(), + author_email: author_sig.email.to_string(), + }); + remaining.remove(&idx); + } + } + + Ok(result) + } +} diff --git a/lib/git/build.rs b/lib/git/build.rs new file mode 100644 index 0000000..9c79321 --- /dev/null +++ b/lib/git/build.rs @@ -0,0 +1,25 @@ +fn main() { + let proto_dir = "proto"; + + let protos = [ + format!("{proto_dir}/common.proto"), + format!("{proto_dir}/archive.proto"), + format!("{proto_dir}/blame.proto"), + format!("{proto_dir}/blob.proto"), + format!("{proto_dir}/branch.proto"), + format!("{proto_dir}/commit.proto"), + format!("{proto_dir}/diff.proto"), + format!("{proto_dir}/fork.proto"), + format!("{proto_dir}/init.proto"), + format!("{proto_dir}/merge.proto"), + format!("{proto_dir}/tag.proto"), + format!("{proto_dir}/tree.proto"), + ]; + + tonic_prost_build::configure() + .type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]") + .type_attribute("ObjectId", "#[serde(transparent)]") + .field_attribute("ObjectId.value", "#[serde(rename = \"0\")]") + .compile_protos(&protos, &[proto_dir.to_string()]) + .unwrap_or_else(|e| panic!("Failed to compile protos: {e}")); +} diff --git a/lib/git/cmd/archive/mod.rs b/lib/git/cmd/archive/mod.rs new file mode 100644 index 0000000..d6e5ba8 --- /dev/null +++ b/lib/git/cmd/archive/mod.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +use crate::cmd::oid::ObjectId; + +pub mod tar; +pub mod zip; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArchiveOptions { + pub tree: ObjectId, + pub prefix: Option, + pub pathspec: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArchiveResult { + pub bytes: Vec, +} diff --git a/lib/git/cmd/archive/tar.rs b/lib/git/cmd/archive/tar.rs new file mode 100644 index 0000000..8ee884c --- /dev/null +++ b/lib/git/cmd/archive/tar.rs @@ -0,0 +1,53 @@ +use gix::error::ResultExt; + +use crate::{ + bare::GitBare, + cmd::archive::{ArchiveOptions, ArchiveResult}, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn archive_tar( + &self, + options: ArchiveOptions, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&options.tree).try_into()?; + let tree_id = { + let gix_id = gix_id; + if let Ok(tree) = repo.find_tree(gix_id) { + tree.id().detach() + } else { + let commit = repo + .find_commit(gix_id) + .map_err(|e| GitError::Gix(e.to_string()))?; + commit + .tree_id() + .map_err(|e| GitError::Gix(e.to_string()))? + .detach() + } + }; + + let (mut stream, _index) = repo + .worktree_stream(tree_id) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut buf: Vec = Vec::new(); + + let archive_opts = gix_archive::Options { + format: gix_archive::Format::Tar, + tree_prefix: options.prefix.map(|p| gix::bstr::BString::from(p)), + ..Default::default() + }; + + gix_archive::write_stream_seek( + &mut stream, + |stream| stream.next_entry().or_erased(), + std::io::Cursor::new(&mut buf), + archive_opts, + ) + .map_err(|e| GitError::Gix(e.to_string()))?; + + Ok(ArchiveResult { bytes: buf }) + } +} diff --git a/lib/git/cmd/archive/zip.rs b/lib/git/cmd/archive/zip.rs new file mode 100644 index 0000000..95f0726 --- /dev/null +++ b/lib/git/cmd/archive/zip.rs @@ -0,0 +1,55 @@ +use gix::error::ResultExt; + +use crate::{ + bare::GitBare, + cmd::archive::{ArchiveOptions, ArchiveResult}, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn archive_zip( + &self, + options: ArchiveOptions, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&options.tree).try_into()?; + let tree_id = { + let gix_id = gix_id; + if let Ok(tree) = repo.find_tree(gix_id) { + tree.id().detach() + } else { + let commit = repo + .find_commit(gix_id) + .map_err(|e| GitError::Gix(e.to_string()))?; + commit + .tree_id() + .map_err(|e| GitError::Gix(e.to_string()))? + .detach() + } + }; + + let (mut stream, _index) = repo + .worktree_stream(tree_id) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut buf: Vec = Vec::new(); + + let archive_opts = gix_archive::Options { + format: gix_archive::Format::Zip { + compression_level: None, + }, + tree_prefix: options.prefix.map(|p| gix::bstr::BString::from(p)), + ..Default::default() + }; + + gix_archive::write_stream_seek( + &mut stream, + |stream| stream.next_entry().or_erased(), + std::io::Cursor::new(&mut buf), + archive_opts, + ) + .map_err(|e| GitError::Gix(e.to_string()))?; + + Ok(ArchiveResult { bytes: buf }) + } +} diff --git a/lib/git/cmd/blame/blame_file.rs b/lib/git/cmd/blame/blame_file.rs new file mode 100644 index 0000000..5986b0f --- /dev/null +++ b/lib/git/cmd/blame/blame_file.rs @@ -0,0 +1,96 @@ +use std::path::Path; + +use crate::{ + bare::GitBare, + cmd::{ + blame::{BlameOptions, CommitBlameHunk}, + oid::ObjectId, + }, + errors::GitResult, +}; + +impl GitBare { + pub fn blame_file( + &self, + oid: ObjectId, + path: impl AsRef, + opts: Option, + ) -> GitResult> { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + + let blame_opts = opts.unwrap_or_default(); + let ranges = if let (Some(min), Some(max)) = + (blame_opts.min_line, blame_opts.max_line) + { + gix::blame::BlameRanges::from_one_based_inclusive_range( + min as u32..=max as u32, + )? + } else if let Some(min) = blame_opts.min_line { + gix::blame::BlameRanges::from_one_based_inclusive_range( + min as u32..=min as u32, + )? + } else { + gix::blame::BlameRanges::default() + }; + + let rewrites = if blame_opts.track_copies_same_file + || blame_opts.track_copies_same_commit_moves + { + Some(gix::diff::Rewrites::default()) + } else { + None + }; + + let gix_options = gix::repository::blame_file::Options { + diff_algorithm: None, + ranges, + since: None, + rewrites, + }; + + let outcome = repo.blame_file( + path.as_ref() + .as_os_str() + .to_string_lossy() + .as_bytes() + .as_ref(), + gix_id, + gix_options, + )?; + + let mut hunks = Vec::new(); + + for entry in outcome.entries { + let range_in_blamed = entry.range_in_blamed_file(); + let range_in_source = entry.range_in_source_file(); + + let commit_oid = + ObjectId::new(entry.commit_id.to_hex().to_string()); + let final_start_line = range_in_blamed.start as u32 + 1; + let final_lines = range_in_blamed.len() as u32; + let orig_start_line = range_in_source.start as u32 + 1; + let orig_lines = range_in_source.len() as u32; + + let is_boundary = repo + .find_commit(entry.commit_id) + .ok() + .is_some_and(|c| c.parent_ids().count() == 0); + + let orig_path = + entry.source_file_name.as_ref().map(|p| p.to_string()); + + hunks.push(CommitBlameHunk { + commit_oid, + final_start_line, + final_lines, + orig_start_line, + orig_lines, + boundary: is_boundary, + orig_path, + }); + } + + Ok(hunks) + } +} diff --git a/lib/git/cmd/blame/blame_hunk.rs b/lib/git/cmd/blame/blame_hunk.rs new file mode 100644 index 0000000..e048d1d --- /dev/null +++ b/lib/git/cmd/blame/blame_hunk.rs @@ -0,0 +1,69 @@ +use std::path::Path; + +use crate::{ + bare::GitBare, + cmd::{blame::CommitBlameHunk, oid::ObjectId}, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn blame_hunk( + &self, + oid: &ObjectId, + path: impl AsRef, + line_on: usize, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (oid).try_into()?; + + let ranges = gix::blame::BlameRanges::from_one_based_inclusive_range( + line_on as u32..=line_on as u32, + )?; + + let gix_options = gix::repository::blame_file::Options { + diff_algorithm: None, + ranges, + since: None, + rewrites: None, + }; + + let outcome = repo.blame_file( + path.as_ref() + .as_os_str() + .to_string_lossy() + .as_bytes() + .as_ref(), + gix_id, + gix_options, + )?; + + let entry = outcome.entries.first().ok_or_else(|| { + GitError::ParseError(format!( + "no blame hunk found for line {} in {} at {}", + line_on, + path.as_ref().display(), + oid + )) + })?; + + let range_in_blamed = entry.range_in_blamed_file(); + let range_in_source = entry.range_in_source_file(); + + let commit_oid = ObjectId::new(entry.commit_id.to_hex().to_string()); + let is_boundary = repo + .find_commit(entry.commit_id) + .ok() + .is_some_and(|c| c.parent_ids().count() == 0); + let orig_path = entry.source_file_name.as_ref().map(|p| p.to_string()); + + Ok(CommitBlameHunk { + commit_oid, + final_start_line: range_in_blamed.start as u32 + 1, + final_lines: range_in_blamed.len() as u32, + orig_start_line: range_in_source.start as u32 + 1, + orig_lines: range_in_source.len() as u32, + boundary: is_boundary, + orig_path, + }) + } +} diff --git a/lib/git/cmd/blame/blame_line.rs b/lib/git/cmd/blame/blame_line.rs new file mode 100644 index 0000000..4ef8811 --- /dev/null +++ b/lib/git/cmd/blame/blame_line.rs @@ -0,0 +1,84 @@ +use std::path::Path; + +use crate::{ + bare::GitBare, + cmd::{ + blame::{BlameOptions, CommitBlameLine}, + oid::ObjectId, + }, + errors::GitResult, +}; + +impl GitBare { + pub fn blame_lines( + &self, + oid: ObjectId, + path: impl AsRef, + opts: Option, + ) -> GitResult> { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + + let blame_opts = opts.unwrap_or_default(); + + let ranges = if let (Some(min), Some(max)) = + (blame_opts.min_line, blame_opts.max_line) + { + gix::blame::BlameRanges::from_one_based_inclusive_range( + min as u32..=max as u32, + )? + } else if let Some(min) = blame_opts.min_line { + gix::blame::BlameRanges::from_one_based_inclusive_range( + min as u32..=min as u32, + )? + } else { + gix::blame::BlameRanges::default() + }; + + let rewrites = if blame_opts.track_copies_same_file + || blame_opts.track_copies_same_commit_moves + { + Some(gix::diff::Rewrites::default()) + } else { + None + }; + + let gix_options = gix::repository::blame_file::Options { + diff_algorithm: None, + ranges, + since: None, + rewrites, + }; + + let outcome = repo.blame_file( + path.as_ref() + .as_os_str() + .to_string_lossy() + .as_bytes() + .as_ref(), + gix_id, + gix_options, + )?; + + let mut result = Vec::new(); + + for (entry, lines) in outcome.entries_with_lines() { + let commit_oid = + ObjectId::new(entry.commit_id.to_hex().to_string()); + let orig_path = + entry.source_file_name.as_ref().map(|p| p.to_string()); + let base_line = entry.range_in_blamed_file().start as u32 + 1; + + for (i, line) in lines.iter().enumerate() { + result.push(CommitBlameLine { + commit_oid: commit_oid.clone(), + line_no: base_line + i as u32, + content: line.to_string(), + orig_path: orig_path.clone(), + }); + } + } + + Ok(result) + } +} diff --git a/lib/git/cmd/blame/mod.rs b/lib/git/cmd/blame/mod.rs new file mode 100644 index 0000000..b2b5992 --- /dev/null +++ b/lib/git/cmd/blame/mod.rs @@ -0,0 +1,65 @@ +use serde::{Deserialize, Serialize}; + +use crate::cmd::oid::ObjectId; + +pub mod blame_file; +pub mod blame_hunk; +pub mod blame_line; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitBlameHunk { + pub commit_oid: ObjectId, + pub final_start_line: u32, + pub final_lines: u32, + pub orig_start_line: u32, + pub orig_lines: u32, + pub boundary: bool, + pub orig_path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitBlameLine { + pub commit_oid: ObjectId, + pub line_no: u32, + pub content: String, + pub orig_path: Option, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct BlameOptions { + pub min_line: Option, + pub max_line: Option, + pub track_copies_same_file: bool, + pub track_copies_same_commit_moves: bool, + pub ignore_whitespace: bool, +} + +impl BlameOptions { + pub fn new() -> Self { + Self::default() + } + + pub fn min_line(mut self, line: usize) -> Self { + self.min_line = Some(line); + self + } + + pub fn max_line(mut self, line: usize) -> Self { + self.max_line = Some(line); + self + } + pub fn track_copies_same_file(mut self) -> Self { + self.track_copies_same_file = true; + self + } + + pub fn track_copies_same_commit_moves(mut self) -> Self { + self.track_copies_same_commit_moves = true; + self + } + + pub fn ignore_whitespace(mut self) -> Self { + self.ignore_whitespace = true; + self + } +} diff --git a/lib/git/cmd/blob/blob_chunk.rs b/lib/git/cmd/blob/blob_chunk.rs new file mode 100644 index 0000000..92c2c67 --- /dev/null +++ b/lib/git/cmd/blob/blob_chunk.rs @@ -0,0 +1,46 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobChunkParam { + pub path: String, + pub oid: ObjectId, + pub size: usize, + pub offset: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobChunk { + pub param: BlobChunkParam, + pub chunk: Vec, +} + +impl GitBare { + pub fn blob_chunk(&self, param: BlobChunkParam) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (¶m.oid).try_into()?; + + let blob = repo.find_blob(gix_id).map_err(|_| { + GitError::ObjectNotFound(param.oid.as_str().to_string()) + })?; + + let blob_bytes = blob.data.clone(); + let end = param.offset + param.size; + if end > blob_bytes.len() { + return Err(GitError::ParseError(format!( + "chunk offset+size ({}) exceeds blob length ({})", + end, + blob_bytes.len() + ))); + } + + let chunk = blob_bytes[param.offset..end].to_vec(); + + Ok(BlobChunk { param, chunk }) + } +} diff --git a/lib/git/cmd/blob/blob_helper.rs b/lib/git/cmd/blob/blob_helper.rs new file mode 100644 index 0000000..ab3cf3c --- /dev/null +++ b/lib/git/cmd/blob/blob_helper.rs @@ -0,0 +1,27 @@ +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn blob_is_binary(&self, id: ObjectId) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&id).try_into()?; + + let blob = repo + .find_blob(gix_id) + .map_err(|_| GitError::ObjectNotFound(id.as_str().to_string()))?; + + let check_len = std::cmp::min(8000, blob.data.len()); + let is_binary = blob.data[..check_len].contains(&0); + + Ok(is_binary) + } + + pub fn blob_exists(&self, id: ObjectId) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&id).try_into()?; + Ok(repo.has_object(gix_id)) + } +} diff --git a/lib/git/cmd/blob/blob_load.rs b/lib/git/cmd/blob/blob_load.rs new file mode 100644 index 0000000..c899ff0 --- /dev/null +++ b/lib/git/cmd/blob/blob_load.rs @@ -0,0 +1,38 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobLoadParams { + pub id: ObjectId, + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobLoadResult { + pub params: BlobLoadParams, + pub blob: Vec, +} + +impl GitBare { + pub fn blob_load( + &self, + params: &BlobLoadParams, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (¶ms.id).try_into()?; + + let blob = repo.find_blob(gix_id).map_err(|_| { + GitError::ObjectNotFound(params.id.as_str().to_string()) + })?; + + Ok(BlobLoadResult { + params: params.clone(), + blob: blob.data.to_vec(), + }) + } +} diff --git a/lib/git/cmd/blob/blob_size.rs b/lib/git/cmd/blob/blob_size.rs new file mode 100644 index 0000000..829891d --- /dev/null +++ b/lib/git/cmd/blob/blob_size.rs @@ -0,0 +1,26 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobSizeParams { + pub id: ObjectId, + pub path: String, +} + +impl GitBare { + pub fn blob_size(&self, params: &BlobSizeParams) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (¶ms.id).try_into()?; + + let header = repo.find_header(gix_id).map_err(|_| { + GitError::ObjectNotFound(params.id.as_str().to_string()) + })?; + + Ok(header.size() as u64) + } +} diff --git a/lib/git/cmd/blob/blob_upload.rs b/lib/git/cmd/blob/blob_upload.rs new file mode 100644 index 0000000..e2e72a9 --- /dev/null +++ b/lib/git/cmd/blob/blob_upload.rs @@ -0,0 +1,47 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{command::GitCommandParams, oid::ObjectId}, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobUploadParams { + pub blob: Vec, + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlobUploadResult { + pub id: ObjectId, +} + +impl GitBare { + pub fn blob_upload( + &self, + params: BlobUploadParams, + ) -> GitResult { + let cmd_params = GitCommandParams::new(vec![ + "hash-object".to_string(), + "-w".to_string(), + "--stdin".to_string(), + ]) + .with_stdin(params.blob); + + let output = self.git_command_with(cmd_params)?; + + if !output.success { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + + let stdout = output.stdout_lossy(); + let oid_str = stdout.trim(); + let oid = ObjectId::new(oid_str); + + Ok(BlobUploadResult { id: oid }) + } +} diff --git a/lib/git/cmd/blob/mod.rs b/lib/git/cmd/blob/mod.rs new file mode 100644 index 0000000..813dc54 --- /dev/null +++ b/lib/git/cmd/blob/mod.rs @@ -0,0 +1,10 @@ +pub mod blob_chunk; +pub mod blob_helper; +pub mod blob_load; +pub mod blob_size; +pub mod blob_upload; + +pub use blob_chunk::{BlobChunk, BlobChunkParam}; +pub use blob_load::{BlobLoadParams, BlobLoadResult}; +pub use blob_size::BlobSizeParams; +pub use blob_upload::{BlobUploadParams, BlobUploadResult}; diff --git a/lib/git/cmd/branch/branch_delete.rs b/lib/git/cmd/branch/branch_delete.rs new file mode 100644 index 0000000..dc28e1d --- /dev/null +++ b/lib/git/cmd/branch/branch_delete.rs @@ -0,0 +1,18 @@ +use crate::{bare::GitBare, errors::GitResult}; + +pub struct BranchDeleteParams { + pub name: String, + pub force: bool, +} + +impl GitBare { + pub fn branch_delete(&self, params: BranchDeleteParams) -> GitResult<()> { + let flag = if params.force { "-D" } else { "-d" }; + self.git_command_trusted(vec![ + "branch".to_string(), + flag.to_string(), + params.name, + ])?; + Ok(()) + } +} diff --git a/lib/git/cmd/branch/branch_fork.rs b/lib/git/cmd/branch/branch_fork.rs new file mode 100644 index 0000000..1502422 --- /dev/null +++ b/lib/git/cmd/branch/branch_fork.rs @@ -0,0 +1,20 @@ +use crate::{bare::GitBare, cmd::oid::ObjectId, errors::GitResult}; + +pub struct BranchForkParams { + pub name: String, + pub oid: ObjectId, + pub force: bool, +} + +impl GitBare { + pub fn branch_fork(&self, params: &BranchForkParams) -> GitResult<()> { + let mut args = vec!["branch".to_string()]; + if params.force { + args.push("--force".to_string()); + } + args.push(params.name.clone()); + args.push(params.oid.to_string()); + self.git_command_trusted(args)?; + Ok(()) + } +} diff --git a/lib/git/cmd/branch/branch_head.rs b/lib/git/cmd/branch/branch_head.rs new file mode 100644 index 0000000..78c42ea --- /dev/null +++ b/lib/git/cmd/branch/branch_head.rs @@ -0,0 +1,36 @@ +use crate::{bare::GitBare, errors::GitResult}; + +impl GitBare { + pub fn branch_head_name(&self) -> GitResult { + let repo = self.gix_repo()?; + let head_name = repo.head_name()?; + match head_name { + Some(name) => Ok(name.shorten().to_string()), + None => { + Err(crate::errors::GitError::RefNotFound("HEAD".to_string())) + } + } + } + + pub fn branch_ahead_behind(&self, branch_name: String) -> GitResult { + let upstream = self.branch_upstream_name(branch_name.clone())?; + let upstream_name = match upstream { + Some(name) => name, + None => return Ok(false), + }; + + let repo = self.gix_repo()?; + let branch_id = repo + .rev_parse_single(branch_name.as_str()) + .map_err(|_| crate::errors::GitError::RefNotFound(branch_name))?; + let upstream_id = repo + .rev_parse_single(upstream_name.as_str()) + .map_err(|_| crate::errors::GitError::RefNotFound(upstream_name))?; + + let branch_commit = branch_id.detach(); + let upstream_commit = upstream_id.detach(); + let merge_base = repo.merge_base(branch_commit, upstream_commit).ok(); + + Ok(merge_base.is_some_and(|base| base.detach() != branch_commit)) + } +} diff --git a/lib/git/cmd/branch/branch_list.rs b/lib/git/cmd/branch/branch_list.rs new file mode 100644 index 0000000..8e0fd30 --- /dev/null +++ b/lib/git/cmd/branch/branch_list.rs @@ -0,0 +1,133 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BranchListItem { + pub name: String, + pub oid: ObjectId, + pub is_head: bool, + pub is_remote: bool, + pub is_current: bool, + pub upstream: Option, +} + +impl GitBare { + pub fn branch_list_all(&self) -> GitResult> { + let repo = self.gix_repo()?; + let head_name = repo.head_name()?.map(|n| n.shorten().to_string()); + + let mut items = Vec::new(); + let platform = repo.references()?; + let local_iter = platform.local_branches()?; + for ref_result in local_iter { + let reference = ref_result?; + let name = reference.name().shorten().to_string(); + let is_head = head_name.as_ref() == Some(&name); + let is_current = is_head; + + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Internal( + "local branch has no direct target".to_string(), + ) + })?; + let oid = ObjectId::new(target_id.to_hex().to_string()); + let upstream = reference + .remote_tracking_ref_name(gix::remote::Direction::Fetch) + .and_then(|r| r.ok()) + .map(|n| n.shorten().to_string()); + + items.push(BranchListItem { + name, + oid, + is_head, + is_remote: false, + is_current, + upstream, + }); + } + let remote_iter = platform.remote_branches()?; + for ref_result in remote_iter { + let reference = ref_result?; + let name = reference.name().shorten().to_string(); + let is_head = head_name.as_ref() == Some(&name); + let is_current = is_head; + + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Internal( + "remote branch has no direct target".to_string(), + ) + })?; + let oid = ObjectId::new(target_id.to_hex().to_string()); + + items.push(BranchListItem { + name, + oid, + is_head, + is_remote: true, + is_current, + upstream: None, + }); + } + + Ok(items) + } + + pub fn branch_info(&self, branch: String) -> GitResult { + let repo = self.gix_repo()?; + let head_name = repo.head_name()?.map(|n| n.shorten().to_string()); + let local_ref_str = format!("refs/heads/{branch}"); + let local_ref = repo.try_find_reference(local_ref_str.as_str())?; + if let Some(reference) = local_ref { + let name = reference.name().shorten().to_string(); + let is_head = head_name.as_ref() == Some(&name); + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Internal("branch has no direct target".to_string()) + })?; + let oid = ObjectId::new(target_id.to_hex().to_string()); + + let upstream = reference + .remote_tracking_ref_name(gix::remote::Direction::Fetch) + .and_then(|r| r.ok()) + .map(|n| n.shorten().to_string()); + + return Ok(BranchListItem { + name, + oid, + is_head, + is_remote: false, + is_current: is_head, + upstream, + }); + } + let remote_ref_str = format!("refs/remotes/{branch}"); + let remote_ref = repo.try_find_reference(remote_ref_str.as_str())?; + if let Some(reference) = remote_ref { + let name = reference.name().shorten().to_string(); + let is_head = head_name.as_ref() == Some(&name); + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Internal("branch has no direct target".to_string()) + })?; + let oid = ObjectId::new(target_id.to_hex().to_string()); + + return Ok(BranchListItem { + name, + oid, + is_head, + is_remote: true, + is_current: is_head, + upstream: None, + }); + } + + Err(GitError::RefNotFound(branch)) + } +} diff --git a/lib/git/cmd/branch/branch_rename.rs b/lib/git/cmd/branch/branch_rename.rs new file mode 100644 index 0000000..0e451d0 --- /dev/null +++ b/lib/git/cmd/branch/branch_rename.rs @@ -0,0 +1,20 @@ +use crate::{bare::GitBare, errors::GitResult}; + +pub struct BranchReNameParams { + pub old_branch: String, + pub new_branch: String, + pub force: bool, +} + +impl GitBare { + pub fn branch_rename(&self, params: BranchReNameParams) -> GitResult<()> { + let flag = if params.force { "-M" } else { "-m" }; + self.git_command_trusted(vec![ + "branch".to_string(), + flag.to_string(), + params.old_branch, + params.new_branch, + ])?; + Ok(()) + } +} diff --git a/lib/git/cmd/branch/branch_summary.rs b/lib/git/cmd/branch/branch_summary.rs new file mode 100644 index 0000000..71d5202 --- /dev/null +++ b/lib/git/cmd/branch/branch_summary.rs @@ -0,0 +1,26 @@ +use serde::{Deserialize, Serialize}; + +use crate::{bare::GitBare, errors::GitResult}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BranchSummary { + pub local_count: usize, + pub remote_count: usize, + pub all_count: usize, +} + +impl GitBare { + pub fn branch_summary(&self) -> GitResult { + let repo = self.gix_repo()?; + let platform = repo.references()?; + + let local_count = platform.local_branches()?.count(); + let remote_count = platform.remote_branches()?.count(); + + Ok(BranchSummary { + local_count, + remote_count, + all_count: local_count + remote_count, + }) + } +} diff --git a/lib/git/cmd/branch/branch_upstream.rs b/lib/git/cmd/branch/branch_upstream.rs new file mode 100644 index 0000000..5b420ec --- /dev/null +++ b/lib/git/cmd/branch/branch_upstream.rs @@ -0,0 +1,40 @@ +use crate::{ + bare::GitBare, cmd::branch::branch_list::BranchListItem, errors::GitResult, +}; + +impl GitBare { + pub fn branch_upstream_name( + &self, + branch_name: String, + ) -> GitResult> { + let repo = self.gix_repo()?; + let ref_str = format!("refs/heads/{branch_name}"); + let reference = repo.find_reference(ref_str.as_str())?; + + let upstream = reference + .remote_tracking_ref_name(gix::remote::Direction::Fetch) + .and_then(|r| r.ok()) + .map(|n| n.shorten().to_string()); + + Ok(upstream) + } + + pub fn branch_upstream( + &self, + branch_name: String, + ) -> GitResult> { + let upstream_name = self.branch_upstream_name(branch_name)?; + match upstream_name { + Some(name) => { + let info = self.branch_info(name)?; + Ok(Some(info)) + } + None => Ok(None), + } + } + + pub fn branch_has_upstream(&self, branch_name: String) -> GitResult { + let upstream = self.branch_upstream_name(branch_name)?; + Ok(upstream.is_some()) + } +} diff --git a/lib/git/cmd/branch/mod.rs b/lib/git/cmd/branch/mod.rs new file mode 100644 index 0000000..474c76c --- /dev/null +++ b/lib/git/cmd/branch/mod.rs @@ -0,0 +1,13 @@ +pub mod branch_delete; +pub mod branch_fork; +pub mod branch_head; +pub mod branch_list; +pub mod branch_rename; +pub mod branch_summary; +pub mod branch_upstream; + +pub use branch_delete::BranchDeleteParams; +pub use branch_fork::BranchForkParams; +pub use branch_list::BranchListItem; +pub use branch_rename::BranchReNameParams; +pub use branch_summary::BranchSummary; diff --git a/lib/git/cmd/command.rs b/lib/git/cmd/command.rs new file mode 100644 index 0000000..4aaded7 --- /dev/null +++ b/lib/git/cmd/command.rs @@ -0,0 +1,420 @@ +use std::{ + ffi::{OsStr, OsString}, + path::{Component, Path, PathBuf}, +}; + +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + errors::{GitError, GitResult}, +}; + +const DANGEROUS_SUBCOMMANDS: &[&str] = &[ + "add", + "apply", + "archive", + "bisect", + "bundle", + "checkout", + "clean", + "clone", + "commit", + "config", + "fetch", + "gc", + "hook", + "init", + "maintenance", + "merge", + "mv", + "notes", + "pull", + "push", + "rebase", + "remote", + "repack", + "replace", + "reset", + "restore", + "rm", + "send-email", + "sparse-checkout", + "stash", + "submodule", + "switch", + "update-index", + "update-ref", + "worktree", +]; + +const DANGEROUS_OPTIONS: &[&str] = &[ + "-C", + "-c", + "--exec-path", + "--git-dir", + "--html-path", + "--man-path", + "--namespace", + "--paginate", + "--super-prefix", + "--upload-pack", + "--work-tree", +]; + +const DANGEROUS_OPTION_PREFIXES: &[&str] = &[ + "--exec-path=", + "--git-dir=", + "--namespace=", + "--super-prefix=", + "--upload-pack=", + "--work-tree=", +]; + +const DENIED_ENV_NAMES: &[&str] = &[ + "GIT_ALTERNATE_OBJECT_DIRECTORIES", + "GIT_CONFIG", + "GIT_CONFIG_COUNT", + "GIT_CONFIG_GLOBAL", + "GIT_CONFIG_NOSYSTEM", + "GIT_CONFIG_SYSTEM", + "GIT_DIR", + "GIT_EXEC_PATH", + "GIT_EXTERNAL_DIFF", + "GIT_INDEX_FILE", + "GIT_OBJECT_DIRECTORY", + "GIT_SSH", + "GIT_SSH_COMMAND", + "GIT_WORK_TREE", + "LD_LIBRARY_PATH", + "PATH", +]; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GitCommandParams { + pub args: Vec, + pub env: Vec<(String, String)>, + pub stdin: Option>, + pub check_status: bool, + pub bypass_subcommand_validation: bool, +} + +impl GitCommandParams { + pub fn new(args: Vec) -> Self { + Self { + args, + env: Vec::new(), + stdin: None, + check_status: true, + bypass_subcommand_validation: false, + } + } + + pub fn unchecked(mut self) -> Self { + self.check_status = false; + self + } + + pub fn trusted(mut self) -> Self { + self.bypass_subcommand_validation = true; + self + } + + pub fn with_stdin(mut self, stdin: Vec) -> Self { + self.stdin = Some(stdin); + self + } + + pub fn with_env(mut self, name: String, value: String) -> Self { + self.env.push((name, value)); + self + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GitCommandOutput { + pub status_code: Option, + pub success: bool, + pub stdout: Vec, + pub stderr: Vec, +} + +impl GitCommandOutput { + pub fn stdout_lossy(&self) -> String { + String::from_utf8_lossy(&self.stdout).into_owned() + } + + pub fn stderr_lossy(&self) -> String { + String::from_utf8_lossy(&self.stderr).into_owned() + } +} + +impl GitBare { + pub fn git_command( + &self, + args: Vec, + ) -> GitResult { + self.git_command_with(GitCommandParams::new(args)) + } + + pub fn git_command_with( + &self, + params: GitCommandParams, + ) -> GitResult { + let bare_dir = self.safe_bare_dir()?; + validate_git_command(¶ms)?; + + let mut args = Vec::with_capacity(params.args.len() + 2); + args.push(OsString::from("--git-dir")); + args.push(bare_dir.as_os_str().to_os_string()); + args.extend(params.args.iter().map(OsString::from)); + + let mut expression = duct::cmd("git", args) + .dir(&bare_dir) + .stdout_capture() + .stderr_capture() + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_TERMINAL_PROMPT", "0") + .env_remove("GIT_DIR") + .env_remove("GIT_WORK_TREE") + .env_remove("GIT_INDEX_FILE") + .env_remove("GIT_OBJECT_DIRECTORY") + .env_remove("GIT_ALTERNATE_OBJECT_DIRECTORIES") + .env_remove("GIT_SSH") + .env_remove("GIT_SSH_COMMAND") + .unchecked(); + + for (name, value) in ¶ms.env { + expression = expression.env(name, value); + } + + if let Some(stdin) = params.stdin { + expression = expression.stdin_bytes(stdin); + } + + let output = expression.run()?; + let result = GitCommandOutput { + status_code: output.status.code(), + success: output.status.success(), + stdout: output.stdout, + stderr: output.stderr, + }; + + if params.check_status && !result.success { + return Err(GitError::CommandFailed { + status_code: result.status_code, + stderr: result.stderr_lossy(), + }); + } + + Ok(result) + } + + pub fn git_command_stdout(&self, args: Vec) -> GitResult { + let output = self.git_command(args)?; + Ok(output.stdout_lossy()) + } + + pub fn git_command_success(&self, args: Vec) -> GitResult { + let output = + self.git_command_with(GitCommandParams::new(args).unchecked())?; + Ok(output.success) + } + + pub fn git_command_trusted( + &self, + args: Vec, + ) -> GitResult { + self.git_command_with(GitCommandParams::new(args).trusted()) + } + + pub fn git_command_trusted_unchecked( + &self, + args: Vec, + ) -> GitResult { + self.git_command_with(GitCommandParams::new(args).trusted().unchecked()) + } + + pub fn git_command_trusted_stdout( + &self, + args: Vec, + ) -> GitResult { + let output = + self.git_command_with(GitCommandParams::new(args).trusted())?; + Ok(output.stdout_lossy()) + } + + pub fn git_command_trusted_success( + &self, + args: Vec, + ) -> GitResult { + let output = self.git_command_with( + GitCommandParams::new(args).trusted().unchecked(), + )?; + Ok(output.success) + } + + fn safe_bare_dir(&self) -> GitResult { + let bare_dir = self.bare_dir.canonicalize()?; + if !bare_dir.is_dir() { + return Err(GitError::NotBareRepository); + } + + Ok(bare_dir) + } +} + +fn validate_git_command(params: &GitCommandParams) -> GitResult<()> { + if params.bypass_subcommand_validation { + return Ok(()); + } + + if params.args.is_empty() { + return Err(GitError::UnsafeCommand( + "missing git subcommand".to_owned(), + )); + } + + let subcommand = first_subcommand(¶ms.args).ok_or_else(|| { + GitError::UnsafeCommand("missing git subcommand".to_owned()) + })?; + + if DANGEROUS_SUBCOMMANDS.contains(&subcommand.as_str()) { + return Err(GitError::UnsafeCommand(format!( + "subcommand `{subcommand}` is not allowed" + ))); + } + + for arg in ¶ms.args { + validate_git_arg(arg)?; + } + + for (name, value) in ¶ms.env { + validate_git_env(name, value)?; + } + + Ok(()) +} + +fn first_subcommand(args: &[String]) -> Option { + let mut iter = args.iter().peekable(); + while let Some(arg) = iter.next() { + if arg == "--" { + return None; + } + + if arg == "-c" + || arg == "-C" + || arg == "--git-dir" + || arg == "--work-tree" + { + iter.next(); + continue; + } + + if arg.starts_with('-') { + continue; + } + + return Some(arg.to_owned()); + } + + None +} + +fn validate_git_arg(arg: &str) -> GitResult<()> { + if DANGEROUS_OPTIONS.contains(&arg) { + return Err(GitError::UnsafeCommand(format!( + "option `{arg}` is not allowed" + ))); + } + + if DANGEROUS_OPTION_PREFIXES + .iter() + .any(|prefix| arg.starts_with(prefix)) + { + return Err(GitError::UnsafeCommand(format!( + "option `{arg}` is not allowed" + ))); + } + + if let Some((name, value)) = arg.split_once('=') { + if looks_like_path(value) { + validate_relative_path(value).map_err(|_| { + GitError::UnsafeCommand(format!( + "path value for `{name}` escapes bare_dir" + )) + })?; + } + } else if looks_like_path(arg) { + validate_relative_path(arg).map_err(|_| { + GitError::UnsafeCommand(format!( + "path argument `{arg}` escapes bare_dir" + )) + })?; + } + + if contains_shell_metachar(arg) { + return Err(GitError::UnsafeCommand(format!( + "argument `{arg}` contains shell metacharacters" + ))); + } + + Ok(()) +} + +fn validate_git_env(name: &str, value: &str) -> GitResult<()> { + let upper_name = name.to_ascii_uppercase(); + if DENIED_ENV_NAMES.contains(&upper_name.as_str()) + || upper_name.starts_with("GIT_CONFIG_") + { + return Err(GitError::UnsafeCommand(format!( + "environment variable `{name}` is not allowed" + ))); + } + + if looks_like_path(value) { + validate_relative_path(value).map_err(|_| { + GitError::UnsafeCommand(format!( + "environment variable `{name}` escapes bare_dir" + )) + })?; + } + + Ok(()) +} + +fn looks_like_path(value: &str) -> bool { + value.contains('/') || value.contains('\\') || value == "." || value == ".." +} + +fn validate_relative_path(value: &str) -> Result<(), ()> { + let path = Path::new(value); + if path.is_absolute() { + return Err(()); + } + + for component in path.components() { + match component { + Component::Normal(part) if is_safe_path_part(part) => {} + Component::CurDir => {} + _ => return Err(()), + } + } + + Ok(()) +} + +fn is_safe_path_part(part: &OsStr) -> bool { + let Some(part) = part.to_str() else { + return false; + }; + + !part.is_empty() && part != "." && part != ".." +} + +fn contains_shell_metachar(value: &str) -> bool { + value.chars().any(|ch| { + matches!(ch, ';' | '|' | '&' | '`' | '$' | '<' | '>' | '\n' | '\r') + }) +} diff --git a/lib/git/cmd/commit/commit_cherry_pick.rs b/lib/git/cmd/commit/commit_cherry_pick.rs new file mode 100644 index 0000000..2179e10 --- /dev/null +++ b/lib/git/cmd/commit/commit_cherry_pick.rs @@ -0,0 +1,186 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{ + commit::CommitSignature, oid::ObjectId, parse::format_git_timestamp, + }, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Deserialize, Debug, Serialize)] +pub struct CommitCherryPickParams { + pub cherrypick_oid: ObjectId, + pub author: CommitSignature, + pub committer: CommitSignature, + pub message: Option, + pub mainline: u32, + pub update_ref: Option, +} + +#[derive(Clone, Deserialize, Debug, Serialize)] +pub struct CommitCherryPickSequence { + pub cherrypick_oids: Vec, + pub author: CommitSignature, + pub committer: CommitSignature, + pub update_ref: Option, +} + +impl GitBare { + pub fn commit_pick( + &self, + params: CommitCherryPickParams, + ) -> GitResult { + let cherry_info = self.commit_info(params.cherrypick_oid.clone())?; + let their_tree = cherry_info.tree_id; + let repo = self.gix_repo()?; + let head_id = + repo.head_id().map_err(|e| GitError::Gix(e.to_string()))?; + let head_oid = ObjectId::new(head_id.detach().to_hex().to_string()); + let head_info = self.commit_info(head_oid.clone())?; + let our_tree = head_info.tree_id; + let base_result = + self.merge_base(head_oid.clone(), params.cherrypick_oid.clone()); + let base_tree = match base_result { + Ok(base_oid) => { + let base_info = self.commit_info(base_oid)?; + base_info.tree_id.as_str().to_string() + } + Err(_) => { + let empty_tree = "4b825dc642cb6eb9a060e54bf899d15363c725d7"; + let empty_tree_gix_id: gix::hash::ObjectId = + gix::hash::ObjectId::from_hex(empty_tree.as_bytes()) + .map_err(|e| GitError::Gix(e.to_string()))?; + if repo.has_object(empty_tree_gix_id) { + empty_tree.to_string() + } else { + let mktree_output = + self.git_command_trusted_stdout(vec![ + "mktree".to_string(), + ])?; + mktree_output.trim().to_string() + } + } + }; + let merge_tree_args = vec![ + "merge-tree".to_string(), + "--write-tree".to_string(), + "--merge-base".to_string(), + base_tree.clone(), + our_tree.as_str().to_string(), + their_tree.as_str().to_string(), + ]; + + let merge_output = self.git_command_trusted(merge_tree_args); + + let merged_tree_id = match merge_output { + Ok(output) => { + if output.success { + ObjectId::new(output.stdout_lossy().trim()) + } else { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + } + Err(err) => return Err(err), + }; + let message = params + .message + .clone() + .unwrap_or(cherry_info.message.clone()); + let mut commit_args = vec![ + "commit-tree".to_string(), + merged_tree_id.as_str().to_string(), + "-p".to_string(), + head_oid.as_str().to_string(), + ]; + if cherry_info.parent_ids.len() > 1 { + commit_args.push("-m".to_string()); + commit_args.push(format!( + "cherry picked from commit {}", + params.cherrypick_oid.as_str() + )); + if params.mainline > 0 { + commit_args.push(format!("(mainline {})", params.mainline)); + } + } + let cmd_params = + crate::cmd::command::GitCommandParams::new(commit_args) + .trusted() + .with_env( + "GIT_AUTHOR_NAME".to_string(), + params.author.name.clone(), + ) + .with_env( + "GIT_AUTHOR_EMAIL".to_string(), + params.author.email.clone(), + ) + .with_env( + "GIT_AUTHOR_DATE".to_string(), + format_git_timestamp( + params.author.time_secs, + params.author.offset_minutes, + ), + ) + .with_env( + "GIT_COMMITTER_NAME".to_string(), + params.committer.name.clone(), + ) + .with_env( + "GIT_COMMITTER_EMAIL".to_string(), + params.committer.email.clone(), + ) + .with_env( + "GIT_COMMITTER_DATE".to_string(), + format_git_timestamp( + params.committer.time_secs, + params.committer.offset_minutes, + ), + ) + .with_stdin(message.as_bytes().to_vec()); + let commit_output = self.git_command_with(cmd_params)?; + + if !commit_output.success { + return Err(GitError::CommandFailed { + status_code: commit_output.status_code, + stderr: commit_output.stderr_lossy(), + }); + } + + let new_oid = ObjectId::new(commit_output.stdout_lossy().trim()); + if let Some(ref_name) = ¶ms.update_ref { + self.git_command_trusted(vec![ + "update-ref".to_string(), + ref_name.clone(), + new_oid.as_str().to_string(), + ])?; + } + + Ok(new_oid) + } + pub fn commit_cherry_pick_sequence( + &self, + params: CommitCherryPickSequence, + ) -> GitResult { + let mut last_oid: Option = None; + + for cherry_oid in params.cherrypick_oids { + let pick_params = CommitCherryPickParams { + cherrypick_oid: cherry_oid, + author: params.author.clone(), + committer: params.committer.clone(), + message: None, // Use original commit message + mainline: 1, // Default mainline + update_ref: params.update_ref.clone(), + }; + + last_oid = Some(self.commit_pick(pick_params)?); + } + + last_oid.ok_or_else(|| { + GitError::ParseError("cherry-pick sequence was empty".to_string()) + }) + } +} diff --git a/lib/git/cmd/commit/commit_create.rs b/lib/git/cmd/commit/commit_create.rs new file mode 100644 index 0000000..a35c357 --- /dev/null +++ b/lib/git/cmd/commit/commit_create.rs @@ -0,0 +1,234 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{command::GitCommandParams, oid::ObjectId}, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileChange { + pub path: String, + pub content: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateCommitParams { + pub branch: String, + pub message: String, + pub author_name: String, + pub author_email: String, + pub committer_name: String, + pub committer_email: String, + pub files: Vec, +} + +#[derive(Debug, Clone)] +struct TreeEntry { + mode: String, // e.g. "100644" + kind: String, // "blob" or "tree" + oid: ObjectId, + path: String, +} + +impl GitBare { + /// Create a commit with the given file changes on the specified branch. + /// + /// Returns the OID of the newly created commit. + pub fn commit_create( + &self, + params: CreateCommitParams, + ) -> GitResult { + let repo = self.gix_repo()?; + + // 1. Write each file as a blob and collect entries + let mut new_entries: Vec = Vec::with_capacity(params.files.len()); + for fc in ¶ms.files { + let blob_upload_result = self.blob_upload(crate::cmd::blob::BlobUploadParams { + blob: fc.content.clone(), + path: fc.path.clone(), + })?; + new_entries.push(TreeEntry { + mode: "100644".to_string(), + kind: "blob".to_string(), + oid: blob_upload_result.id, + path: fc.path.clone(), + }); + } + + // 2. Get the parent commit (HEAD of the branch, or initial commit) + let parent_oid: Option; + let existing_tree_entries: Vec; + + let branch_ref = format!("refs/heads/{}", params.branch); + + match repo.find_reference(&branch_ref) { + Ok(r) => { + let oid = r.into_fully_peeled_id() + .map_err(|e| GitError::Gix(e.to_string()))? + .detach(); + let oid_str = oid.to_hex().to_string(); + let commit_oid = ObjectId::new(oid_str); + let commit = self.commit_info(commit_oid.clone())?; + + let tree_entries = self._ls_tree(commit.tree_id)?; + parent_oid = Some(commit_oid); + existing_tree_entries = tree_entries; + } + Err(_) => { + parent_oid = None; + existing_tree_entries = Vec::new(); + } + } + + // 3. Merge new entries into existing tree (replace by path) + let mut merged: Vec = Vec::new(); + let mut seen_paths: std::collections::HashSet = + std::collections::HashSet::new(); + + for entry in new_entries { + seen_paths.insert(entry.path.clone()); + merged.push(entry); + } + for entry in existing_tree_entries { + if !seen_paths.contains(&entry.path) { + merged.push(entry); + } + } + + // 4. Create tree with git mktree + let tree_oid = self._mktree(&merged)?; + + // 5. Create commit with git commit-tree + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64; + let timestamp = crate::cmd::parse::format_git_timestamp(now, 0); + + let mut commit_tree_args = vec![ + "commit-tree".to_string(), + tree_oid.as_str().to_string(), + ]; + if let Some(parent) = &parent_oid { + commit_tree_args.push("-p".to_string()); + commit_tree_args.push(parent.as_str().to_string()); + } + commit_tree_args.push("-F".to_string()); + commit_tree_args.push("-".to_string()); + + let commit_output = self.git_command_with( + GitCommandParams::new(commit_tree_args) + .with_stdin(params.message.as_bytes().to_vec()) + .with_env("GIT_AUTHOR_NAME".to_string(), params.author_name.clone()) + .with_env("GIT_AUTHOR_EMAIL".to_string(), params.author_email.clone()) + .with_env("GIT_AUTHOR_DATE".to_string(), timestamp.clone()) + .with_env("GIT_COMMITTER_NAME".to_string(), params.committer_name.clone()) + .with_env("GIT_COMMITTER_EMAIL".to_string(), params.committer_email.clone()) + .with_env("GIT_COMMITTER_DATE".to_string(), timestamp), + )?; + + if !commit_output.success { + return Err(GitError::CommandFailed { + status_code: commit_output.status_code, + stderr: commit_output.stderr_lossy(), + }); + } + + let stdout_str = commit_output.stdout_lossy(); + let commit_oid_str = stdout_str.trim(); + if commit_oid_str.is_empty() { + return Err(GitError::CommandFailed { + status_code: commit_output.status_code, + stderr: "no commit OID produced".to_string(), + }); + } + + let commit_oid = ObjectId::new(commit_oid_str); + + // 6. Update the branch ref + let update_ref_args = vec![ + "update-ref".to_string(), + branch_ref, + commit_oid.as_str().to_string(), + ]; + let _update_output = self.git_command_with( + GitCommandParams::new(update_ref_args), + )?; + + Ok(commit_oid) + } + + /// List tree entries (mode, type, oid, path) using git ls-tree. + fn _ls_tree(&self, tree_oid: ObjectId) -> GitResult> { + let output = self.git_command_with(GitCommandParams::new(vec![ + "ls-tree".to_string(), + tree_oid.as_str().to_string(), + ]))?; + + if !output.success { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + + let stdout = output.stdout_lossy(); + let mut entries = Vec::new(); + + for line in stdout.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + // Output format: \t + let parts: Vec<&str> = line.splitn(2, '\t').collect(); + if parts.len() < 2 { + continue; + } + let meta: Vec<&str> = parts[0].split_whitespace().collect(); + if meta.len() < 3 { + continue; + } + entries.push(TreeEntry { + mode: meta[0].to_string(), + kind: meta[1].to_string(), + oid: ObjectId::new(meta[2]), + path: parts[1].to_string(), + }); + } + + Ok(entries) + } + + /// Create a tree object from entries using git mktree. + fn _mktree(&self, entries: &[TreeEntry]) -> GitResult { + // Build input for git mktree: + // Each line: \t + let mut input = String::new(); + for entry in entries { + input.push_str(&format!( + "{} {} {}\t{}\n", + entry.mode, entry.kind, entry.oid.as_str(), entry.path + )); + } + + let output = self.git_command_with( + GitCommandParams::new(vec![ + "mktree".to_string(), + ]) + .with_stdin(input.into_bytes()), + )?; + + if !output.success { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + + let stdout = output.stdout_lossy(); + let oid_str = stdout.trim(); + Ok(ObjectId::new(oid_str)) + } +} diff --git a/lib/git/cmd/commit/commit_history.rs b/lib/git/cmd/commit/commit_history.rs new file mode 100644 index 0000000..8aea12e --- /dev/null +++ b/lib/git/cmd/commit/commit_history.rs @@ -0,0 +1,31 @@ +use crate::{ + bare::GitBare, + cmd::{ + commit::{ + CommitMeta, + commit_walker::{CommitWalkParams, CommitWalkSort}, + }, + oid::ObjectId, + }, + errors::GitResult, +}; + +impl GitBare { + pub fn commit_history( + &self, + params: CommitWalkParams, + ) -> GitResult> { + self.commit_walk(params) + } + + pub fn commit_history_head( + &self, + limit: Option, + ) -> GitResult> { + let mut params = CommitWalkParams::default(); + params.start_oids = vec![ObjectId::new("HEAD")]; + params.limit = limit; + params.sort = CommitWalkSort::Time; + self.commit_history(params) + } +} diff --git a/lib/git/cmd/commit/commit_info.rs b/lib/git/cmd/commit/commit_info.rs new file mode 100644 index 0000000..4fcc9fc --- /dev/null +++ b/lib/git/cmd/commit/commit_info.rs @@ -0,0 +1,72 @@ +use crate::{ + bare::GitBare, + cmd::{ + commit::{CommitMeta, CommitSignature}, + oid::ObjectId, + }, + errors::GitResult, +}; + +impl GitBare { + pub fn commit_info(&self, oid: ObjectId) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + + let commit = repo.find_commit(gix_id)?; + let decoded = commit.decode()?; + + let message_raw = commit.message_raw()?; + let message = message_raw.to_string(); + let message = message.trim_end_matches('\n').to_string(); + let summary = message.lines().next().unwrap_or("").to_string(); + + let author_sig = decoded.author()?; + let author_time = author_sig.time()?; + let committer_sig = decoded.committer()?; + let committer_time = committer_sig.time()?; + + let tree_id = ObjectId::new(decoded.tree().to_hex().to_string()); + let parent_ids: Vec = decoded + .parents() + .map(|id| ObjectId::new(id.to_hex().to_string())) + .collect(); + + let encoding = decoded + .extra_headers() + .find("encoding") + .map(|v| v.to_string()); + + Ok(CommitMeta { + oid, + message, + summary, + author: CommitSignature { + name: author_sig.name.to_string(), + email: author_sig.email.to_string(), + time_secs: author_time.seconds as i64, + offset_minutes: (author_time.offset / 60) as i32, + }, + committer: CommitSignature { + name: committer_sig.name.to_string(), + email: committer_sig.email.to_string(), + time_secs: committer_time.seconds as i64, + offset_minutes: (committer_time.offset / 60) as i32, + }, + tree_id, + parent_ids, + encoding, + }) + } + + pub fn commit_exists(&self, oid: ObjectId) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + + if !repo.has_object(gix_id) { + return Ok(false); + } + + let header = repo.find_header(gix_id)?; + Ok(header.kind() == gix::object::Kind::Commit) + } +} diff --git a/lib/git/cmd/commit/commit_prefix.rs b/lib/git/cmd/commit/commit_prefix.rs new file mode 100644 index 0000000..ea7c027 --- /dev/null +++ b/lib/git/cmd/commit/commit_prefix.rs @@ -0,0 +1,25 @@ +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn commit_oid_from_prefix(&self, prefix: &str) -> GitResult { + let repo = self.gix_repo()?; + let id = repo + .rev_parse_single(prefix) + .map_err(|_| GitError::ObjectNotFound(prefix.to_string()))?; + + let obj_id = id.detach(); + Ok(ObjectId::new(obj_id.to_hex().to_string())) + } + + pub fn commit_info_from_prefix( + &self, + prefix: &str, + ) -> GitResult { + let oid = self.commit_oid_from_prefix(prefix)?; + self.commit_info(oid) + } +} diff --git a/lib/git/cmd/commit/commit_refs.rs b/lib/git/cmd/commit/commit_refs.rs new file mode 100644 index 0000000..42839e4 --- /dev/null +++ b/lib/git/cmd/commit/commit_refs.rs @@ -0,0 +1,50 @@ +use crate::{ + bare::GitBare, + cmd::{commit::CommitRefInfo, oid::ObjectId}, + errors::GitResult, +}; + +impl GitBare { + pub fn commit_refs(&self, oid: ObjectId) -> GitResult> { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + let target_hex = gix_id.to_hex().to_string(); + + let mut refs = Vec::new(); + + let platform = repo.references()?; + let iter = platform.all()?; + + for ref_result in iter { + let reference = ref_result?; + let full_name = reference.name().as_bstr().to_string(); + + let is_branch = full_name.starts_with("refs/heads/"); + let is_tag = full_name.starts_with("refs/tags/"); + let is_remote = full_name.starts_with("refs/remotes/"); + + if !is_branch && !is_tag && !is_remote { + continue; + } + + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + crate::errors::GitError::Internal( + "ref has no direct target".to_string(), + ) + })?; + + if target_id.to_hex().to_string() == target_hex { + let short_name = reference.name().shorten(); + refs.push(CommitRefInfo { + name: short_name.to_string(), + target: oid.clone(), + is_remote, + is_tag, + }); + } + } + + Ok(refs) + } +} diff --git a/lib/git/cmd/commit/commit_summary.rs b/lib/git/cmd/commit/commit_summary.rs new file mode 100644 index 0000000..a8fda2f --- /dev/null +++ b/lib/git/cmd/commit/commit_summary.rs @@ -0,0 +1,41 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{commit::CommitMeta, oid::ObjectId}, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitSummary { + pub head: Option, + pub count: usize, +} + +impl GitBare { + pub fn commit_summary(&self) -> GitResult { + let repo = self.gix_repo()?; + let head_id = repo + .head_id() + .map_err(|_| GitError::RefNotFound("HEAD".to_string())); + let head_id = match head_id { + Ok(id) => id, + Err(_) => { + return Ok(CommitSummary { + head: None, + count: 0, + }); + } + }; + + let head_oid = ObjectId::new(head_id.detach().to_hex().to_string()); + let head_commit = self.commit_info(head_oid)?; + let walk = repo.rev_walk([head_id.detach()]).all()?; + let count = walk.count(); + + Ok(CommitSummary { + head: Some(head_commit), + count, + }) + } +} diff --git a/lib/git/cmd/commit/commit_walker.rs b/lib/git/cmd/commit/commit_walker.rs new file mode 100644 index 0000000..0d7629d --- /dev/null +++ b/lib/git/cmd/commit/commit_walker.rs @@ -0,0 +1,147 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{commit::CommitMeta, oid::ObjectId}, + errors::{GitError, GitResult}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum CommitWalkSort { + None, + Topological, + Time, + Reverse, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitWalkParams { + pub start_oids: Vec, + pub hide_oids: Vec, + pub limit: Option, + pub skip: usize, + pub first_parent: bool, + pub sort: CommitWalkSort, + pub branch: Option, +} + +impl Default for CommitWalkParams { + fn default() -> Self { + Self { + start_oids: Vec::new(), + hide_oids: Vec::new(), + limit: None, + skip: 0, + first_parent: false, + sort: CommitWalkSort::Time, + branch: None, + } + } +} + +impl GitBare { + pub fn commit_walk( + &self, + params: CommitWalkParams, + ) -> GitResult> { + let repo = self.gix_repo()?; + let tips: Vec = if let Some(ref branch_name) = params.branch { + if branch_name.is_empty() { + vec![repo.head_id()?.detach()] + } else { + let ref_name = format!("refs/heads/{branch_name}"); + match repo.try_find_reference(&ref_name)? { + Some(reference) => { + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Internal(format!("branch '{branch_name}' has no direct target")) + })?; + vec![target_id.to_owned()] + } + None => { + let remote_ref = format!("refs/remotes/{branch_name}"); + match repo.try_find_reference(&remote_ref)? { + Some(reference) => { + let target = reference.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Internal(format!("remote branch '{branch_name}' has no direct target")) + })?; + vec![target_id.to_owned()] + } + None => { + return Err(GitError::RefNotFound(branch_name.clone())); + } + } + } + } + } + } else if params.start_oids.is_empty() { + vec![repo.head_id()?.detach()] + } else { + params + .start_oids + .iter() + .map(|oid| oid.try_into()) + .collect::, _>>()? + }; + let hide: Vec = params + .hide_oids + .iter() + .map(|oid| oid.try_into()) + .collect::, _>>()?; + let mut platform = repo.rev_walk(tips); + match params.sort { + CommitWalkSort::None => {} + CommitWalkSort::Topological => { + platform = platform + .sorting(gix::revision::walk::Sorting::BreadthFirst); + } + CommitWalkSort::Time => { + platform = platform.sorting(gix::revision::walk::Sorting::ByCommitTime( + gix::traverse::commit::simple::CommitTimeOrder::NewestFirst, + )); + } + CommitWalkSort::Reverse => { + platform = platform.sorting(gix::revision::walk::Sorting::ByCommitTime( + gix::traverse::commit::simple::CommitTimeOrder::OldestFirst, + )); + } + } + if params.first_parent { + platform = platform.first_parent_only(); + } + if !hide.is_empty() { + platform = platform.with_boundary(hide); + } + + let walk = platform.all()?; + + let mut commits = Vec::new(); + let mut count = 0; + let skip = params.skip; + + for info in walk { + if count < skip { + count += 1; + continue; + } + + let info = info?; + let oid = ObjectId::new(info.id().detach().to_hex().to_string()); + let commit = self.commit_info(oid)?; + commits.push(commit); + if let Some(limit) = params.limit { + if commits.len() >= limit { + break; + } + } + + count += 1; + } + if matches!(params.sort, CommitWalkSort::Reverse) { + commits.reverse(); + } + + Ok(commits) + } +} diff --git a/lib/git/cmd/commit/mod.rs b/lib/git/cmd/commit/mod.rs new file mode 100644 index 0000000..79c834a --- /dev/null +++ b/lib/git/cmd/commit/mod.rs @@ -0,0 +1,46 @@ +use serde::{Deserialize, Serialize}; + +use crate::cmd::oid::ObjectId; + +pub mod commit_cherry_pick; +pub mod commit_create; +pub mod commit_history; +pub mod commit_info; +pub mod commit_prefix; +pub mod commit_refs; +pub mod commit_summary; +pub mod commit_walker; +pub use commit_cherry_pick::{ + CommitCherryPickParams, CommitCherryPickSequence, +}; +pub use commit_create::{CreateCommitParams, FileChange}; +pub use commit_summary::CommitSummary; +pub use commit_walker::{CommitWalkParams, CommitWalkSort}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitSignature { + pub name: String, + pub email: String, + pub time_secs: i64, + pub offset_minutes: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitMeta { + pub oid: ObjectId, + pub message: String, + pub summary: String, + pub author: CommitSignature, + pub committer: CommitSignature, + pub tree_id: ObjectId, + pub parent_ids: Vec, + pub encoding: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitRefInfo { + pub name: String, + pub target: ObjectId, + pub is_remote: bool, + pub is_tag: bool, +} diff --git a/lib/git/cmd/diff/diff_index_to_tree.rs b/lib/git/cmd/diff/diff_index_to_tree.rs new file mode 100644 index 0000000..df1b1f4 --- /dev/null +++ b/lib/git/cmd/diff/diff_index_to_tree.rs @@ -0,0 +1,152 @@ +use gix::bstr::ByteSlice; + +use crate::{ + bare::GitBare, + cmd::{ + diff::{ + DiffOptions, DiffResult, DiffStats, + diff_tree_to_tree::{ + HunkCollector, change_to_delta, matches_pathspec, + }, + }, + oid::ObjectId, + }, + errors::GitResult, +}; + +impl GitBare { + pub fn diff_index_to_tree( + &self, + tree: ObjectId, + opts: Option, + ) -> GitResult { + let options = opts.unwrap_or_default(); + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&tree).try_into()?; + let tree_obj = if let Ok(tree) = repo.find_tree(gix_id) { + tree + } else { + let commit = repo + .find_commit(gix_id) + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + let tree_id = commit + .tree_id() + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))? + .detach(); + repo.find_tree(tree_id) + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))? + }; + let head_id = repo + .head_id() + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + let head_commit = repo + .find_commit(head_id.detach()) + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + let head_tree_id = head_commit + .tree_id() + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))? + .detach(); + let head_tree = repo + .find_tree(head_tree_id) + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + let tree_oid = tree_obj.id().detach(); + let head_tree_oid = head_tree.id().detach(); + if tree_oid == head_tree_oid { + return Ok(DiffResult { + stats: DiffStats { + files_changed: 0, + insertions: 0, + deletions: 0, + }, + deltas: Vec::new(), + }); + } + + let mut diff_opts = gix::diff::Options::default(); + diff_opts.track_path(); + + let changes = repo + .diff_tree_to_tree(&tree_obj, &head_tree, Some(diff_opts)) + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + + let mut resource_cache = repo + .diff_resource_cache_for_tree_diff() + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + + let mut deltas = Vec::new(); + let mut stats = DiffStats { + files_changed: 0, + insertions: 0, + deletions: 0, + }; + + for change in &changes { + let location = change.location().to_str().unwrap_or(""); + if !matches_pathspec(&options.pathspec, location) { + continue; + } + + let mut delta = change_to_delta(change); + + resource_cache + .set_resource_by_change(change.to_ref(), &repo.objects) + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + + let is_binary = { + use gix::diff::blob::platform::prepare_diff::Operation; + + let prep = resource_cache + .prepare_diff() + .map_err(|e| crate::errors::GitError::Gix(e.to_string()))?; + + match prep.operation { + Operation::InternalDiff { algorithm } => { + let input = prep.interned_input(); + let diff = gix::diff::blob::diff_with_slider_heuristics( + algorithm, &input, + ); + + stats.files_changed += 1; + stats.insertions += diff.count_additions() as usize; + stats.deletions += diff.count_removals() as usize; + + if options.context_lines > 0 { + let ctx = gix::diff::blob::unified_diff::ContextSize::symmetrical( + options.context_lines.max(3), + ); + let collector = HunkCollector::new(); + let unified = gix::diff::blob::UnifiedDiff::new( + &diff, &input, collector, ctx, + ); + let (hunks, lines) = + unified.consume().map_err(|e| { + crate::errors::GitError::Gix(e.to_string()) + })?; + delta.hunks = hunks; + delta.lines = lines; + } + false + } + Operation::SourceOrDestinationIsBinary => { + stats.files_changed += 1; + true + } + Operation::ExternalCommand { .. } => { + stats.files_changed += 1; + false + } + } + }; + + if is_binary { + delta.old_file.is_binary = true; + delta.new_file.is_binary = true; + } + + resource_cache.clear_resource_cache_keep_allocation(); + deltas.push(delta); + } + + Ok(DiffResult { stats, deltas }) + } +} diff --git a/lib/git/cmd/diff/diff_patch.rs b/lib/git/cmd/diff/diff_patch.rs new file mode 100644 index 0000000..57c1fdb --- /dev/null +++ b/lib/git/cmd/diff/diff_patch.rs @@ -0,0 +1,366 @@ +use gix::{ + bstr::ByteSlice, + diff::blob::unified_diff::{ + ConsumeHunk, ContextSize, DiffLineKind, HunkHeader, + }, +}; + +use crate::{ + bare::GitBare, + cmd::{ + diff::{ + DiffDeltaStatus, DiffOptions, DiffResult, DiffStats, + SideBySideChangeType, SideBySideDiffResult, SideBySideFile, + SideBySideLine, + diff_tree_to_tree::{ + HunkCollector, change_to_delta, matches_pathspec, peel_to_tree, + }, + }, + oid::ObjectId, + }, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn diff_patch( + &self, + old_commit: ObjectId, + new_commit: ObjectId, + opts: Option, + ) -> GitResult { + let repo = self.gix_repo()?; + let options = opts.unwrap_or_default(); + let context_lines = options.context_lines.max(3); + + let old_tree_obj = peel_to_tree(&repo, old_commit)?; + let new_tree_obj = peel_to_tree(&repo, new_commit)?; + + let mut diff_opts = gix::diff::Options::default(); + diff_opts.track_path(); + + let changes = repo + .diff_tree_to_tree(&old_tree_obj, &new_tree_obj, Some(diff_opts)) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut resource_cache = repo + .diff_resource_cache_for_tree_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut deltas = Vec::new(); + let mut stats = DiffStats { + files_changed: 0, + insertions: 0, + deletions: 0, + }; + + for change in &changes { + let location = change.location().to_str().unwrap_or(""); + if !matches_pathspec(&options.pathspec, location) { + continue; + } + + let mut delta = change_to_delta(change); + + // Only diff blobs — skip trees (directories) + let entry_mode = change.entry_mode(); + if entry_mode.is_tree() { + stats.files_changed += 1; + resource_cache.clear_resource_cache_keep_allocation(); + deltas.push(delta); + continue; + } + + resource_cache + .set_resource_by_change(change.to_ref(), &repo.objects) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let is_binary = { + use gix::diff::blob::platform::prepare_diff::Operation; + + let prep = resource_cache + .prepare_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + match prep.operation { + Operation::InternalDiff { algorithm } => { + let input = prep.interned_input(); + let diff = gix::diff::blob::diff_with_slider_heuristics( + algorithm, &input, + ); + + stats.files_changed += 1; + stats.insertions += diff.count_additions() as usize; + stats.deletions += diff.count_removals() as usize; + + let ctx = ContextSize::symmetrical(context_lines); + let collector = HunkCollector::new(); + let unified = gix::diff::blob::UnifiedDiff::new( + &diff, &input, collector, ctx, + ); + let (hunks, lines) = unified + .consume() + .map_err(|e| GitError::Gix(e.to_string()))?; + delta.hunks = hunks; + delta.lines = lines; + + false + } + Operation::SourceOrDestinationIsBinary => { + stats.files_changed += 1; + true + } + Operation::ExternalCommand { .. } => { + stats.files_changed += 1; + false + } + } + }; + + if is_binary { + delta.old_file.is_binary = true; + delta.new_file.is_binary = true; + } + + resource_cache.clear_resource_cache_keep_allocation(); + deltas.push(delta); + } + + Ok(DiffResult { stats, deltas }) + } + + pub fn diff_patch_side_by_side( + &self, + old_commit: ObjectId, + new_commit: ObjectId, + opts: Option, + ) -> GitResult { + let repo = self.gix_repo()?; + let options = opts.unwrap_or_default(); + let context_lines = options.context_lines.max(3); + + let old_tree_obj = peel_to_tree(&repo, old_commit)?; + let new_tree_obj = peel_to_tree(&repo, new_commit)?; + + let mut diff_opts = gix::diff::Options::default(); + diff_opts.track_path(); + + let changes = repo + .diff_tree_to_tree(&old_tree_obj, &new_tree_obj, Some(diff_opts)) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut resource_cache = repo + .diff_resource_cache_for_tree_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut files: Vec = Vec::new(); + let mut total_additions = 0; + let mut total_deletions = 0; + + for change in &changes { + let location = change.location().to_str().unwrap_or(""); + if !matches_pathspec(&options.pathspec, location) { + continue; + } + + let delta = change_to_delta(change); + let is_rename = delta.status == DiffDeltaStatus::Renamed; + let path = delta.new_file.path.clone().unwrap_or_default(); + + // Skip directories — only diff blobs + if change.entry_mode().is_tree() { + total_additions += 0; + total_deletions += 0; + files.push(SideBySideFile { + path, + additions: 0, + deletions: 0, + is_binary: false, + is_rename, + lines: Vec::new(), + }); + resource_cache.clear_resource_cache_keep_allocation(); + continue; + } + + resource_cache + .set_resource_by_change(change.to_ref(), &repo.objects) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let (file_additions, file_deletions, is_binary, sbs_lines) = { + use gix::diff::blob::platform::prepare_diff::Operation; + + let prep = resource_cache + .prepare_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + match prep.operation { + Operation::InternalDiff { algorithm } => { + let input = prep.interned_input(); + let diff = gix::diff::blob::diff_with_slider_heuristics( + algorithm, &input, + ); + + let adds = diff.count_additions() as usize; + let dels = diff.count_removals() as usize; + + let ctx = ContextSize::symmetrical(context_lines); + let collector = SideBySideCollector::new(); + let unified = gix::diff::blob::UnifiedDiff::new( + &diff, &input, collector, ctx, + ); + let lines = unified + .consume() + .map_err(|e| GitError::Gix(e.to_string()))?; + + (adds, dels, false, lines) + } + Operation::SourceOrDestinationIsBinary => { + (0, 0, true, Vec::new()) + } + Operation::ExternalCommand { .. } => { + (0, 0, false, Vec::new()) + } + } + }; + + total_additions += file_additions; + total_deletions += file_deletions; + + files.push(SideBySideFile { + path, + additions: file_additions, + deletions: file_deletions, + is_binary, + is_rename, + lines: sbs_lines, + }); + + resource_cache.clear_resource_cache_keep_allocation(); + } + + Ok(SideBySideDiffResult { + files, + total_additions, + total_deletions, + }) + } +} +struct SideBySideCollector { + lines: Vec, + pending_removed: Vec<(u32, String)>, + pending_added: Vec<(u32, String)>, + current_old_lineno: u32, + current_new_lineno: u32, +} + +impl SideBySideCollector { + fn new() -> Self { + SideBySideCollector { + lines: Vec::new(), + pending_removed: Vec::new(), + pending_added: Vec::new(), + current_old_lineno: 0, + current_new_lineno: 0, + } + } + + fn flush_pending(&mut self) { + let removed_count = self.pending_removed.len(); + let added_count = self.pending_added.len(); + let common = removed_count.min(added_count); + + for i in 0..common { + let (left_no, old_content) = &self.pending_removed[i]; + let (right_no, new_content) = &self.pending_added[i]; + + let change_type = if old_content == new_content { + SideBySideChangeType::Unchanged + } else { + SideBySideChangeType::Modified + }; + + self.lines.push(SideBySideLine { + left_line_no: Some(*left_no), + right_line_no: Some(*right_no), + left_content: old_content.clone(), + right_content: new_content.clone(), + change_type, + }); + } + + for i in common..removed_count { + let (left_no, old_content) = &self.pending_removed[i]; + self.lines.push(SideBySideLine { + left_line_no: Some(*left_no), + right_line_no: None, + left_content: old_content.clone(), + right_content: String::new(), + change_type: SideBySideChangeType::Removed, + }); + } + + for i in common..added_count { + let (right_no, new_content) = &self.pending_added[i]; + self.lines.push(SideBySideLine { + left_line_no: None, + right_line_no: Some(*right_no), + left_content: String::new(), + right_content: new_content.clone(), + change_type: SideBySideChangeType::Added, + }); + } + + self.pending_removed.clear(); + self.pending_added.clear(); + } +} + +impl ConsumeHunk for SideBySideCollector { + type Out = Vec; + + fn consume_hunk( + &mut self, + header: HunkHeader, + entries: &[(DiffLineKind, &[u8])], + ) -> std::io::Result<()> { + self.flush_pending(); + self.current_old_lineno = header.before_hunk_start; + self.current_new_lineno = header.after_hunk_start; + + for (kind, content) in entries { + let content_str = String::from_utf8_lossy(content).to_string(); + + match kind { + DiffLineKind::Context => { + self.flush_pending(); + self.lines.push(SideBySideLine { + left_line_no: Some(self.current_old_lineno), + right_line_no: Some(self.current_new_lineno), + left_content: content_str.clone(), + right_content: content_str, + change_type: SideBySideChangeType::Unchanged, + }); + self.current_old_lineno += 1; + self.current_new_lineno += 1; + } + DiffLineKind::Add => { + self.pending_added + .push((self.current_new_lineno, content_str)); + self.current_new_lineno += 1; + } + DiffLineKind::Remove => { + self.pending_removed + .push((self.current_old_lineno, content_str)); + self.current_old_lineno += 1; + } + } + } + + Ok(()) + } + + fn finish(mut self) -> Vec { + self.flush_pending(); + self.lines + } +} diff --git a/lib/git/cmd/diff/diff_stats.rs b/lib/git/cmd/diff/diff_stats.rs new file mode 100644 index 0000000..cd2f42f --- /dev/null +++ b/lib/git/cmd/diff/diff_stats.rs @@ -0,0 +1,104 @@ +use gix::bstr::ByteSlice; + +use crate::{ + bare::GitBare, + cmd::{ + diff::{ + DiffOptions, DiffStats, + diff_tree_to_tree::{matches_pathspec, peel_to_tree}, + }, + oid::ObjectId, + }, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn diff_stats( + &self, + old_commit: ObjectId, + new_commit: ObjectId, + opts: Option, + ) -> GitResult { + let repo = self.gix_repo()?; + let options = opts.unwrap_or_default(); + + let old_tree_obj = peel_to_tree(&repo, old_commit)?; + let new_tree_obj = peel_to_tree(&repo, new_commit)?; + let mut platform = old_tree_obj + .changes() + .map_err(|e| GitError::Gix(e.to_string()))?; + platform.options(|opts| { + opts.track_rewrites(None); + }); + + let gix_stats = platform + .stats(&new_tree_obj) + .map_err(|e| GitError::Gix(e.to_string()))?; + if options.pathspec.is_empty() { + Ok(DiffStats { + files_changed: gix_stats.files_changed as usize, + insertions: gix_stats.lines_added as usize, + deletions: gix_stats.lines_removed as usize, + }) + } else { + let changes = repo + .diff_tree_to_tree(&old_tree_obj, &new_tree_obj, None) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut resource_cache = + repo.diff_resource_cache_for_tree_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut stats = DiffStats { + files_changed: 0, + insertions: 0, + deletions: 0, + }; + + for change in &changes { + let location = change.location().to_str().unwrap_or(""); + if !matches_pathspec(&options.pathspec, location) { + continue; + } + + // Skip directories — only diff blobs + if change.entry_mode().is_tree() { + stats.files_changed += 1; + continue; + } + + resource_cache + .set_resource_by_change(change.to_ref(), &repo.objects) + .map_err(|e| GitError::Gix(e.to_string()))?; + + { + use gix::diff::blob::platform::prepare_diff::Operation; + + let prep = resource_cache + .prepare_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + if let Operation::InternalDiff { algorithm } = + prep.operation + { + let input = prep.interned_input(); + let diff = gix::diff::blob::diff_with_slider_heuristics( + algorithm, &input, + ); + stats.files_changed += 1; + stats.insertions += diff.count_additions() as usize; + stats.deletions += diff.count_removals() as usize; + } else if let Operation::SourceOrDestinationIsBinary = + prep.operation + { + stats.files_changed += 1; + } + } + + resource_cache.clear_resource_cache_keep_allocation(); + } + + Ok(stats) + } + } +} diff --git a/lib/git/cmd/diff/diff_tree_to_tree.rs b/lib/git/cmd/diff/diff_tree_to_tree.rs new file mode 100644 index 0000000..cf90878 --- /dev/null +++ b/lib/git/cmd/diff/diff_tree_to_tree.rs @@ -0,0 +1,362 @@ +use gix::{ + bstr::ByteSlice, + diff::blob::unified_diff::{ + ConsumeHunk, ContextSize, DiffLineKind, HunkHeader, + }, +}; + +use crate::{ + bare::GitBare, + cmd::{ + diff::{ + DiffDelta, DiffDeltaStatus, DiffFile, DiffHunk, DiffLine, + DiffOptions, DiffResult, DiffStats, + }, + oid::ObjectId, + }, + errors::{GitError, GitResult}, +}; +pub fn peel_to_tree( + repo: &gix::Repository, + oid: ObjectId, +) -> GitResult> { + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + if let Ok(tree) = repo.find_tree(gix_id) { + Ok(tree) + } else { + let commit = repo + .find_commit(gix_id) + .map_err(|e| GitError::Gix(e.to_string()))?; + let tree_id = commit + .tree_id() + .map_err(|e| GitError::Gix(e.to_string()))? + .detach(); + repo.find_tree(tree_id) + .map_err(|e| GitError::Gix(e.to_string())) + } +} +pub fn matches_pathspec(pathspec: &[String], path: &str) -> bool { + if pathspec.is_empty() { + return true; + } + pathspec + .iter() + .any(|spec| path == spec || path.starts_with(spec)) +} +pub fn change_to_delta( + change: &gix::diff::tree_with_rewrites::Change, +) -> DiffDelta { + use gix::diff::tree_with_rewrites::Change; + + match change { + Change::Addition { location, id, .. } => { + let path = location.to_str().unwrap_or("").to_string(); + DiffDelta { + status: DiffDeltaStatus::Added, + old_file: DiffFile { + oid: None, + path: Some(path.clone()), + size: 0, + is_binary: false, + }, + new_file: DiffFile { + oid: oid_to_option(id), + path: Some(path), + size: 0, + is_binary: false, + }, + nfiles: 1, + hunks: Vec::new(), + lines: Vec::new(), + } + } + Change::Deletion { location, id, .. } => { + let path = location.to_str().unwrap_or("").to_string(); + DiffDelta { + status: DiffDeltaStatus::Deleted, + old_file: DiffFile { + oid: oid_to_option(id), + path: Some(path.clone()), + size: 0, + is_binary: false, + }, + new_file: DiffFile { + oid: None, + path: Some(path), + size: 0, + is_binary: false, + }, + nfiles: 1, + hunks: Vec::new(), + lines: Vec::new(), + } + } + Change::Modification { + location, + previous_entry_mode, + previous_id, + entry_mode, + id, + } => { + let path = location.to_str().unwrap_or("").to_string(); + let status = if previous_entry_mode.kind() != entry_mode.kind() { + DiffDeltaStatus::Typechange + } else { + DiffDeltaStatus::Modified + }; + DiffDelta { + status, + old_file: DiffFile { + oid: oid_to_option(previous_id), + path: Some(path.clone()), + size: 0, + is_binary: false, + }, + new_file: DiffFile { + oid: oid_to_option(id), + path: Some(path), + size: 0, + is_binary: false, + }, + nfiles: 1, + hunks: Vec::new(), + lines: Vec::new(), + } + } + Change::Rewrite { + source_location, + source_id, + location, + id, + copy, + .. + } => { + let old_path = source_location.to_str().unwrap_or("").to_string(); + let new_path = location.to_str().unwrap_or("").to_string(); + DiffDelta { + status: if *copy { + DiffDeltaStatus::Copied + } else { + DiffDeltaStatus::Renamed + }, + old_file: DiffFile { + oid: oid_to_option(source_id), + path: Some(old_path), + size: 0, + is_binary: false, + }, + new_file: DiffFile { + oid: oid_to_option(id), + path: Some(new_path), + size: 0, + is_binary: false, + }, + nfiles: 2, + hunks: Vec::new(), + lines: Vec::new(), + } + } + } +} + +pub fn oid_to_option(id: &gix::hash::oid) -> Option { + if id.is_null() { + None + } else { + Some(ObjectId::new(id.to_hex().to_string())) + } +} +pub struct HunkCollector { + hunks: Vec, + lines: Vec, + current_old_lineno: u32, + current_new_lineno: u32, +} + +impl HunkCollector { + pub fn new() -> Self { + HunkCollector { + hunks: Vec::new(), + lines: Vec::new(), + current_old_lineno: 0, + current_new_lineno: 0, + } + } +} + +impl ConsumeHunk for HunkCollector { + type Out = (Vec, Vec); + + fn consume_hunk( + &mut self, + header: HunkHeader, + entries: &[(DiffLineKind, &[u8])], + ) -> std::io::Result<()> { + self.current_old_lineno = header.before_hunk_start; + self.current_new_lineno = header.after_hunk_start; + + self.hunks.push(DiffHunk { + old_start: header.before_hunk_start, + old_lines: header.before_hunk_len, + new_start: header.after_hunk_start, + new_lines: header.after_hunk_len, + header: format!( + "@@ -{},{} +{},{} @@", + header.before_hunk_start, + header.before_hunk_len, + header.after_hunk_start, + header.after_hunk_len + ), + }); + + for (kind, content) in entries { + let origin = kind.to_prefix(); + let content_str = String::from_utf8_lossy(content).to_string(); + + let (old_lineno, new_lineno) = match kind { + DiffLineKind::Context => { + let old = Some(self.current_old_lineno); + let new = Some(self.current_new_lineno); + self.current_old_lineno += 1; + self.current_new_lineno += 1; + (old, new) + } + DiffLineKind::Add => { + let new = Some(self.current_new_lineno); + self.current_new_lineno += 1; + (None, new) + } + DiffLineKind::Remove => { + let old = self.current_old_lineno; + self.current_old_lineno += 1; + (Some(old), None) + } + }; + + self.lines.push(DiffLine { + content: content_str, + origin, + old_lineno, + new_lineno, + num_lines: 1, + content_offset: -1, + }); + } + + Ok(()) + } + + fn finish(self) -> (Vec, Vec) { + (self.hunks, self.lines) + } +} + +impl GitBare { + pub fn diff_tree_to_tree( + &self, + old_tree: ObjectId, + new_tree: ObjectId, + opts: Option, + ) -> GitResult { + let repo = self.gix_repo()?; + let options = opts.unwrap_or_default(); + + let old_tree_obj = peel_to_tree(&repo, old_tree)?; + let new_tree_obj = peel_to_tree(&repo, new_tree)?; + + let mut diff_opts = gix::diff::Options::default(); + diff_opts.track_path(); + + let changes = repo + .diff_tree_to_tree(&old_tree_obj, &new_tree_obj, Some(diff_opts)) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut resource_cache = repo + .diff_resource_cache_for_tree_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + let mut deltas = Vec::new(); + let mut stats = DiffStats { + files_changed: 0, + insertions: 0, + deletions: 0, + }; + + for change in &changes { + let location = change.location().to_str().unwrap_or(""); + if !matches_pathspec(&options.pathspec, location) { + continue; + } + + let mut delta = change_to_delta(change); + + // Skip directories — only diff blobs + if change.entry_mode().is_tree() { + stats.files_changed += 1; + resource_cache.clear_resource_cache_keep_allocation(); + deltas.push(delta); + continue; + } + + resource_cache + .set_resource_by_change(change.to_ref(), &repo.objects) + .map_err(|e| GitError::Gix(e.to_string()))?; + + let is_binary = { + use gix::diff::blob::platform::prepare_diff::Operation; + + let prep = resource_cache + .prepare_diff() + .map_err(|e| GitError::Gix(e.to_string()))?; + + match prep.operation { + Operation::InternalDiff { algorithm } => { + let input = prep.interned_input(); + let diff = gix::diff::blob::diff_with_slider_heuristics( + algorithm, &input, + ); + + stats.files_changed += 1; + stats.insertions += diff.count_additions() as usize; + stats.deletions += diff.count_removals() as usize; + + if options.context_lines > 0 { + let ctx = ContextSize::symmetrical( + options.context_lines.max(3), + ); + let collector = HunkCollector::new(); + let unified = gix::diff::blob::UnifiedDiff::new( + &diff, &input, collector, ctx, + ); + let (hunks, lines) = unified + .consume() + .map_err(|e| GitError::Gix(e.to_string()))?; + delta.hunks = hunks; + delta.lines = lines; + } + + false + } + Operation::SourceOrDestinationIsBinary => { + stats.files_changed += 1; + true + } + Operation::ExternalCommand { .. } => { + stats.files_changed += 1; + false + } + } + }; + + if is_binary { + delta.old_file.is_binary = true; + delta.new_file.is_binary = true; + } + + resource_cache.clear_resource_cache_keep_allocation(); + deltas.push(delta); + } + + Ok(DiffResult { stats, deltas }) + } +} diff --git a/lib/git/cmd/diff/mod.rs b/lib/git/cmd/diff/mod.rs new file mode 100644 index 0000000..8090352 --- /dev/null +++ b/lib/git/cmd/diff/mod.rs @@ -0,0 +1,113 @@ +use serde::{Deserialize, Serialize}; + +use crate::cmd::oid::ObjectId; + +pub mod diff_index_to_tree; +pub mod diff_patch; +pub mod diff_stats; +pub mod diff_tree_to_tree; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum DiffDeltaStatus { + Unmodified, + Added, + Deleted, + Modified, + Renamed, + Copied, + Typechange, + Conflicted, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffFile { + pub oid: Option, + pub path: Option, + pub size: u64, + pub is_binary: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffHunk { + pub old_start: u32, + pub old_lines: u32, + pub new_start: u32, + pub new_lines: u32, + pub header: String, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffDelta { + pub status: DiffDeltaStatus, + pub old_file: DiffFile, + pub new_file: DiffFile, + pub nfiles: u16, + pub hunks: Vec, + pub lines: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffLine { + pub content: String, + pub origin: char, + pub old_lineno: Option, + pub new_lineno: Option, + pub num_lines: u32, + pub content_offset: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffStats { + pub files_changed: usize, + pub insertions: usize, + pub deletions: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiffResult { + pub stats: DiffStats, + pub deltas: Vec, +} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SideBySideChangeType { + Unchanged, + Added, + Removed, + Modified, + Empty, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SideBySideLine { + pub left_line_no: Option, + pub right_line_no: Option, + pub left_content: String, + pub right_content: String, + pub change_type: SideBySideChangeType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SideBySideFile { + pub path: String, + pub additions: usize, + pub deletions: usize, + pub is_binary: bool, + pub is_rename: bool, + pub lines: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SideBySideDiffResult { + pub files: Vec, + pub total_additions: usize, + pub total_deletions: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct DiffOptions { + pub context_lines: u32, + pub pathspec: Vec, + pub ignore_whitespace: bool, + pub force_text: bool, + pub reverse: bool, +} diff --git a/lib/git/cmd/fork.rs b/lib/git/cmd/fork.rs new file mode 100644 index 0000000..97c7273 --- /dev/null +++ b/lib/git/cmd/fork.rs @@ -0,0 +1,111 @@ +use std::path::PathBuf; + +use crate::{ + bare::GitBare, + errors::{GitError, GitResult}, +}; + +pub struct ForkRepoParams { + pub namespace: String, + pub repo_name: String, + pub default_branch: String, + pub description: Option, + pub enable_lfs: bool, +} + +impl ForkRepoParams { + pub async fn fork_bare( + storage_root: String, + source_storage_path: String, + params: ForkRepoParams, + ) -> GitResult { + let target_dir = PathBuf::from(&storage_root) + .join("repo") + .join(¶ms.namespace) + .join(¶ms.repo_name); + + if target_dir.exists() { + return Err(GitError::Internal(format!( + "repository directory already exists: {}", + target_dir.display() + ))); + } + + let source_dir = PathBuf::from(&source_storage_path); + if !source_dir.exists() { + return Err(GitError::Internal(format!( + "source repository does not exist: {}", + source_dir.display() + ))); + } + let output = duct::cmd( + "git", + &[ + "clone", + "--bare", + source_dir.to_string_lossy().as_ref(), + target_dir.to_string_lossy().as_ref(), + ], + ) + .stdout_capture() + .stderr_capture() + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_TERMINAL_PROMPT", "0") + .unchecked() + .run()?; + + if !output.status.success() { + std::fs::remove_dir_all(&target_dir).ok(); + return Err(GitError::CommandFailed { + status_code: output.status.code(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + }); + } + + let bare = GitBare { + bare_dir: target_dir.clone(), + }; + let symref_output = bare.git_command_trusted(vec![ + "symbolic-ref".to_string(), + "HEAD".to_string(), + format!("refs/heads/{}", params.default_branch), + ])?; + if !symref_output.success { + return Err(GitError::CommandFailed { + status_code: symref_output.status_code, + stderr: symref_output.stderr_lossy(), + }); + } + let remote_output = bare.git_command_trusted(vec![ + "remote".to_string(), + "add".to_string(), + "upstream".to_string(), + source_dir.to_string_lossy().to_string(), + ])?; + if !remote_output.success { + tracing::warn!( + "failed to add upstream remote: {}", + remote_output.stderr_lossy() + ); + } + if let Some(desc) = ¶ms.description { + let desc_path = target_dir.join("description"); + std::fs::write(&desc_path, desc)?; + } + if params.enable_lfs { + let gitattributes_path = target_dir.join("info").join("attributes"); + if let Some(parent) = gitattributes_path.parent() { + std::fs::create_dir_all(parent)?; + } + let lfs_attributes = "*.psd filter=lfs diff=lfs merge=lfs -text\n\ + *.zip filter=lfs diff=lfs merge=lfs -text\n\ + *.tar filter=lfs diff=lfs merge=lfs -text\n\ + *.gz filter=lfs diff=lfs merge=lfs -text\n\ + *.mp4 filter=lfs diff=lfs merge=lfs -text\n\ + *.mov filter=lfs diff=lfs merge=lfs -text\n"; + std::fs::write(&gitattributes_path, lfs_attributes)?; + } + + Ok(target_dir.to_string_lossy().to_string()) + } +} diff --git a/lib/git/cmd/init.rs b/lib/git/cmd/init.rs new file mode 100644 index 0000000..7341b3e --- /dev/null +++ b/lib/git/cmd/init.rs @@ -0,0 +1,235 @@ +use std::path::PathBuf; + +use crate::{ + bare::GitBare, + errors::{GitError, GitResult}, +}; + +pub struct InitRepositoriesParams { + pub namespace: String, + pub repo_name: String, + pub default_branch: String, + pub description: Option, + pub initialize_with_readme: bool, + pub enable_lfs: bool, +} + +pub struct CloneRepoParams { + pub namespace: String, + pub repo_name: String, + pub source_url: String, +} + +impl CloneRepoParams { + pub async fn clone_bare( + storage_root: String, + params: CloneRepoParams, + ) -> GitResult { + let repo_dir = PathBuf::from(&storage_root) + .join("repo") + .join(¶ms.namespace) + .join(¶ms.repo_name); + + if repo_dir.exists() { + return Err(GitError::Internal(format!( + "repository directory already exists: {}", + repo_dir.display() + ))); + } + + if let Some(parent) = repo_dir.parent() { + std::fs::create_dir_all(parent)?; + } + + // Clone as bare repo from source URL + let output = duct::cmd( + "git", + &[ + "clone", + "--bare", + ¶ms.source_url, + repo_dir.to_string_lossy().as_ref(), + ], + ) + .stdout_capture() + .stderr_capture() + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_TERMINAL_PROMPT", "0") + .unchecked() + .run()?; + + if !output.status.success() { + std::fs::remove_dir_all(&repo_dir).ok(); + return Err(GitError::CommandFailed { + status_code: output.status.code(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + }); + } + + Ok(repo_dir.to_string_lossy().to_string()) + } +} + +impl InitRepositoriesParams { + pub async fn init_bare( + basic_path: String, + params: InitRepositoriesParams, + ) -> GitResult { + let repo_dir = PathBuf::from(&basic_path) + .join("repo") + .join(¶ms.namespace) + .join(¶ms.repo_name); + + if repo_dir.exists() { + return Err(GitError::Internal(format!( + "repository directory already exists: {}", + repo_dir.display() + ))); + } + + std::fs::create_dir_all(&repo_dir)?; + let output = duct::cmd("git", &["init", "--bare"]) + .dir(&repo_dir) + .stdout_capture() + .stderr_capture() + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_TERMINAL_PROMPT", "0") + .unchecked() + .run()?; + + if !output.status.success() { + std::fs::remove_dir_all(&repo_dir).ok(); + return Err(GitError::CommandFailed { + status_code: output.status.code(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + }); + } + + let bare = GitBare { + bare_dir: repo_dir.clone(), + }; + let symref_output = bare.git_command_trusted(vec![ + "symbolic-ref".to_string(), + "HEAD".to_string(), + format!("refs/heads/{}", params.default_branch), + ])?; + if !symref_output.success { + return Err(GitError::CommandFailed { + status_code: symref_output.status_code, + stderr: symref_output.stderr_lossy(), + }); + } + if let Some(desc) = ¶ms.description { + let desc_path = repo_dir.join("description"); + std::fs::write(&desc_path, desc)?; + } + if params.enable_lfs { + let gitattributes_path = repo_dir.join("info").join("attributes"); + if let Some(parent) = gitattributes_path.parent() { + std::fs::create_dir_all(parent)?; + } + let lfs_attributes = "*.psd filter=lfs diff=lfs merge=lfs -text\n\ + *.zip filter=lfs diff=lfs merge=lfs -text\n\ + *.tar filter=lfs diff=lfs merge=lfs -text\n\ + *.gz filter=lfs diff=lfs merge=lfs -text\n\ + *.mp4 filter=lfs diff=lfs merge=lfs -text\n\ + *.mov filter=lfs diff=lfs merge=lfs -text\n"; + std::fs::write(&gitattributes_path, lfs_attributes)?; + } + if params.initialize_with_readme { + init_initial_commit(&bare, ¶ms)?; + } + + Ok(repo_dir.to_string_lossy().to_string()) + } +} + +fn duct_output_to_error(output: &std::process::Output) -> GitError { + GitError::CommandFailed { + status_code: output.status.code(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + } +} + +fn init_initial_commit( + bare: &GitBare, + params: &InitRepositoriesParams, +) -> GitResult<()> { + let tmp_dir = bare.bare_dir.with_extension("tmp-init"); + let clone_output = duct::cmd( + "git", + &[ + "clone", + bare.bare_dir.to_string_lossy().as_ref(), + tmp_dir.to_string_lossy().as_ref(), + ], + ) + .stdout_capture() + .stderr_capture() + .unchecked() + .run()?; + + if !clone_output.status.success() { + return Err(duct_output_to_error(&clone_output)); + } + let checkout_output = + duct::cmd("git", &["checkout", "-b", ¶ms.default_branch]) + .dir(&tmp_dir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run()?; + + if !checkout_output.status.success() { + std::fs::remove_dir_all(&tmp_dir).ok(); + return Err(duct_output_to_error(&checkout_output)); + } + let readme_content = format!( + "# {}\n\n{}", + params.repo_name, + params.description.as_deref().unwrap_or("") + ); + let readme_path = tmp_dir.join("README.md"); + std::fs::write(&readme_path, readme_content)?; + let add_output = duct::cmd("git", &["add", "README.md"]) + .dir(&tmp_dir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run()?; + + if !add_output.status.success() { + std::fs::remove_dir_all(&tmp_dir).ok(); + return Err(duct_output_to_error(&add_output)); + } + let commit_output = duct::cmd("git", &["commit", "-m", "Initial commit"]) + .dir(&tmp_dir) + .stdout_capture() + .stderr_capture() + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_COMMITTER_NAME", "panda") + .env("GIT_COMMITTER_EMAIL", "panda@gitdata.ai") + .env("GIT_AUTHOR_NAME", "panda") + .env("GIT_AUTHOR_EMAIL", "panda@gitdata.ai") + .unchecked() + .run()?; + + if !commit_output.status.success() { + std::fs::remove_dir_all(&tmp_dir).ok(); + return Err(duct_output_to_error(&commit_output)); + } + let push_output = + duct::cmd("git", &["push", "origin", ¶ms.default_branch]) + .dir(&tmp_dir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run()?; + std::fs::remove_dir_all(&tmp_dir).ok(); + + if !push_output.status.success() { + return Err(duct_output_to_error(&push_output)); + } + + Ok(()) +} diff --git a/lib/git/cmd/merge/merge_abort.rs b/lib/git/cmd/merge/merge_abort.rs new file mode 100644 index 0000000..de0758d --- /dev/null +++ b/lib/git/cmd/merge/merge_abort.rs @@ -0,0 +1,13 @@ +use crate::{bare::GitBare, errors::GitResult}; + +impl GitBare { + pub fn merge_abort(&self) -> GitResult<()> { + self.git_command_trusted_stdout(vec![ + "update-ref".to_string(), + "-d".to_string(), + "MERGE_HEAD".to_string(), + ])?; + + Ok(()) + } +} diff --git a/lib/git/cmd/merge/merge_analysis.rs b/lib/git/cmd/merge/merge_analysis.rs new file mode 100644 index 0000000..d0ca1d4 --- /dev/null +++ b/lib/git/cmd/merge/merge_analysis.rs @@ -0,0 +1,76 @@ +use crate::{ + bare::GitBare, + cmd::{ + merge::{MergeAnalysisResult, MergePreferenceResult}, + oid::ObjectId, + }, + errors::GitResult, +}; + +impl GitBare { + pub fn merge_analysis( + &self, + their_commit: ObjectId, + ) -> GitResult<(MergeAnalysisResult, MergePreferenceResult)> { + let repo = self.gix_repo()?; + let their_gix_id: gix::hash::ObjectId = (&their_commit).try_into()?; + let head_id = repo.head_id()?; + let head_gix_id = head_id.detach(); + let is_up_to_date = repo + .merge_base(their_gix_id, head_gix_id) + .ok() + .is_some_and(|base| base.detach() == their_gix_id); + + if is_up_to_date { + return Ok(( + MergeAnalysisResult { + is_none: false, + is_normal: false, + is_up_to_date: true, + is_fast_forward: false, + is_unborn: false, + }, + MergePreferenceResult { + is_none: false, + is_no_fast_forward: false, + is_fastforward_only: false, + }, + )); + } + let is_fast_forward = repo + .merge_base(head_gix_id, their_gix_id) + .ok() + .is_some_and(|base| base.detach() == head_gix_id); + + if is_fast_forward { + return Ok(( + MergeAnalysisResult { + is_none: false, + is_normal: false, + is_up_to_date: false, + is_fast_forward: true, + is_unborn: false, + }, + MergePreferenceResult { + is_none: false, + is_no_fast_forward: false, + is_fastforward_only: false, + }, + )); + } + Ok(( + MergeAnalysisResult { + is_none: false, + is_normal: true, + is_up_to_date: false, + is_fast_forward: false, + is_unborn: false, + }, + MergePreferenceResult { + is_none: false, + is_no_fast_forward: false, + is_fastforward_only: false, + }, + )) + } +} diff --git a/lib/git/cmd/merge/merge_analysis_for_ref.rs b/lib/git/cmd/merge/merge_analysis_for_ref.rs new file mode 100644 index 0000000..77f28ae --- /dev/null +++ b/lib/git/cmd/merge/merge_analysis_for_ref.rs @@ -0,0 +1,81 @@ +use crate::{ + bare::GitBare, + cmd::merge::{MergeAnalysisResult, MergePreferenceResult}, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn merge_analysis_for_ref( + &self, + ref_name: String, + ) -> GitResult<(MergeAnalysisResult, MergePreferenceResult)> { + let repo = self.gix_repo()?; + let their_id = + repo.rev_parse_single(ref_name.as_str()).map_err(|_| { + GitError::RefNotFound(format!( + "reference '{}' could not be resolved", + ref_name + )) + })?; + let their_gix_id = their_id.detach(); + + let head_id = repo.head_id()?; + let head_gix_id = head_id.detach(); + let is_up_to_date = repo + .merge_base(their_gix_id, head_gix_id) + .ok() + .is_some_and(|base| base.detach() == their_gix_id); + + if is_up_to_date { + return Ok(( + MergeAnalysisResult { + is_none: false, + is_normal: false, + is_up_to_date: true, + is_fast_forward: false, + is_unborn: false, + }, + MergePreferenceResult { + is_none: false, + is_no_fast_forward: false, + is_fastforward_only: false, + }, + )); + } + let is_fast_forward = repo + .merge_base(head_gix_id, their_gix_id) + .ok() + .is_some_and(|base| base.detach() == head_gix_id); + + if is_fast_forward { + return Ok(( + MergeAnalysisResult { + is_none: false, + is_normal: false, + is_up_to_date: false, + is_fast_forward: true, + is_unborn: false, + }, + MergePreferenceResult { + is_none: false, + is_no_fast_forward: false, + is_fastforward_only: false, + }, + )); + } + Ok(( + MergeAnalysisResult { + is_none: false, + is_normal: true, + is_up_to_date: false, + is_fast_forward: false, + is_unborn: false, + }, + MergePreferenceResult { + is_none: false, + is_no_fast_forward: false, + is_fastforward_only: false, + }, + )) + } +} diff --git a/lib/git/cmd/merge/merge_base.rs b/lib/git/cmd/merge/merge_base.rs new file mode 100644 index 0000000..e06c82c --- /dev/null +++ b/lib/git/cmd/merge/merge_base.rs @@ -0,0 +1,26 @@ +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn merge_base( + &self, + one: ObjectId, + two: ObjectId, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_one: gix::hash::ObjectId = (&one).try_into()?; + let gix_two: gix::hash::ObjectId = (&two).try_into()?; + + let base = repo.merge_base(gix_one, gix_two).map_err(|_| { + GitError::ObjectNotFound(format!( + "no merge base found between {} and {}", + one, two + )) + })?; + + Ok(ObjectId::new(base.detach().to_hex().to_string())) + } +} diff --git a/lib/git/cmd/merge/merge_base_many.rs b/lib/git/cmd/merge/merge_base_many.rs new file mode 100644 index 0000000..85f789b --- /dev/null +++ b/lib/git/cmd/merge/merge_base_many.rs @@ -0,0 +1,36 @@ +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn merge_base_many( + &self, + commits: Vec, + ) -> GitResult { + if commits.len() < 2 { + return Err(GitError::ParseError( + "merge_base_many requires at least 2 commits".to_string(), + )); + } + + let repo = self.gix_repo()?; + let gix_commits: Vec = commits + .iter() + .map(|id| id.try_into()) + .collect::, _>>()?; + + let first = gix_commits[0]; + let others = &gix_commits[1..]; + + let bases = repo.merge_bases_many(first, others)?; + let base = bases.first().ok_or_else(|| { + GitError::ObjectNotFound( + "no common merge base found for the given commits".to_string(), + ) + })?; + + Ok(ObjectId::new(base.detach().to_hex().to_string())) + } +} diff --git a/lib/git/cmd/merge/merge_base_octopus.rs b/lib/git/cmd/merge/merge_base_octopus.rs new file mode 100644 index 0000000..5debeef --- /dev/null +++ b/lib/git/cmd/merge/merge_base_octopus.rs @@ -0,0 +1,32 @@ +use crate::{ + bare::GitBare, + cmd::oid::ObjectId, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn merge_base_octopus( + &self, + commits: Vec, + ) -> GitResult { + if commits.len() < 2 { + return Err(GitError::ParseError( + "merge_base_octopus requires at least 2 commits".to_string(), + )); + } + + let repo = self.gix_repo()?; + let gix_commits: Vec = commits + .iter() + .map(|id| id.try_into()) + .collect::, _>>()?; + + let base = repo.merge_base_octopus(gix_commits).map_err(|_| { + GitError::ObjectNotFound( + "no octopus merge base found for the given commits".to_string(), + ) + })?; + + Ok(ObjectId::new(base.detach().to_hex().to_string())) + } +} diff --git a/lib/git/cmd/merge/merge_commit.rs b/lib/git/cmd/merge/merge_commit.rs new file mode 100644 index 0000000..a814181 --- /dev/null +++ b/lib/git/cmd/merge/merge_commit.rs @@ -0,0 +1,127 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{ + commit::CommitSignature, merge::MergeOptions, oid::ObjectId, + parse::format_git_timestamp, + }, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MergeCommitParams { + pub their_commit: ObjectId, + pub author: CommitSignature, + pub committer: CommitSignature, + pub message: String, + pub update_ref: Option, + pub options: Option, +} + +impl GitBare { + pub fn merge_commit( + &self, + params: MergeCommitParams, + ) -> GitResult { + let repo = self.gix_repo()?; + let head_id = + repo.head_id().map_err(|e| GitError::Gix(e.to_string()))?; + let head_oid = ObjectId::new(head_id.detach().to_hex().to_string()); + let mut merge_tree_args = + vec!["merge-tree".to_string(), "--write-tree".to_string()]; + if let Some(opts) = ¶ms.options { + if opts.find_renames { + merge_tree_args.push("--find-renames".to_string()); + if opts.rename_threshold > 0 { + merge_tree_args.push(format!( + "--rename-threshold={}", + opts.rename_threshold + )); + } + } + if opts.fail_on_conflict { + merge_tree_args.push("--fail-on-conflict".to_string()); + } + } + merge_tree_args.push(head_oid.as_str().to_string()); + merge_tree_args.push(params.their_commit.as_str().to_string()); + + let merge_tree_output = self.git_command_stdout(merge_tree_args)?; + let tree_oid_str = + merge_tree_output.lines().next().unwrap_or("").trim(); + if tree_oid_str.is_empty() { + return Err(GitError::CommandFailed { + status_code: None, + stderr: "merge-tree produced no output".to_string(), + }); + } + + let tree_oid = ObjectId::new(tree_oid_str); + let commit_tree_args = vec![ + "commit-tree".to_string(), + tree_oid.as_str().to_string(), + "-p".to_string(), + head_oid.as_str().to_string(), + "-p".to_string(), + params.their_commit.as_str().to_string(), + "-F".to_string(), + "-".to_string(), // read message from stdin + ]; + + let author_timestamp = format_git_timestamp( + params.author.time_secs, + params.author.offset_minutes, + ); + let committer_timestamp = format_git_timestamp( + params.committer.time_secs, + params.committer.offset_minutes, + ); + + let commit_output = self.git_command_with( + crate::cmd::command::GitCommandParams::new(commit_tree_args) + .with_stdin(params.message.as_bytes().to_vec()) + .with_env( + "GIT_AUTHOR_NAME".to_string(), + params.author.name.clone(), + ) + .with_env( + "GIT_AUTHOR_EMAIL".to_string(), + params.author.email.clone(), + ) + .with_env("GIT_AUTHOR_DATE".to_string(), author_timestamp) + .with_env( + "GIT_COMMITTER_NAME".to_string(), + params.committer.name.clone(), + ) + .with_env( + "GIT_COMMITTER_EMAIL".to_string(), + params.committer.email.clone(), + ) + .with_env( + "GIT_COMMITTER_DATE".to_string(), + committer_timestamp, + ), + )?; + + let stdout_str = commit_output.stdout_lossy(); + let commit_oid_str = stdout_str.trim(); + if commit_oid_str.is_empty() { + return Err(GitError::CommandFailed { + status_code: commit_output.status_code, + stderr: commit_output.stderr_lossy(), + }); + } + + let result_oid = ObjectId::new(commit_oid_str); + if let Some(ref_name) = ¶ms.update_ref { + self.git_command_trusted_stdout(vec![ + "update-ref".to_string(), + ref_name.clone(), + result_oid.as_str().to_string(), + ])?; + } + + Ok(result_oid) + } +} diff --git a/lib/git/cmd/merge/merge_is_conflicted.rs b/lib/git/cmd/merge/merge_is_conflicted.rs new file mode 100644 index 0000000..9dac48e --- /dev/null +++ b/lib/git/cmd/merge/merge_is_conflicted.rs @@ -0,0 +1,8 @@ +use crate::{bare::GitBare, errors::GitResult}; + +impl GitBare { + pub fn merge_is_conflicted(&self) -> GitResult { + let repo = self.gix_repo()?; + Ok(repo.try_find_reference("MERGE_HEAD")?.is_some()) + } +} diff --git a/lib/git/cmd/merge/merge_tree.rs b/lib/git/cmd/merge/merge_tree.rs new file mode 100644 index 0000000..fd7da4a --- /dev/null +++ b/lib/git/cmd/merge/merge_tree.rs @@ -0,0 +1,76 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{merge::MergeOptions, oid::ObjectId}, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MergeTreeResult { + pub tree_id: ObjectId, + pub has_conflicts: bool, +} + +impl GitBare { + pub fn merge_tree( + &self, + ours: ObjectId, + theirs: ObjectId, + options: Option, + ) -> GitResult { + let mut args = + vec!["merge-tree".to_string(), "--write-tree".to_string()]; + + if let Some(opts) = &options { + if opts.find_renames { + args.push("--find-renames".to_string()); + if opts.rename_threshold > 0 { + args.push(format!( + "--rename-threshold={}", + opts.rename_threshold + )); + } + } + if opts.fail_on_conflict { + args.push("--fail-on-conflict".to_string()); + } + if opts.no_recursive { + args.push("--no-recursive".to_string()); + } + if opts.target_limit > 0 { + args.push(format!("--target-limit={}", opts.target_limit)); + } + } + + args.push(ours.as_str().to_string()); + args.push(theirs.as_str().to_string()); + + let output = self.git_command_with( + crate::cmd::command::GitCommandParams::new(args).unchecked(), + )?; + + let stdout = output.stdout_lossy(); + let mut lines = stdout.lines(); + let tree_oid_str = lines.next().unwrap_or("").trim(); + + if tree_oid_str.is_empty() { + if !output.success { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + return Err(GitError::ParseError( + "merge-tree produced no tree OID output".to_string(), + )); + } + let has_conflicts = + !output.success || lines.any(|l| !l.trim().is_empty()); + + Ok(MergeTreeResult { + tree_id: ObjectId::new(tree_oid_str), + has_conflicts, + }) + } +} diff --git a/lib/git/cmd/merge/mergehead_list.rs b/lib/git/cmd/merge/mergehead_list.rs new file mode 100644 index 0000000..8f0ff19 --- /dev/null +++ b/lib/git/cmd/merge/mergehead_list.rs @@ -0,0 +1,20 @@ +use crate::{bare::GitBare, cmd::oid::ObjectId, errors::GitResult}; + +impl GitBare { + pub fn mergehead_list(&self) -> GitResult> { + let merge_head_path = self.bare_dir.canonicalize()?.join("MERGE_HEAD"); + + if !merge_head_path.exists() { + return Ok(Vec::new()); + } + + let content = std::fs::read_to_string(&merge_head_path)?; + let oids: Vec = content + .lines() + .filter(|line| !line.trim().is_empty()) + .map(|line| ObjectId::new(line.trim())) + .collect(); + + Ok(oids) + } +} diff --git a/lib/git/cmd/merge/mod.rs b/lib/git/cmd/merge/mod.rs new file mode 100644 index 0000000..c2e9c9d --- /dev/null +++ b/lib/git/cmd/merge/mod.rs @@ -0,0 +1,43 @@ +use serde::{Deserialize, Serialize}; + +pub mod merge_abort; +pub mod merge_analysis; +pub mod merge_analysis_for_ref; +pub mod merge_base; +pub mod merge_base_many; +pub mod merge_base_octopus; +pub mod merge_commit; +pub mod merge_is_conflicted; +pub mod merge_tree; +pub mod mergehead_list; +pub mod squash_commit; + +pub use merge_commit::MergeCommitParams; +pub use merge_tree::MergeTreeResult; +pub use squash_commit::SquashCommitParams; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MergeAnalysisResult { + pub is_none: bool, + pub is_normal: bool, + pub is_up_to_date: bool, + pub is_fast_forward: bool, + pub is_unborn: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MergePreferenceResult { + pub is_none: bool, + pub is_no_fast_forward: bool, + pub is_fastforward_only: bool, +} +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct MergeOptions { + pub find_renames: bool, + pub fail_on_conflict: bool, + pub skip_reuc: bool, + pub no_recursive: bool, + pub rename_threshold: u32, + pub target_limit: u32, + pub recursion_limit: u32, +} diff --git a/lib/git/cmd/merge/squash_commit.rs b/lib/git/cmd/merge/squash_commit.rs new file mode 100644 index 0000000..31e4079 --- /dev/null +++ b/lib/git/cmd/merge/squash_commit.rs @@ -0,0 +1,120 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{merge::MergeOptions, oid::ObjectId, parse::format_git_timestamp}, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SquashCommitParams { + pub their_commit: ObjectId, + pub options: Option, +} + +impl GitBare { + pub fn squash_commit( + &self, + params: SquashCommitParams, + ) -> GitResult { + let repo = self.gix_repo()?; + let head_id = + repo.head_id().map_err(|e| GitError::Gix(e.to_string()))?; + let head_oid = ObjectId::new(head_id.detach().to_hex().to_string()); + let mut merge_tree_args = + vec!["merge-tree".to_string(), "--write-tree".to_string()]; + if let Some(opts) = ¶ms.options { + if opts.find_renames { + merge_tree_args.push("--find-renames".to_string()); + if opts.rename_threshold > 0 { + merge_tree_args.push(format!( + "--rename-threshold={}", + opts.rename_threshold + )); + } + } + if opts.fail_on_conflict { + merge_tree_args.push("--fail-on-conflict".to_string()); + } + } + merge_tree_args.push(head_oid.as_str().to_string()); + merge_tree_args.push(params.their_commit.as_str().to_string()); + + let merge_tree_output = self.git_command_with( + crate::cmd::command::GitCommandParams::new(merge_tree_args) + .unchecked(), + )?; + + let stdout = merge_tree_output.stdout_lossy(); + let tree_oid_str = stdout.lines().next().unwrap_or("").trim(); + + if tree_oid_str.is_empty() { + if !merge_tree_output.success { + return Err(GitError::CommandFailed { + status_code: merge_tree_output.status_code, + stderr: merge_tree_output.stderr_lossy(), + }); + } + return Err(GitError::ParseError( + "merge-tree produced no tree OID".to_string(), + )); + } + + let tree_oid = ObjectId::new(tree_oid_str); + let their_info = self.commit_info(params.their_commit.clone())?; + let squash_message = their_info.message.clone(); + let commit_tree_args = vec![ + "commit-tree".to_string(), + tree_oid.as_str().to_string(), + "-p".to_string(), + head_oid.as_str().to_string(), + "-F".to_string(), + "-".to_string(), + ]; + let author_timestamp = format_git_timestamp( + their_info.author.time_secs, + their_info.author.offset_minutes, + ); + let committer_timestamp = format_git_timestamp( + their_info.committer.time_secs, + their_info.committer.offset_minutes, + ); + + let commit_output = self.git_command_with( + crate::cmd::command::GitCommandParams::new(commit_tree_args) + .with_stdin(squash_message.as_bytes().to_vec()) + .with_env( + "GIT_AUTHOR_NAME".to_string(), + their_info.author.name.clone(), + ) + .with_env( + "GIT_AUTHOR_EMAIL".to_string(), + their_info.author.email.clone(), + ) + .with_env("GIT_AUTHOR_DATE".to_string(), author_timestamp) + .with_env( + "GIT_COMMITTER_NAME".to_string(), + their_info.committer.name.clone(), + ) + .with_env( + "GIT_COMMITTER_EMAIL".to_string(), + their_info.committer.email.clone(), + ) + .with_env( + "GIT_COMMITTER_DATE".to_string(), + committer_timestamp, + ), + )?; + + let stdout_str = commit_output.stdout_lossy(); + let commit_oid_str = stdout_str.trim(); + if commit_oid_str.is_empty() { + return Err(GitError::CommandFailed { + status_code: commit_output.status_code, + stderr: commit_output.stderr_lossy(), + }); + } + + Ok(ObjectId::new(commit_oid_str)) + } +} diff --git a/lib/git/cmd/mod.rs b/lib/git/cmd/mod.rs new file mode 100644 index 0000000..a5fc408 --- /dev/null +++ b/lib/git/cmd/mod.rs @@ -0,0 +1,15 @@ +pub mod archive; +pub mod blame; +pub mod blob; +pub mod branch; +pub mod command; +pub mod commit; +pub mod diff; +pub mod fork; +pub mod init; +pub mod merge; +pub mod oid; +pub mod parse; +pub mod tag; +pub mod tagger; +pub mod tree; diff --git a/lib/git/cmd/oid.rs b/lib/git/cmd/oid.rs new file mode 100644 index 0000000..beb134a --- /dev/null +++ b/lib/git/cmd/oid.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +pub struct ObjectId(pub String); + +impl ObjectId { + pub fn new(hex: impl AsRef) -> Self { + Self(hex.as_ref().to_lowercase()) + } + pub fn as_str(&self) -> &str { + &self.0 + } +} +impl std::fmt::Display for ObjectId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl AsRef for ObjectId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl TryFrom<&ObjectId> for gix::hash::ObjectId { + type Error = crate::errors::GitError; + + fn try_from(id: &ObjectId) -> Result { + gix::hash::ObjectId::from_hex(id.as_str().as_bytes()).map_err(|e| { + crate::errors::GitError::InvalidOid(format!("invalid hex oid: {e}")) + }) + } +} diff --git a/lib/git/cmd/parse.rs b/lib/git/cmd/parse.rs new file mode 100644 index 0000000..e67f37c --- /dev/null +++ b/lib/git/cmd/parse.rs @@ -0,0 +1,45 @@ +pub fn parse_timezone_offset(tz: &str) -> crate::errors::GitResult { + use crate::errors::GitError; + + if tz.len() != 5 { + return Err(GitError::ParseError(format!( + "invalid timezone format: {tz}" + ))); + } + + let sign: i32 = if tz.starts_with('+') { + 1 + } else if tz.starts_with('-') { + -1 + } else { + return Err(GitError::ParseError(format!( + "invalid timezone sign: {tz}" + ))); + }; + + let hours: i32 = tz[1..3].parse().map_err(|_| { + GitError::ParseError(format!("invalid timezone hours: {tz}")) + })?; + let minutes: i32 = tz[3..5].parse().map_err(|_| { + GitError::ParseError(format!("invalid timezone minutes: {tz}")) + })?; + + Ok(sign * (hours * 60 + minutes)) +} +pub fn parse_iso_timezone(iso_date: &str) -> crate::errors::GitResult { + let tz_part = + iso_date.rsplit_once(' ').map(|(_, tz)| tz).ok_or_else(|| { + crate::errors::GitError::ParseError(format!( + "ISO date missing timezone: {iso_date}" + )) + })?; + + parse_timezone_offset(tz_part) +} +pub fn format_git_timestamp(secs: i64, offset_minutes: i32) -> String { + let sign = if offset_minutes >= 0 { '+' } else { '-' }; + let abs_offset = offset_minutes.abs(); + let hours = abs_offset / 60; + let mins = abs_offset % 60; + format!("{} {}{:02}{:02}", secs, sign, hours, mins) +} diff --git a/lib/git/cmd/tag/mod.rs b/lib/git/cmd/tag/mod.rs new file mode 100644 index 0000000..bc5eae4 --- /dev/null +++ b/lib/git/cmd/tag/mod.rs @@ -0,0 +1,33 @@ +use serde::{Deserialize, Serialize}; + +use crate::cmd::oid::ObjectId; + +pub mod tag_count; +pub mod tag_delete; +pub mod tag_info; +pub mod tag_init; +pub mod tag_list; +pub mod tag_rename; +pub mod tag_summary; +pub mod tag_upmsg; + +pub use tag_delete::TagDeleteParams; +pub use tag_init::TagInitParams; +pub use tag_rename::TagRenameParams; +pub use tag_upmsg::TagUpdateMessageParams; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TagItem { + pub name: String, + pub oid: ObjectId, + pub target: ObjectId, + pub is_annotated: bool, + pub message: Option, + pub tagger: Option, + pub tagger_email: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TagSummary { + pub total_count: usize, +} diff --git a/lib/git/cmd/tag/tag_count.rs b/lib/git/cmd/tag/tag_count.rs new file mode 100644 index 0000000..55a2794 --- /dev/null +++ b/lib/git/cmd/tag/tag_count.rs @@ -0,0 +1,9 @@ +use crate::{bare::GitBare, errors::GitResult}; + +impl GitBare { + pub fn tag_count(&self) -> GitResult { + let repo = self.gix_repo()?; + let platform = repo.references()?; + Ok(platform.tags()?.count()) + } +} diff --git a/lib/git/cmd/tag/tag_delete.rs b/lib/git/cmd/tag/tag_delete.rs new file mode 100644 index 0000000..6a38be9 --- /dev/null +++ b/lib/git/cmd/tag/tag_delete.rs @@ -0,0 +1,38 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::command::GitCommandParams, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TagDeleteParams { + pub name: String, +} + +impl GitBare { + pub fn tag_delete(&self, params: TagDeleteParams) -> GitResult<()> { + let cmd_params = GitCommandParams::new(vec![ + "tag".to_string(), + "-d".to_string(), + params.name.clone(), + ]) + .trusted() + .unchecked(); + + let output = self.git_command_with(cmd_params)?; + if !output.success { + let stderr = output.stderr_lossy(); + if stderr.contains("not found") { + return Err(GitError::RefNotFound(params.name)); + } + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr, + }); + } + + Ok(()) + } +} diff --git a/lib/git/cmd/tag/tag_info.rs b/lib/git/cmd/tag/tag_info.rs new file mode 100644 index 0000000..b831bf4 --- /dev/null +++ b/lib/git/cmd/tag/tag_info.rs @@ -0,0 +1,70 @@ +use crate::{ + bare::GitBare, + cmd::{oid::ObjectId, tag::TagItem}, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn tag_info(&self, name: String) -> GitResult { + let repo = self.gix_repo()?; + let ref_str = format!("refs/tags/{name}"); + let mut reference = repo.find_reference(ref_str.as_str())?; + + let short_name = reference.name().shorten().to_string(); + let direct_id_hex = reference + .target() + .try_id() + .map(|id| id.to_hex().to_string()) + .ok_or_else(|| { + GitError::Internal("tag ref has no direct target".to_string()) + })?; + + let direct_gix_id: gix::hash::ObjectId = + gix::hash::ObjectId::from_hex(direct_id_hex.as_bytes()) + .map_err(|e| GitError::InvalidOid(e.to_string()))?; + let obj_header = repo.find_header(direct_gix_id)?; + let is_annotated = obj_header.kind() == gix::object::Kind::Tag; + let peeled_id = reference.peel_to_id()?; + let peeled_hex = peeled_id.detach().to_hex().to_string(); + + let (oid, target_oid) = if is_annotated { + (ObjectId::new(direct_id_hex), ObjectId::new(peeled_hex)) + } else { + let target_oid = ObjectId::new(direct_id_hex); + (target_oid.clone(), target_oid) + }; + + let (message, tagger, tagger_email) = if is_annotated { + let tag_obj = repo.find_tag(direct_gix_id)?; + let decoded = tag_obj.decode()?; + + let message = decoded.message.to_string().trim_end().to_string(); + let msg = if message.is_empty() { + None + } else { + Some(message) + }; + + let (tg_name, tg_email) = decoded + .tagger() + .ok() + .flatten() + .map(|s| (Some(s.name.to_string()), Some(s.email.to_string()))) + .unwrap_or((None, None)); + + (msg, tg_name, tg_email) + } else { + (None, None, None) + }; + + Ok(TagItem { + name: short_name, + oid, + target: target_oid, + is_annotated, + message, + tagger, + tagger_email, + }) + } +} diff --git a/lib/git/cmd/tag/tag_init.rs b/lib/git/cmd/tag/tag_init.rs new file mode 100644 index 0000000..feb2e15 --- /dev/null +++ b/lib/git/cmd/tag/tag_init.rs @@ -0,0 +1,83 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{command::GitCommandParams, oid::ObjectId, tagger::GitTagger}, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TagInitParams { + pub name: String, + pub target: ObjectId, + pub message: Option, + pub tagger: Option, + pub force: bool, +} + +impl GitBare { + pub fn tag_init(&self, params: TagInitParams) -> GitResult { + if let Some(message) = ¶ms.message { + let mut cmd_args = vec!["tag".to_string(), "-a".to_string()]; + if params.force { + cmd_args.push("--force".to_string()); + } + cmd_args.push(params.name.clone()); + cmd_args.push(params.target.as_str().to_string()); + cmd_args.push("-F".to_string()); + cmd_args.push("-".to_string()); + + let mut cmd_params = GitCommandParams::new(cmd_args) + .trusted() + .with_stdin(message.as_bytes().to_vec()); + + if let Some(tagger) = ¶ms.tagger { + cmd_params = cmd_params + .with_env( + "GIT_COMMITTER_NAME".to_string(), + tagger.name.clone(), + ) + .with_env( + "GIT_COMMITTER_EMAIL".to_string(), + tagger.email.clone(), + ); + } + + let output = self.git_command_with(cmd_params)?; + if !output.success { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + } else { + let mut cmd_args = vec!["tag".to_string()]; + if params.force { + cmd_args.push("--force".to_string()); + } + cmd_args.push(params.name.clone()); + cmd_args.push(params.target.as_str().to_string()); + + let output = self.git_command_trusted(cmd_args)?; + if !output.success { + return Err(GitError::CommandFailed { + status_code: output.status_code, + stderr: output.stderr_lossy(), + }); + } + } + let repo = self.gix_repo()?; + let refname = format!("refs/tags/{}", params.name); + let ref_name_gix = gix::refs::PartialName::try_from(refname.as_str()) + .map_err(|e| GitError::Gix(e.to_string()))?; + let tag_ref = repo + .find_reference(&ref_name_gix) + .map_err(|e| GitError::Gix(e.to_string()))?; + let target = tag_ref.target(); + let target_id = target.try_id().ok_or_else(|| { + GitError::Gix(format!("tag ref '{}' has no direct target", refname)) + })?; + + Ok(ObjectId::new(target_id.to_hex().to_string())) + } +} diff --git a/lib/git/cmd/tag/tag_list.rs b/lib/git/cmd/tag/tag_list.rs new file mode 100644 index 0000000..9d773f1 --- /dev/null +++ b/lib/git/cmd/tag/tag_list.rs @@ -0,0 +1,59 @@ +use crate::{ + bare::GitBare, + cmd::{oid::ObjectId, tag::TagItem}, + errors::GitResult, +}; + +impl GitBare { + pub fn tag_list(&self) -> GitResult> { + let repo = self.gix_repo()?; + let platform = repo.references()?; + let tags_iter = platform.tags()?; + + let mut items = Vec::new(); + for ref_result in tags_iter { + let mut reference = ref_result?; + let name = reference.name().shorten().to_string(); + let direct_id_hex = reference + .target() + .try_id() + .map(|id| id.to_hex().to_string()) + .ok_or_else(|| { + crate::errors::GitError::Internal( + "tag ref has no direct target".to_string(), + ) + })?; + + let direct_gix_id: gix::hash::ObjectId = + gix::hash::ObjectId::from_hex(direct_id_hex.as_bytes()) + .map_err(|e| { + crate::errors::GitError::InvalidOid(e.to_string()) + })?; + let is_annotated = repo + .find_header(direct_gix_id) + .ok() + .is_some_and(|h| h.kind() == gix::object::Kind::Tag); + let peeled_id = reference.peel_to_id()?; + let peeled_hex = peeled_id.detach().to_hex().to_string(); + + let (oid, target_oid) = if is_annotated { + (ObjectId::new(direct_id_hex), ObjectId::new(peeled_hex)) + } else { + let target_oid = ObjectId::new(direct_id_hex); + (target_oid.clone(), target_oid) + }; + + items.push(TagItem { + name, + oid, + target: target_oid, + is_annotated, + message: None, + tagger: None, + tagger_email: None, + }); + } + + Ok(items) + } +} diff --git a/lib/git/cmd/tag/tag_rename.rs b/lib/git/cmd/tag/tag_rename.rs new file mode 100644 index 0000000..b9cba40 --- /dev/null +++ b/lib/git/cmd/tag/tag_rename.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{ + tag::{tag_delete::TagDeleteParams, tag_init::TagInitParams}, + tagger::GitTagger, + }, + errors::GitResult, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TagRenameParams { + pub old_name: String, + pub new_name: String, + pub force: bool, +} + +impl GitBare { + pub fn tag_rename(&self, params: TagRenameParams) -> GitResult<()> { + let old_info = self.tag_info(params.old_name.clone())?; + let init_params = TagInitParams { + name: params.new_name.clone(), + target: old_info.target.clone(), + message: old_info.message.clone(), + tagger: if old_info.is_annotated { + Some(GitTagger { + name: old_info.tagger.clone().unwrap_or_default(), + email: old_info.tagger_email.clone().unwrap_or_default(), + }) + } else { + None + }, + force: params.force, + }; + + self.tag_init(init_params)?; + self.tag_delete(TagDeleteParams { + name: params.old_name.clone(), + })?; + + Ok(()) + } +} diff --git a/lib/git/cmd/tag/tag_summary.rs b/lib/git/cmd/tag/tag_summary.rs new file mode 100644 index 0000000..8ca06d9 --- /dev/null +++ b/lib/git/cmd/tag/tag_summary.rs @@ -0,0 +1,8 @@ +use crate::{bare::GitBare, cmd::tag::TagSummary, errors::GitResult}; + +impl GitBare { + pub fn tag_summary(&self) -> GitResult { + let total_count = self.tag_count()?; + Ok(TagSummary { total_count }) + } +} diff --git a/lib/git/cmd/tag/tag_upmsg.rs b/lib/git/cmd/tag/tag_upmsg.rs new file mode 100644 index 0000000..bc4b331 --- /dev/null +++ b/lib/git/cmd/tag/tag_upmsg.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + bare::GitBare, + cmd::{ + oid::ObjectId, + tag::{tag_delete::TagDeleteParams, tag_init::TagInitParams}, + tagger::GitTagger, + }, + errors::GitResult, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TagUpdateMessageParams { + pub name: String, + pub message: String, + pub tagger: GitTagger, + pub force: bool, +} + +impl GitBare { + pub fn tag_update_message( + &self, + params: TagUpdateMessageParams, + ) -> GitResult { + let current_info = self.tag_info(params.name.clone())?; + self.tag_delete(TagDeleteParams { + name: params.name.clone(), + })?; + let init_params = TagInitParams { + name: params.name.clone(), + target: current_info.target.clone(), + message: Some(params.message.clone()), + tagger: Some(params.tagger.clone()), + force: params.force, + }; + + self.tag_init(init_params) + } +} diff --git a/lib/git/cmd/tagger.rs b/lib/git/cmd/tagger.rs new file mode 100644 index 0000000..29cc03c --- /dev/null +++ b/lib/git/cmd/tagger.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GitTagger { + pub name: String, + pub email: String, +} diff --git a/lib/git/cmd/tree/mod.rs b/lib/git/cmd/tree/mod.rs new file mode 100644 index 0000000..b8a1187 --- /dev/null +++ b/lib/git/cmd/tree/mod.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; + +use crate::cmd::oid::ObjectId; + +pub mod resolve_tree; +pub mod tree_entry; +pub mod tree_entry_by_path_from_commit; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TreeInfo { + pub oid: ObjectId, + pub entry_count: usize, + pub is_empty: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TreeEntry { + pub name: String, + pub oid: ObjectId, + pub kind: TreeKind, + pub filemode: u32, + pub is_binary: bool, + pub is_lfs: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum TreeKind { + #[serde(rename = "blob")] + Blob, + #[serde(rename = "tree")] + Tree, + #[serde(rename = "lfs_pointer")] + LfsPointer, +} diff --git a/lib/git/cmd/tree/resolve_tree.rs b/lib/git/cmd/tree/resolve_tree.rs new file mode 100644 index 0000000..282222f --- /dev/null +++ b/lib/git/cmd/tree/resolve_tree.rs @@ -0,0 +1,29 @@ +use crate::{ + bare::GitBare, + cmd::{oid::ObjectId, tree::TreeInfo}, + errors::{GitError, GitResult}, +}; + +impl GitBare { + pub fn resolve_tree(&self, oid: ObjectId) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&oid).try_into()?; + let header = repo.find_header(gix_id)?; + if header.kind() != gix::object::Kind::Tree { + return Err(GitError::ParseError(format!( + "object {} is not a tree (type: {})", + oid, + header.kind() + ))); + } + + let gix_tree = repo.find_tree(gix_id)?; + let entry_count = gix_tree.iter().count(); + + Ok(TreeInfo { + oid, + entry_count, + is_empty: entry_count == 0, + }) + } +} diff --git a/lib/git/cmd/tree/tree_entry.rs b/lib/git/cmd/tree/tree_entry.rs new file mode 100644 index 0000000..9ff9614 --- /dev/null +++ b/lib/git/cmd/tree/tree_entry.rs @@ -0,0 +1,143 @@ +use crate::{ + bare::GitBare, + cmd::{ + oid::ObjectId, + tree::{TreeEntry, TreeKind}, + }, + errors::{GitError, GitResult}, +}; + +const LFS_POINTER_PREFIX: &[u8] = b"version https://git-lfs.github.com/spec/v"; +const LFS_OID_MARKER: &[u8] = b"\noid sha256:"; +const LFS_POINTER_MAX_SIZE: usize = 512; + +impl GitBare { + fn blob_is_lfs_pointer(&self, id: ObjectId) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&id).try_into()?; + let blob = repo + .find_blob(gix_id) + .map_err(|_| GitError::ObjectNotFound(id.as_str().to_string()))?; + + if blob.data.len() > LFS_POINTER_MAX_SIZE { + return Ok(false); + } + + let data = &blob.data; + Ok(data.starts_with(LFS_POINTER_PREFIX) + && contains_subslice(data, LFS_OID_MARKER)) + } + + pub fn tree_entries(&self, tree: ObjectId) -> GitResult> { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&tree).try_into()?; + let gix_tree = repo.find_tree(gix_id)?; + + let mut entries = Vec::new(); + for entry_result in gix_tree.iter() { + let entry = entry_result?; + let name = entry.inner.filename.to_string(); + let oid = ObjectId::new(entry.inner.oid.to_hex().to_string()); + + let is_tree = entry.inner.mode.is_tree(); + let is_lfs = if !is_tree { + self.blob_is_lfs_pointer(oid.clone()).unwrap_or(false) + } else { + false + }; + + let kind = if is_tree { + TreeKind::Tree + } else if is_lfs { + TreeKind::LfsPointer + } else { + TreeKind::Blob + }; + + let filemode = entry.inner.mode.value() as u32; + + let is_binary = match kind { + TreeKind::LfsPointer => false, + TreeKind::Blob => { + self.blob_is_binary(oid.clone()).unwrap_or(false) + } + TreeKind::Tree => false, + }; + + entries.push(TreeEntry { + name, + oid, + kind, + filemode, + is_binary, + is_lfs, + }); + } + + Ok(entries) + } + + pub fn tree_entry_by_path( + &self, + tree: ObjectId, + path: String, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&tree).try_into()?; + let gix_tree = repo.find_tree(gix_id)?; + + let entry = gix_tree.find_entry(path.as_str()).ok_or_else(|| { + GitError::ObjectNotFound(format!( + "path '{}' not found in tree {}", + path, tree + )) + })?; + + let name = entry.inner.filename.to_string(); + let oid = ObjectId::new(entry.inner.oid.to_hex().to_string()); + + let is_tree = entry.inner.mode.is_tree(); + let is_lfs = if !is_tree { + self.blob_is_lfs_pointer(oid.clone()).unwrap_or(false) + } else { + false + }; + + let kind = if is_tree { + TreeKind::Tree + } else if is_lfs { + TreeKind::LfsPointer + } else { + TreeKind::Blob + }; + + let filemode = entry.inner.mode.value() as u32; + + let is_binary = match kind { + TreeKind::LfsPointer => false, + TreeKind::Blob => self.blob_is_binary(oid.clone()).unwrap_or(false), + TreeKind::Tree => false, + }; + + Ok(TreeEntry { + name, + oid, + kind, + filemode, + is_binary, + is_lfs, + }) + } +} + +fn contains_subslice(data: &[u8], pattern: &[u8]) -> bool { + if pattern.len() > data.len() { + return false; + } + for window in data.windows(pattern.len()) { + if window == pattern { + return true; + } + } + false +} diff --git a/lib/git/cmd/tree/tree_entry_by_path_from_commit.rs b/lib/git/cmd/tree/tree_entry_by_path_from_commit.rs new file mode 100644 index 0000000..7fdcf9f --- /dev/null +++ b/lib/git/cmd/tree/tree_entry_by_path_from_commit.rs @@ -0,0 +1,22 @@ +use crate::{ + bare::GitBare, + cmd::{oid::ObjectId, tree::TreeEntry}, + errors::GitResult, +}; + +impl GitBare { + pub fn tree_entry_by_path_from_commit( + &self, + commit: ObjectId, + path: String, + ) -> GitResult { + let repo = self.gix_repo()?; + let gix_id: gix::hash::ObjectId = (&commit).try_into()?; + + let gix_commit = repo.find_commit(gix_id)?; + let tree_id = gix_commit.tree_id()?.detach(); + let tree_oid = ObjectId::new(tree_id.to_hex().to_string()); + + self.tree_entry_by_path(tree_oid, path) + } +} diff --git a/lib/git/errors.rs b/lib/git/errors.rs new file mode 100644 index 0000000..97bad42 --- /dev/null +++ b/lib/git/errors.rs @@ -0,0 +1,81 @@ +pub type GitResult = Result; + +#[derive(Debug, thiserror::Error)] +pub enum GitError { + #[error("repository is not bare")] + NotBareRepository, + #[error("git command failed with status {status_code:?}: {stderr}")] + CommandFailed { + status_code: Option, + stderr: String, + }, + #[error("unsafe git command rejected: {0}")] + UnsafeCommand(String), + #[error("object not found: {0}")] + ObjectNotFound(String), + #[error("reference not found: {0}")] + RefNotFound(String), + #[error("parse error: {0}")] + ParseError(String), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error("gix error: {0}")] + Gix(String), + #[error("database error: {0}")] + DatabaseError(db::sqlx::Error), + #[error("repository not found")] + RepoNotFound, + #[error("internal error: {0}")] + Internal(String), + #[error("not found: {0}")] + NotFound(String), + #[error("invalid oid: {0}")] + InvalidOid(String), + #[error("locked: {0}")] + Locked(String), + #[error("permission denied: {0}")] + PermissionDenied(String), + #[error("authentication failed: {0}")] + AuthFailed(String), +} + +impl From for GitError { + fn from(e: db::sqlx::Error) -> Self { + GitError::DatabaseError(e) + } +} + +macro_rules! impl_gix_error { + ($err_type:path) => { + impl From<$err_type> for GitError { + fn from(e: $err_type) -> Self { + GitError::Gix(e.to_string()) + } + } + }; +} + +impl_gix_error!(gix::object::find::existing::Error); +impl_gix_error!(gix::object::find::existing::with_conversion::Error); +impl_gix_error!(gix::object::find::Error); +impl_gix_error!(gix::reference::iter::Error); +impl_gix_error!(gix::reference::iter::init::Error); +impl_gix_error!(gix::reference::find::existing::Error); +impl_gix_error!(gix::reference::find::Error); +impl_gix_error!(gix::reference::head_id::Error); +impl_gix_error!(gix::repository::merge_bases_many::Error); +impl_gix_error!(gix::reference::peel::Error); +impl_gix_error!(gix::repository::blame_file::Error); +impl_gix_error!(gix::blame::Error); +impl_gix_error!(gix::revision::walk::Error); +impl_gix_error!(gix::revision::walk::iter::Error); +impl_gix_error!(gix::revision::spec::parse::single::Error); +impl_gix_error!(gix::open::Error); +impl_gix_error!(gix::objs::decode::Error); +impl_gix_error!(gix::date::Error); + +impl From> for GitError { + fn from(e: Box) -> Self { + GitError::Gix(e.to_string()) + } +} diff --git a/lib/git/graphql/blob.rs b/lib/git/graphql/blob.rs new file mode 100644 index 0000000..ccf5552 --- /dev/null +++ b/lib/git/graphql/blob.rs @@ -0,0 +1,93 @@ +use std::time::Duration; + +use juniper::graphql_object; + +use crate::{ + cmd::{ + blob::{blob_load::BlobLoadParams, blob_size::BlobSizeParams}, + oid::ObjectId, + }, + graphql::{ + GraphqlContext, + cache_helper::{IMMUTABLE_TTL, cache_key, cached_json}, + }, +}; + +const BLOB_SIZE_TTL: Duration = IMMUTABLE_TTL; + +#[derive(Clone, serde::Serialize, serde::Deserialize)] +pub struct BlobResult { + pub oid: String, + pub size: i32, + pub is_binary: bool, +} + +#[graphql_object(context = GraphqlContext)] +impl BlobResult { + fn oid(&self) -> &str { + &self.oid + } + fn size(&self) -> i32 { + self.size + } + fn is_binary(&self) -> bool { + self.is_binary + } +} + +pub async fn resolve_blob_size( + ctx: &GraphqlContext, + oid: String, +) -> anyhow::Result { + let key = cache_key("query:git:blob_size", &[&oid]); + cached_json(&ctx.cache, &key, BLOB_SIZE_TTL, || { + let repo = ctx.repo.clone(); + let oid_obj = ObjectId::new(&oid); + async move { + let size = tokio::task::spawn_blocking(move || { + repo.blob_size(&BlobSizeParams { + id: oid_obj, + path: String::new(), + }) + }) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(size as i32) + } + }) + .await +} + +pub async fn resolve_blob( + ctx: &GraphqlContext, + oid: String, +) -> anyhow::Result { + let repo = ctx.repo.clone(); + let oid_obj = ObjectId::new(&oid); + let result = tokio::task::spawn_blocking(move || { + let size = repo.blob_size(&BlobSizeParams { + id: oid_obj.clone(), + path: String::new(), + })?; + let loaded = repo.blob_load(&BlobLoadParams { + id: oid_obj, + path: String::new(), + })?; + Ok::<(u64, Vec, String), crate::errors::GitError>(( + size, + loaded.blob, + oid, + )) + }) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + + let (size, bytes, oid_str) = result; + let is_binary = bytes.contains(&0); + + Ok(BlobResult { + oid: oid_str, + size: size as i32, + is_binary, + }) +} diff --git a/lib/git/graphql/branch.rs b/lib/git/graphql/branch.rs new file mode 100644 index 0000000..e3d61d1 --- /dev/null +++ b/lib/git/graphql/branch.rs @@ -0,0 +1,173 @@ +use std::time::Duration; + +use juniper::graphql_object; +use serde::{Deserialize, Serialize}; + +use crate::{ + cmd::branch::{branch_list::BranchListItem, branch_summary::BranchSummary}, + graphql::{ + GraphqlContext, + cache_helper::{ + MUTABLE_TTL, cached_json, mutable_cache_key, repo_revision, + }, + }, +}; + +const BRANCH_TTL: Duration = MUTABLE_TTL; + +#[derive(Clone, Serialize, Deserialize)] +pub struct BranchGql { + pub name: String, + pub oid: String, + pub is_head: bool, + pub is_remote: bool, + pub is_current: bool, + pub upstream: Option, +} + +#[graphql_object(context = GraphqlContext)] +impl BranchGql { + fn name(&self) -> &str { + &self.name + } + fn oid(&self) -> &str { + &self.oid + } + fn is_head(&self) -> bool { + self.is_head + } + fn is_remote(&self) -> bool { + self.is_remote + } + fn is_current(&self) -> bool { + self.is_current + } + fn upstream(&self) -> Option<&str> { + self.upstream.as_deref() + } +} + +impl From for BranchGql { + fn from(item: BranchListItem) -> Self { + BranchGql { + name: item.name, + oid: item.oid.0, + is_head: item.is_head, + is_remote: item.is_remote, + is_current: item.is_current, + upstream: item.upstream, + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct BranchSummaryGql { + pub local_count: i32, + pub remote_count: i32, + pub all_count: i32, +} + +#[graphql_object(context = GraphqlContext)] +impl BranchSummaryGql { + fn local_count(&self) -> i32 { + self.local_count + } + fn remote_count(&self) -> i32 { + self.remote_count + } + fn all_count(&self) -> i32 { + self.all_count + } +} + +impl From for BranchSummaryGql { + fn from(s: BranchSummary) -> Self { + BranchSummaryGql { + local_count: s.local_count as i32, + remote_count: s.remote_count as i32, + all_count: s.all_count as i32, + } + } +} + +pub async fn resolve_branches( + ctx: &GraphqlContext, +) -> anyhow::Result> { + let revision = repo_revision(ctx).await; + let key = mutable_cache_key(ctx, "query:git:branches", &[], &revision); + cached_json(&ctx.cache, &key, BRANCH_TTL, || { + let repo = ctx.repo.clone(); + async move { + let items = + tokio::task::spawn_blocking(move || repo.branch_list_all()) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(items.into_iter().map(BranchGql::from).collect()) + } + }) + .await +} + +pub async fn resolve_branch( + ctx: &GraphqlContext, + name: String, +) -> anyhow::Result { + let revision = repo_revision(ctx).await; + let key = mutable_cache_key(ctx, "query:git:branch", &[&name], &revision); + cached_json(&ctx.cache, &key, BRANCH_TTL, || { + let repo = ctx.repo.clone(); + let branch_name = name.clone(); + async move { + let item = tokio::task::spawn_blocking(move || { + repo.branch_info(branch_name) + }) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(BranchGql::from(item)) + } + }) + .await +} + +pub async fn resolve_head_branch( + ctx: &GraphqlContext, +) -> anyhow::Result { + let revision = repo_revision(ctx).await; + let key = mutable_cache_key(ctx, "query:git:head_branch", &[], &revision); + cached_json(&ctx.cache, &key, BRANCH_TTL, || { + let repo1 = ctx.repo.clone(); + let repo2 = ctx.repo.clone(); + async move { + let head_name = + tokio::task::spawn_blocking(move || repo1.branch_head_name()) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + let branch = tokio::task::spawn_blocking(move || { + repo2.branch_info(head_name) + }) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(BranchGql::from(branch)) + } + }) + .await +} + +pub async fn resolve_branch_summary( + ctx: &GraphqlContext, +) -> anyhow::Result { + let revision = repo_revision(ctx).await; + let key = + mutable_cache_key(ctx, "query:git:branch_summary", &[], &revision); + cached_json(&ctx.cache, &key, BRANCH_TTL, || { + let repo = ctx.repo.clone(); + async move { + let summary = + tokio::task::spawn_blocking(move || repo.branch_summary()) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(BranchSummaryGql::from(summary)) + } + }) + .await +} diff --git a/lib/git/graphql/cache_helper.rs b/lib/git/graphql/cache_helper.rs new file mode 100644 index 0000000..17ad6f2 --- /dev/null +++ b/lib/git/graphql/cache_helper.rs @@ -0,0 +1,93 @@ +use std::{fmt::Display, future::Future, time::Duration}; + +use cache::AppCache; +use serde::{Serialize, de::DeserializeOwned}; + +use crate::{bare::GitBare, graphql::GraphqlContext}; + +const KEY_SEPARATOR: &str = ":"; + +pub fn cache_key(namespace: &str, parts: &[&str]) -> String { + let mut segments: Vec<&str> = vec![namespace]; + for part in parts { + if !part.is_empty() { + segments.push(part); + } + } + segments.join(KEY_SEPARATOR) +} + +pub fn cache_key_with_revision( + namespace: &str, + revision: impl Display, + parts: &[&str], +) -> String { + let mut key = cache_key(namespace, parts); + key.push_str(KEY_SEPARATOR); + key.push_str(&revision.to_string()); + key +} + +fn path_hash(repo: &GitBare) -> String { + let path_str = repo.bare_dir.to_string_lossy(); + let hash = simple_hash(&path_str); + hash[..8].to_string() +} + +fn simple_hash(s: &str) -> String { + let mut h: u64 = 0; + for b in s.bytes() { + h = h.wrapping_mul(31).wrapping_add(b as u64); + } + format!("{:016x}", h) +} + +pub async fn repo_revision(ctx: &GraphqlContext) -> String { + let repo = ctx.repo.clone(); + let result = tokio::task::spawn_blocking(move || { + repo.git_command_stdout(vec![ + "rev-parse".to_string(), + "HEAD".to_string(), + ]) + }) + .await; + match result { + Ok(Ok(oid)) => oid.trim().to_string(), + _ => "unknown".to_string(), + } +} + +pub fn mutable_cache_key( + ctx: &GraphqlContext, + namespace: &str, + parts: &[&str], + revision: &str, +) -> String { + let ph = path_hash(&ctx.repo); + let mut all_parts: Vec<&str> = vec![&ph]; + all_parts.extend_from_slice(parts); + cache_key_with_revision(namespace, revision, &all_parts) +} + +pub async fn cached_json( + cache: &AppCache, + key: &str, + _ttl: Duration, + build: F, +) -> anyhow::Result +where + T: Serialize + DeserializeOwned, + F: FnOnce() -> Fut, + Fut: Future>, +{ + if let Ok(Some(cached)) = cache.get::(key).await { + return Ok(cached); + } + + let value = build().await?; + cache.set(key, &value).await.ok(); + Ok(value) +} + +pub const IMMUTABLE_TTL: Duration = Duration::from_secs(86400); +pub const MUTABLE_TTL: Duration = Duration::from_secs(300); diff --git a/lib/git/graphql/commit.rs b/lib/git/graphql/commit.rs new file mode 100644 index 0000000..e4d8996 --- /dev/null +++ b/lib/git/graphql/commit.rs @@ -0,0 +1,224 @@ +use std::time::Duration; + +use juniper::graphql_object; +use serde::{Deserialize, Serialize}; + +use crate::{ + cmd::{ + commit::{ + CommitMeta, CommitSignature, + commit_summary::CommitSummary, + commit_walker::{CommitWalkParams, CommitWalkSort}, + }, + oid::ObjectId, + }, + graphql::{ + GraphqlContext, + cache_helper::{ + IMMUTABLE_TTL, MUTABLE_TTL, cache_key, cached_json, + mutable_cache_key, repo_revision, + }, + }, +}; + +const COMMIT_TTL: Duration = IMMUTABLE_TTL; + +#[derive(Clone, Serialize, Deserialize)] +pub struct CommitGql { + pub oid: String, + pub message: String, + pub summary: String, + pub author: CommitSignatureGql, + pub committer: CommitSignatureGql, + pub tree_id: String, + pub parent_ids: Vec, + pub encoding: Option, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct CommitSignatureGql { + pub name: String, + pub email: String, + pub time_secs: f64, + pub offset_minutes: i32, +} + +#[graphql_object(context = GraphqlContext)] +impl CommitSignatureGql { + fn name(&self) -> &str { + &self.name + } + fn email(&self) -> &str { + &self.email + } + fn time_secs(&self) -> f64 { + self.time_secs + } + fn offset_minutes(&self) -> i32 { + self.offset_minutes + } +} + +#[graphql_object(context = GraphqlContext)] +impl CommitGql { + fn oid(&self) -> &str { + &self.oid + } + fn message(&self) -> &str { + &self.message + } + fn summary(&self) -> &str { + &self.summary + } + fn author(&self) -> &CommitSignatureGql { + &self.author + } + fn committer(&self) -> &CommitSignatureGql { + &self.committer + } + fn tree_id(&self) -> &str { + &self.tree_id + } + fn parent_ids(&self) -> &[String] { + &self.parent_ids + } + fn encoding(&self) -> Option<&str> { + self.encoding.as_deref() + } +} + +impl From for CommitSignatureGql { + fn from(s: CommitSignature) -> Self { + CommitSignatureGql { + name: s.name, + email: s.email, + time_secs: s.time_secs as f64, + offset_minutes: s.offset_minutes, + } + } +} + +impl From for CommitGql { + fn from(m: CommitMeta) -> Self { + CommitGql { + oid: m.oid.0, + message: m.message, + summary: m.summary, + author: CommitSignatureGql::from(m.author), + committer: CommitSignatureGql::from(m.committer), + tree_id: m.tree_id.0, + parent_ids: m.parent_ids.iter().map(|p| p.0.clone()).collect(), + encoding: m.encoding, + } + } +} + +pub async fn resolve_commit( + ctx: &GraphqlContext, + oid: String, +) -> anyhow::Result { + let key = cache_key("query:git:commit", &[&oid]); + cached_json(&ctx.cache, &key, COMMIT_TTL, || { + let repo = ctx.repo.clone(); + let oid_obj = ObjectId::new(&oid); + async move { + let meta = + tokio::task::spawn_blocking(move || repo.commit_info(oid_obj)) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(CommitGql::from(meta)) + } + }) + .await +} + +pub async fn resolve_commit_history( + ctx: &GraphqlContext, + limit: Option, + skip: Option, + sort: Option, +) -> anyhow::Result> { + let revision = repo_revision(ctx).await; + let sort_str = sort.as_deref().unwrap_or("time"); + let key_parts = format!( + "{}:{}:{}:{}", + limit.map(|l| l.to_string()).unwrap_or_default(), + skip.map(|s| s.to_string()).unwrap_or_default(), + sort_str, + revision, + ); + let key = mutable_cache_key( + ctx, + "query:git:commit_history", + &[&key_parts], + &revision, + ); + cached_json(&ctx.cache, &key, MUTABLE_TTL, || { + let repo = ctx.repo.clone(); + let walk_sort = match sort_str { + "topological" => CommitWalkSort::Topological, + "time" => CommitWalkSort::Time, + "reverse" => CommitWalkSort::Reverse, + _ => CommitWalkSort::Time, + }; + let params = CommitWalkParams { + limit: limit.map(|l| l as usize), + skip: skip.unwrap_or(0) as usize, + sort: walk_sort, + ..Default::default() + }; + async move { + let commits = tokio::task::spawn_blocking(move || { + repo.commit_history(params) + }) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(commits.into_iter().map(CommitGql::from).collect()) + } + }) + .await +} + +pub async fn resolve_commit_summary( + ctx: &GraphqlContext, +) -> anyhow::Result { + let revision = repo_revision(ctx).await; + let key = + mutable_cache_key(ctx, "query:git:commit_summary", &[], &revision); + cached_json(&ctx.cache, &key, MUTABLE_TTL, || { + let repo = ctx.repo.clone(); + async move { + let summary = + tokio::task::spawn_blocking(move || repo.commit_summary()) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(CommitSummaryGql::from(summary)) + } + }) + .await +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct CommitSummaryGql { + pub head: Option, + pub count: i32, +} + +#[graphql_object(context = GraphqlContext)] +impl CommitSummaryGql { + fn head(&self) -> Option<&CommitGql> { + self.head.as_ref() + } + fn count(&self) -> i32 { + self.count + } +} + +impl From for CommitSummaryGql { + fn from(s: CommitSummary) -> Self { + CommitSummaryGql { + head: s.head.map(CommitGql::from), + count: s.count as i32, + } + } +} diff --git a/lib/git/graphql/mod.rs b/lib/git/graphql/mod.rs new file mode 100644 index 0000000..31892b1 --- /dev/null +++ b/lib/git/graphql/mod.rs @@ -0,0 +1,191 @@ +pub mod blob; +pub mod branch; +pub mod cache_helper; +pub mod commit; +pub mod tag; +pub mod tree; + +use std::path::PathBuf; + +use actix_web::{HttpResponse, web}; +use cache::AppCache; +use db::database::AppDatabase; +use juniper::{ + EmptyMutation, EmptySubscription, FieldResult, RootNode, graphql_object, +}; +use serde_json::json; + +use crate::{ + bare::GitBare, + graphql::{ + blob::{BlobResult, resolve_blob, resolve_blob_size}, + branch::{ + BranchGql, BranchSummaryGql, resolve_branch, + resolve_branch_summary, resolve_branches, resolve_head_branch, + }, + commit::{ + CommitGql, CommitSummaryGql, resolve_commit, + resolve_commit_history, resolve_commit_summary, + }, + tag::{ + TagGql, TagSummaryGql, resolve_tag, resolve_tag_summary, + resolve_tags, + }, + tree::{TreeEntryGql, TreeInfoGql, resolve_tree, resolve_tree_entries}, + }, +}; + +type Schema = RootNode< + GraphqlQuery, + EmptyMutation, + EmptySubscription, +>; + +fn schema() -> Schema { + Schema::new( + GraphqlQuery, + EmptyMutation::::new(), + EmptySubscription::::new(), + ) +} + +pub async fn graphql_handle( + path: web::Path<(String, String)>, + state: web::Data, + body: web::Json, +) -> HttpResponse { + let (wk, repo_name) = path.into_inner(); + let repo = match state.git_state.repo(wk, repo_name).await { + Ok(repo) => repo, + Err(err) => { + return HttpResponse::InternalServerError().json(json!({ + "message": err.to_string() + })); + } + }; + + let ctx = GraphqlContext { + repo: GitBare { + bare_dir: PathBuf::from(&repo.repo.storage_path), + }, + cache: state.git_state.cache.clone(), + db: state.git_state.db.clone(), + }; + + let schema = schema(); + let response = body.execute(&schema, &ctx).await; + + let status_code = if response.is_ok() { 200 } else { 400 }; + HttpResponse::build( + actix_web::http::StatusCode::from_u16(status_code).unwrap(), + ) + .content_type("application/json") + .json(response) +} + +#[derive(Clone)] +pub struct GraphqlContext { + pub repo: GitBare, + pub cache: AppCache, + pub db: AppDatabase, +} + +impl juniper::Context for GraphqlContext {} + +pub struct GraphqlQuery; + +fn to_field_error(e: anyhow::Error) -> juniper::FieldError { + juniper::FieldError::new(e.to_string(), juniper::Value::null()) +} + +#[graphql_object] +#[graphql(context = GraphqlContext)] +impl GraphqlQuery { + fn api_version() -> &'static str { + env!("CARGO_PKG_VERSION") + } + + async fn head_branch(ctx: &GraphqlContext) -> FieldResult { + resolve_head_branch(ctx).await.map_err(to_field_error) + } + + async fn branches(ctx: &GraphqlContext) -> FieldResult> { + resolve_branches(ctx).await.map_err(to_field_error) + } + + async fn branch( + ctx: &GraphqlContext, + name: String, + ) -> FieldResult { + resolve_branch(ctx, name).await.map_err(to_field_error) + } + + async fn branch_summary( + ctx: &GraphqlContext, + ) -> FieldResult { + resolve_branch_summary(ctx).await.map_err(to_field_error) + } + + async fn tags(ctx: &GraphqlContext) -> FieldResult> { + resolve_tags(ctx).await.map_err(to_field_error) + } + + async fn tag(ctx: &GraphqlContext, name: String) -> FieldResult { + resolve_tag(ctx, name).await.map_err(to_field_error) + } + + async fn tag_summary(ctx: &GraphqlContext) -> FieldResult { + resolve_tag_summary(ctx).await.map_err(to_field_error) + } + + async fn commit( + ctx: &GraphqlContext, + oid: String, + ) -> FieldResult { + resolve_commit(ctx, oid).await.map_err(to_field_error) + } + + async fn commit_history( + ctx: &GraphqlContext, + limit: Option, + skip: Option, + sort: Option, + ) -> FieldResult> { + resolve_commit_history(ctx, limit, skip, sort) + .await + .map_err(to_field_error) + } + + async fn commit_summary( + ctx: &GraphqlContext, + ) -> FieldResult { + resolve_commit_summary(ctx).await.map_err(to_field_error) + } + + async fn tree( + ctx: &GraphqlContext, + oid: String, + ) -> FieldResult { + resolve_tree(ctx, oid).await.map_err(to_field_error) + } + + async fn tree_entries( + ctx: &GraphqlContext, + tree_oid: String, + ) -> FieldResult> { + resolve_tree_entries(ctx, tree_oid) + .await + .map_err(to_field_error) + } + + async fn blob_size(ctx: &GraphqlContext, oid: String) -> FieldResult { + resolve_blob_size(ctx, oid).await.map_err(to_field_error) + } + + async fn blob( + ctx: &GraphqlContext, + oid: String, + ) -> FieldResult { + resolve_blob(ctx, oid).await.map_err(to_field_error) + } +} diff --git a/lib/git/graphql/tag.rs b/lib/git/graphql/tag.rs new file mode 100644 index 0000000..1aa06c4 --- /dev/null +++ b/lib/git/graphql/tag.rs @@ -0,0 +1,139 @@ +use std::time::Duration; + +use juniper::graphql_object; +use serde::{Deserialize, Serialize}; + +use crate::{ + cmd::tag::{TagItem, TagSummary}, + graphql::{ + GraphqlContext, + cache_helper::{ + MUTABLE_TTL, cached_json, mutable_cache_key, repo_revision, + }, + }, +}; + +const TAG_TTL: Duration = MUTABLE_TTL; + +#[derive(Clone, Serialize, Deserialize)] +pub struct TagGql { + pub name: String, + pub oid: String, + pub target: String, + pub is_annotated: bool, + pub message: Option, + pub tagger: Option, + pub tagger_email: Option, +} + +#[graphql_object(context = GraphqlContext)] +impl TagGql { + fn name(&self) -> &str { + &self.name + } + fn oid(&self) -> &str { + &self.oid + } + fn target(&self) -> &str { + &self.target + } + fn is_annotated(&self) -> bool { + self.is_annotated + } + fn message(&self) -> Option<&str> { + self.message.as_deref() + } + fn tagger(&self) -> Option<&str> { + self.tagger.as_deref() + } + fn tagger_email(&self) -> Option<&str> { + self.tagger_email.as_deref() + } +} + +impl From for TagGql { + fn from(item: TagItem) -> Self { + TagGql { + name: item.name, + oid: item.oid.0, + target: item.target.0, + is_annotated: item.is_annotated, + message: item.message, + tagger: item.tagger, + tagger_email: item.tagger_email, + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct TagSummaryGql { + pub total_count: i32, +} + +#[graphql_object(context = GraphqlContext)] +impl TagSummaryGql { + fn total_count(&self) -> i32 { + self.total_count + } +} + +impl From for TagSummaryGql { + fn from(s: TagSummary) -> Self { + TagSummaryGql { + total_count: s.total_count as i32, + } + } +} + +pub async fn resolve_tags(ctx: &GraphqlContext) -> anyhow::Result> { + let revision = repo_revision(ctx).await; + let key = mutable_cache_key(ctx, "query:git:tags", &[], &revision); + cached_json(&ctx.cache, &key, TAG_TTL, || { + let repo = ctx.repo.clone(); + async move { + let items = tokio::task::spawn_blocking(move || repo.tag_list()) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(items.into_iter().map(TagGql::from).collect()) + } + }) + .await +} + +pub async fn resolve_tag( + ctx: &GraphqlContext, + name: String, +) -> anyhow::Result { + let revision = repo_revision(ctx).await; + let key = mutable_cache_key(ctx, "query:git:tag", &[&name], &revision); + cached_json(&ctx.cache, &key, TAG_TTL, || { + let repo = ctx.repo.clone(); + let tag_name = name.clone(); + async move { + let item = + tokio::task::spawn_blocking(move || repo.tag_info(tag_name)) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(TagGql::from(item)) + } + }) + .await +} + +pub async fn resolve_tag_summary( + ctx: &GraphqlContext, +) -> anyhow::Result { + let revision = repo_revision(ctx).await; + let key = mutable_cache_key(ctx, "query:git:tag_summary", &[], &revision); + cached_json(&ctx.cache, &key, TAG_TTL, || { + let repo = ctx.repo.clone(); + async move { + let summary = + tokio::task::spawn_blocking(move || repo.tag_summary()) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(TagSummaryGql::from(summary)) + } + }) + .await +} diff --git a/lib/git/graphql/tree.rs b/lib/git/graphql/tree.rs new file mode 100644 index 0000000..4fa0bac --- /dev/null +++ b/lib/git/graphql/tree.rs @@ -0,0 +1,134 @@ +use std::time::Duration; + +use juniper::graphql_object; +use serde::{Deserialize, Serialize}; + +use crate::{ + cmd::{ + oid::ObjectId, + tree::{TreeEntry, TreeInfo, TreeKind}, + }, + graphql::{ + GraphqlContext, + cache_helper::{IMMUTABLE_TTL, cache_key, cached_json}, + }, +}; + +const TREE_TTL: Duration = IMMUTABLE_TTL; + +#[derive(Clone, Serialize, Deserialize)] +pub struct TreeInfoGql { + pub oid: String, + pub entry_count: i32, + pub is_empty: bool, +} + +#[graphql_object(context = GraphqlContext)] +impl TreeInfoGql { + fn oid(&self) -> &str { + &self.oid + } + fn entry_count(&self) -> i32 { + self.entry_count + } + fn is_empty(&self) -> bool { + self.is_empty + } +} + +impl From for TreeInfoGql { + fn from(info: TreeInfo) -> Self { + TreeInfoGql { + oid: info.oid.0, + entry_count: info.entry_count as i32, + is_empty: info.is_empty, + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct TreeEntryGql { + pub name: String, + pub oid: String, + pub kind: String, + pub filemode: i32, + pub is_binary: bool, + pub is_lfs: bool, +} + +#[graphql_object(context = GraphqlContext)] +impl TreeEntryGql { + fn name(&self) -> &str { + &self.name + } + fn oid(&self) -> &str { + &self.oid + } + fn kind(&self) -> &str { + &self.kind + } + fn filemode(&self) -> i32 { + self.filemode + } + fn is_binary(&self) -> bool { + self.is_binary + } + fn is_lfs(&self) -> bool { + self.is_lfs + } +} + +impl From for TreeEntryGql { + fn from(entry: TreeEntry) -> Self { + TreeEntryGql { + name: entry.name, + oid: entry.oid.0, + kind: match entry.kind { + TreeKind::Blob => "blob".to_string(), + TreeKind::Tree => "tree".to_string(), + TreeKind::LfsPointer => "lfs_pointer".to_string(), + }, + filemode: entry.filemode as i32, + is_binary: entry.is_binary, + is_lfs: entry.is_lfs, + } + } +} + +pub async fn resolve_tree( + ctx: &GraphqlContext, + oid: String, +) -> anyhow::Result { + let key = cache_key("query:git:tree", &[&oid]); + cached_json(&ctx.cache, &key, TREE_TTL, || { + let repo = ctx.repo.clone(); + let oid_obj = ObjectId::new(&oid); + async move { + let info = + tokio::task::spawn_blocking(move || repo.resolve_tree(oid_obj)) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(TreeInfoGql::from(info)) + } + }) + .await +} + +pub async fn resolve_tree_entries( + ctx: &GraphqlContext, + tree_oid: String, +) -> anyhow::Result> { + let key = cache_key("query:git:tree_entries", &[&tree_oid]); + cached_json(&ctx.cache, &key, TREE_TTL, || { + let repo = ctx.repo.clone(); + let oid_obj = ObjectId::new(&tree_oid); + async move { + let entries = + tokio::task::spawn_blocking(move || repo.tree_entries(oid_obj)) + .await? + .map_err(|e| anyhow::anyhow!(e))?; + Ok(entries.into_iter().map(TreeEntryGql::from).collect()) + } + }) + .await +} diff --git a/lib/git/http/action.rs b/lib/git/http/action.rs new file mode 100644 index 0000000..260403e --- /dev/null +++ b/lib/git/http/action.rs @@ -0,0 +1,101 @@ +use actix_web::{HttpResponse, web}; +use serde::{Deserialize, Serialize}; + +use crate::{ + http::{HttpAppState, utils::get_repo_model}, + sync::cicheck::poll_ci_task_for_repo, +}; + +#[derive(Debug, Deserialize)] +pub struct ActionQuery { + #[serde(default = "default_timeout")] + pub timeout: usize, +} + +fn default_timeout() -> usize { + 5 +} + +#[derive(Debug, Serialize)] +struct ActionTask { + id: String, + repo_id: String, + pipeline_name: String, + trigger: String, +} + +#[derive(Debug, Serialize)] +struct ActionResponse { + status: String, + task: Option, +} +pub async fn action_poll( + path: web::Path<(String, String)>, + query: web::Query, + state: web::Data, +) -> HttpResponse { + let (namespace, repo_name) = path.into_inner(); + + let repo_model = + match get_repo_model(&namespace, &repo_name, &state.db).await { + Ok(m) => m, + Err(_) => { + return HttpResponse::NotFound().json(ActionResponse { + status: "not_found".into(), + task: None, + }); + } + }; + + let task_json = + poll_ci_task_for_repo(&state.sync.pool(), repo_model.id, query.timeout) + .await; + + match task_json { + Some(json) => { + let hook_task: serde_json::Value = match serde_json::from_str(&json) + { + Ok(v) => v, + Err(_) => { + return HttpResponse::InternalServerError().json( + ActionResponse { + status: "error".into(), + task: None, + }, + ); + } + }; + + let task_id = hook_task + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let pipeline_name = hook_task + .get("payload") + .and_then(|p| p.get("pipeline_name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let trigger = hook_task + .get("payload") + .and_then(|p| p.get("trigger")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + HttpResponse::Ok().json(ActionResponse { + status: "pending".into(), + task: Some(ActionTask { + id: task_id, + repo_id: repo_model.id.to_string(), + pipeline_name, + trigger, + }), + }) + } + None => HttpResponse::NoContent().finish(), + } +} diff --git a/lib/git/http/auth.rs b/lib/git/http/auth.rs new file mode 100644 index 0000000..eebc07c --- /dev/null +++ b/lib/git/http/auth.rs @@ -0,0 +1,94 @@ +use actix_web::{Error, HttpRequest}; +use argon2::{ + Argon2, + password_hash::{PasswordHash, PasswordVerifier}, +}; +use db::database::AppDatabase; +use model::{ + repos::RepoModel, + users::{user::UserModel, user_token::UserTokenModel}, +}; + +use crate::{ + http::utils::extract_basic_credentials, ssh::authz::SshAuthService, +}; + +pub async fn verify_access_token( + db: &AppDatabase, + username: &str, + access_key: &str, +) -> Result { + let user = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" \ + WHERE username = $1", + ) + .bind(username) + .fetch_optional(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorUnauthorized("Invalid username or access key"))? + .ok_or_else(|| actix_web::error::ErrorUnauthorized("Invalid username or access key"))?; + + let tokens: Vec = sqlx::query_as::<_, UserTokenModel>( + "SELECT id, \"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at \ + FROM user_token \ + WHERE \"user\" = $1 AND is_revoked = false", + ) + .bind(user.id) + .fetch_all(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorUnauthorized("Invalid username or access key"))? + .into_iter() + .filter(|token| { + token + .expires_at + .map(|expires_at| expires_at >= chrono::Utc::now()) + .unwrap_or(true) + }) + .collect(); + + for token in tokens { + let Ok(hash) = PasswordHash::new(&token.token_hash) else { + tracing::warn!( + token_id = token.id, + "invalid stored access key hash" + ); + continue; + }; + if Argon2::default() + .verify_password(access_key.as_bytes(), &hash) + .is_ok() + { + return Ok(user); + } + } + + Err(actix_web::error::ErrorUnauthorized( + "Invalid username or access key", + )) +} + +pub async fn authorize_repo_access( + req: &HttpRequest, + db: &AppDatabase, + repo: &RepoModel, + is_write: bool, +) -> Result<(), Error> { + if !is_write && repo.visibility == "public" { + return Ok(()); + } + + let (username, access_key) = extract_basic_credentials(req)?; + let user = verify_access_token(db, &username, &access_key).await?; + let authz = SshAuthService::new(db.clone()); + + let can_access = authz.check_repo_permission(&user, repo, is_write).await; + if !can_access { + return Err(actix_web::error::ErrorForbidden( + "No permission for repository", + )); + } + + Ok(()) +} diff --git a/lib/git/http/handler.rs b/lib/git/http/handler.rs new file mode 100644 index 0000000..0c5c264 --- /dev/null +++ b/lib/git/http/handler.rs @@ -0,0 +1,327 @@ +use std::{ + path::PathBuf, + pin::Pin, + time::{Duration, Instant}, +}; + +use actix_web::{Error, HttpResponse, web}; +use async_stream::stream; +use db::database::AppDatabase; +use futures_util::{Stream, StreamExt}; +use model::repos::{RepoModel, repo_protect::RepoProtectModel}; +use tokio::io::AsyncWriteExt; + +use crate::ssh::{ + branch_protect::check_branch_protection, ref_update::RefUpdate, +}; + +type ByteStream = Pin, std::io::Error>>>>; + +const PRE_PACK_LIMIT: usize = 1_048_576; +const GIT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30); + +pub fn is_valid_oid(oid: &str) -> bool { + oid.len() == 40 && oid.chars().all(|c| c.is_ascii_hexdigit()) +} + +pub fn is_valid_lfs_oid(oid: &str) -> bool { + oid.len() == 64 && oid.chars().all(|c| c.is_ascii_hexdigit()) +} + +pub struct GitHttpHandler { + storage_path: PathBuf, + repo: RepoModel, + db: AppDatabase, +} + +impl GitHttpHandler { + pub fn new( + storage_path: PathBuf, + repo: RepoModel, + db: AppDatabase, + ) -> Self { + Self { + storage_path, + repo, + db, + } + } + + pub async fn upload_pack( + &self, + payload: web::Payload, + ) -> Result { + self.handle_git_rpc("upload-pack", payload).await + } + + pub async fn receive_pack( + &self, + payload: web::Payload, + ) -> Result { + self.handle_git_rpc("receive-pack", payload).await + } + + pub async fn info_refs( + &self, + service: &str, + ) -> Result { + let git_cmd = match service { + "git-upload-pack" => "upload-pack", + "git-receive-pack" => "receive-pack", + _ => { + return Err(actix_web::error::ErrorBadRequest( + "Invalid service", + )); + } + }; + + let output = tokio::time::timeout(GIT_OPERATION_TIMEOUT, async { + tokio::process::Command::new("git") + .arg(git_cmd) + .arg("--stateless-rpc") + .arg("--advertise-refs") + .arg(&self.storage_path) + .output() + .await + }) + .await + .map_err(|_| { + actix_web::error::ErrorInternalServerError("Git info-refs timeout") + })? + .map_err(|e| { + actix_web::error::ErrorInternalServerError(format!( + "Failed to execute git: {}", + e + )) + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(actix_web::error::ErrorInternalServerError(format!( + "Git command failed: {}", + stderr + ))); + } + + let mut response_body = Vec::new(); + let header = format!("# service={}\n", service); + write_pkt_line(&mut response_body, header.as_bytes()); + write_flush_pkt(&mut response_body); + response_body.extend_from_slice(&output.stdout); + + Ok(HttpResponse::Ok() + .content_type(format!("application/x-{}-advertisement", service)) + .insert_header(("Cache-Control", "no-cache")) + .body(response_body)) + } + + async fn handle_git_rpc( + &self, + service: &str, + mut payload: web::Payload, + ) -> Result { + let started = Instant::now(); + tracing::info!( + "git_rpc_started service={} repo={} repo_id={}", + service, + self.repo.name, + self.repo.id.to_string() + ); + let mut child = tokio::process::Command::new("git") + .arg(service) + .arg("--stateless-rpc") + .arg(&self.storage_path) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn() + .map_err(|e| { + actix_web::error::ErrorInternalServerError(format!( + "Failed to spawn git: {}", + e + )) + })?; + + let stream = stream! { + while let Some(chunk) = payload.next().await { + match chunk { + Ok(bytes) => { yield Ok(bytes.to_vec()); } + Err(e) => { yield Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())); } + } + } + }; + let mut stream: ByteStream = Box::pin(stream); + + if service == "receive-pack" { + let branch_protects: Vec = sqlx::query_as::<_, RepoProtectModel>( + "SELECT id, repo, pattern, require_pull_request, required_approvals, \ + require_status_checks, required_status_contexts, enforce_admins, \ + allow_force_pushes, allow_deletions, created_at, updated_at \ + FROM repo_protect \ + WHERE repo = $1", + ) + .bind(self.repo.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| actix_web::error::ErrorInternalServerError(e.to_string()))?; + + let mut pre_pack: Vec = Vec::with_capacity(65536); + + while let Some(chunk) = stream.next().await { + let bytes = match chunk { + Ok(b) => b, + Err(e) => return Err(Error::from(e)), + }; + + if pre_pack.len() + bytes.len() > PRE_PACK_LIMIT { + tracing::warn!( + "git_rpc_payload_too_large service={} repo={} repo_id={}", + service, + self.repo.name, + self.repo.id.to_string() + ); + return Err(actix_web::error::ErrorPayloadTooLarge( + format!( + "Ref negotiation exceeds {} byte limit", + PRE_PACK_LIMIT + ), + )); + } + + if let Some(pos) = bytes.windows(4).position(|w| w == b"0000") { + let end = pos + 4; + pre_pack.extend_from_slice(&bytes[..end]); + + let refs = RefUpdate::parse_ref_updates(&pre_pack) + .map_err(actix_web::error::ErrorBadRequest)?; + if let Some(msg) = refs.iter().find_map(|r#ref| { + check_branch_protection(&branch_protects, r#ref) + }) { + tracing::warn!( + "branch_protection_violation repo={} repo_id={} message={}", + self.repo.name, + self.repo.id.to_string(), + msg + ); + return Err(actix_web::error::ErrorForbidden(msg)); + } + + let remaining: ByteStream = Box::pin(stream! { + yield Ok(pre_pack); + if end < bytes.len() { + yield Ok(bytes[end..].to_vec()); + } + while let Some(chunk) = stream.next().await { + yield chunk; + } + }); + stream = remaining; + break; + } else { + pre_pack.extend_from_slice(&bytes); + } + } + } + + if let Some(mut stdin) = child.stdin.take() { + let write_task = actix_web::rt::spawn(async move { + while let Some(chunk) = stream.next().await { + match chunk { + Ok(bytes) => { + if let Err(e) = stdin.write_all(&bytes).await { + return Err(e); + } + } + Err(e) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + e, + )); + } + } + } + drop(stdin); + Ok::<_, std::io::Error>(()) + }); + + let write_result = + tokio::time::timeout(GIT_OPERATION_TIMEOUT, write_task) + .await + .map_err(|_| { + actix_web::error::ErrorInternalServerError( + "Git stdin write timeout", + ) + })? + .map_err(|e| { + actix_web::error::ErrorInternalServerError(format!( + "Write error: {}", + e + )) + })?; + + if let Err(e) = write_result { + return Err(actix_web::error::ErrorInternalServerError( + format!("Failed to write to git: {}", e), + )); + } + } + + let output = tokio::time::timeout( + GIT_OPERATION_TIMEOUT, + child.wait_with_output(), + ) + .await + .map_err(|_| { + actix_web::error::ErrorInternalServerError("Git operation timeout") + })? + .map_err(|e| { + actix_web::error::ErrorInternalServerError(format!( + "Git wait failed: {}", + e + )) + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let ms = started.elapsed().as_millis() as u64; + tracing::error!( + "git_rpc_failed service={} repo={} repo_id={} duration_ms={} stderr={}", + service, + self.repo.name, + self.repo.id.to_string(), + ms, + stderr.to_string() + ); + return Err(actix_web::error::ErrorInternalServerError(format!( + "Git command failed: {}", + stderr + ))); + } + + let ms = started.elapsed().as_millis() as u64; + tracing::info!( + "git_rpc_completed service={} repo={} repo_id={} duration_ms={} bytes_out={}", + service, + self.repo.name, + self.repo.id.to_string(), + ms, + output.stdout.len() + ); + + Ok(HttpResponse::Ok() + .content_type(format!("application/x-git-{}-result", service)) + .insert_header(("Cache-Control", "no-cache")) + .body(output.stdout)) + } +} + +fn write_pkt_line(buf: &mut Vec, data: &[u8]) { + let len = data.len() + 4; + buf.extend_from_slice(format!("{:04x}", len).as_bytes()); + buf.extend_from_slice(data); +} + +fn write_flush_pkt(buf: &mut Vec) { + buf.extend_from_slice(b"0000"); +} diff --git a/lib/git/http/lfs.rs b/lib/git/http/lfs.rs new file mode 100644 index 0000000..e77c407 --- /dev/null +++ b/lib/git/http/lfs.rs @@ -0,0 +1,726 @@ +use std::{collections::HashMap, path::PathBuf}; + +use actix_web::{HttpResponse, web}; +use cache::AppCache; +use db::database::AppDatabase; +use model::repos::{ + RepoModel, repo_lfs_lock::RepoLfsLockModel, + repo_lfs_object::RepoLfsObjectModel, +}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::{errors::GitError, http::handler::is_valid_lfs_oid}; + +const LFS_AUTH_TOKEN_EXPIRY: u64 = 3600; +const LFS_MAX_OBJECT_SIZE: i64 = 50 * 1024 * 1024 * 1024; + +#[derive(Deserialize, Serialize)] +pub struct BatchRequest { + pub operation: String, + pub objects: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub transfers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#ref: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub hash_algo: Option, +} + +#[derive(Deserialize, Serialize)] +pub struct LfsRef { + pub name: String, +} + +#[derive(Deserialize, Serialize, Clone)] +pub struct LfsObjectReq { + pub oid: String, + pub size: i64, +} + +#[derive(Serialize)] +pub struct BatchResponse { + pub transfer: String, + pub objects: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub hash_algo: Option, +} + +#[derive(Serialize)] +pub struct LfsObjectResponse { + pub oid: String, + pub size: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub authenticated: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub actions: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Serialize)] +pub struct LfsAction { + pub href: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub header: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_in: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, +} + +#[derive(Serialize)] +pub struct LfsError { + pub code: i32, + pub message: String, +} + +#[derive(Deserialize)] +pub struct CreateLockRequest { + pub oid: String, +} + +#[derive(Serialize)] +pub struct LockResponse { + pub id: Uuid, + pub path: String, + pub locked_by: Uuid, + pub locked_at: String, +} + +pub struct LfsHandler { + pub storage_path: PathBuf, + pub model: RepoModel, + pub namespace: String, + pub db: AppDatabase, +} + +impl LfsHandler { + pub fn new( + storage_path: PathBuf, + model: RepoModel, + namespace: String, + db: AppDatabase, + ) -> Self { + Self { + storage_path, + model, + namespace, + db, + } + } + + fn get_lfs_storage_path(&self) -> PathBuf { + self.storage_path.join(".lfs") + } + + fn get_object_path(&self, oid: &str) -> PathBuf { + let prefix = &oid[..2]; + self.get_lfs_storage_path() + .join("objects") + .join(prefix) + .join(oid) + } + + fn build_object_url(&self, base_url: &str, oid: &str) -> String { + format!( + "{}/{}/{}.git/info/lfs/objects/{}", + base_url, self.namespace, self.model.name, oid + ) + } + + pub async fn batch( + &self, + req: BatchRequest, + base_url: &str, + ) -> Result { + let operation = req.operation.as_str(); + + if operation != "upload" && operation != "download" { + return Err(GitError::InvalidOid(format!( + "Invalid operation: {}", + operation + ))); + } + + for obj in &req.objects { + if obj.size > LFS_MAX_OBJECT_SIZE { + return Err(GitError::InvalidOid(format!( + "Object size {} exceeds maximum allowed size {}", + obj.size, LFS_MAX_OBJECT_SIZE + ))); + } + } + + let oids: Vec = + req.objects.iter().map(|o| o.oid.clone()).collect(); + + let existing: Vec = + sqlx::query_as::<_, RepoLfsObjectModel>( + "SELECT repo, oid, size_bytes, storage_key, created_at \ + FROM repo_lfs_object \ + WHERE oid = ANY($1) AND repo = $2", + ) + .bind(&oids) + .bind(self.model.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + let existing_map: HashMap<&str, &RepoLfsObjectModel> = + existing.iter().map(|m| (m.oid.as_str(), m)).collect(); + + let mut response_objects = Vec::with_capacity(req.objects.len()); + + for obj in req.objects { + let existing = existing_map.get(obj.oid.as_str()); + + let mut actions = HashMap::new(); + + match operation { + "upload" => { + if existing.is_none() { + let upload_url = + self.build_object_url(base_url, &obj.oid); + + let token = Uuid::now_v7().to_string(); + let mut headers = HashMap::new(); + headers.insert( + "authorization".to_string(), + format!("Bearer {}", token), + ); + + actions.insert( + "upload".to_string(), + LfsAction { + href: upload_url, + header: Some(headers), + expires_in: Some(LFS_AUTH_TOKEN_EXPIRY as i64), + expires_at: None, + }, + ); + } + } + "download" => match existing { + Some(_) => { + let download_url = + self.build_object_url(base_url, &obj.oid); + + let token = Uuid::now_v7().to_string(); + let mut headers = HashMap::new(); + headers.insert( + "authorization".to_string(), + format!("Bearer {}", token), + ); + + actions.insert( + "download".to_string(), + LfsAction { + href: download_url, + header: Some(headers), + expires_in: Some(LFS_AUTH_TOKEN_EXPIRY as i64), + expires_at: None, + }, + ); + } + None => { + response_objects.push(LfsObjectResponse { + oid: obj.oid, + size: obj.size, + authenticated: None, + actions: None, + error: Some(LfsError { + code: 404, + message: "Object does not exist".to_string(), + }), + }); + continue; + } + }, + _ => {} + } + + response_objects.push(LfsObjectResponse { + oid: obj.oid, + size: obj.size, + authenticated: Some(true), + actions: if actions.is_empty() { + None + } else { + Some(actions) + }, + error: None, + }); + } + + Ok(BatchResponse { + transfer: "basic".to_string(), + objects: response_objects, + hash_algo: req.hash_algo, + }) + } + + pub async fn batch_with_auth( + &self, + req: BatchRequest, + base_url: &str, + user_id: uuid::Uuid, + cache: &AppCache, + ) -> Result { + let operation = req.operation.as_str(); + + if operation != "upload" && operation != "download" { + return Err(GitError::InvalidOid(format!( + "Invalid operation: {}", + operation + ))); + } + + for obj in &req.objects { + if obj.size > LFS_MAX_OBJECT_SIZE { + return Err(GitError::InvalidOid(format!( + "Object size {} exceeds maximum allowed size {}", + obj.size, LFS_MAX_OBJECT_SIZE + ))); + } + } + + let oids: Vec = + req.objects.iter().map(|o| o.oid.clone()).collect(); + + let existing: Vec = + sqlx::query_as::<_, RepoLfsObjectModel>( + "SELECT repo, oid, size_bytes, storage_key, created_at \ + FROM repo_lfs_object \ + WHERE oid = ANY($1) AND repo = $2", + ) + .bind(&oids) + .bind(self.model.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + let existing_map: HashMap<&str, &RepoLfsObjectModel> = + existing.iter().map(|m| (m.oid.as_str(), m)).collect(); + + let mut response_objects = Vec::with_capacity(req.objects.len()); + + for obj in req.objects { + let existing = existing_map.get(obj.oid.as_str()); + + let mut actions = HashMap::new(); + + match operation { + "upload" => { + if existing.is_none() { + let upload_url = + self.build_object_url(base_url, &obj.oid); + + let token = Uuid::now_v7().to_string(); + crate::http::lfs_routes::store_lfs_token( + cache, + &token, + self.model.id, + user_id, + "upload", + ) + .await; + + let mut headers = HashMap::new(); + headers.insert( + "authorization".to_string(), + format!("Bearer {}", token), + ); + + actions.insert( + "upload".to_string(), + LfsAction { + href: upload_url, + header: Some(headers), + expires_in: Some(LFS_AUTH_TOKEN_EXPIRY as i64), + expires_at: None, + }, + ); + } + } + "download" => match existing { + Some(_) => { + let download_url = + self.build_object_url(base_url, &obj.oid); + + let token = Uuid::now_v7().to_string(); + crate::http::lfs_routes::store_lfs_token( + cache, + &token, + self.model.id, + user_id, + "download", + ) + .await; + + let mut headers = HashMap::new(); + headers.insert( + "authorization".to_string(), + format!("Bearer {}", token), + ); + + actions.insert( + "download".to_string(), + LfsAction { + href: download_url, + header: Some(headers), + expires_in: Some(LFS_AUTH_TOKEN_EXPIRY as i64), + expires_at: None, + }, + ); + } + None => { + response_objects.push(LfsObjectResponse { + oid: obj.oid, + size: obj.size, + authenticated: None, + actions: None, + error: Some(LfsError { + code: 404, + message: "Object does not exist".to_string(), + }), + }); + continue; + } + }, + _ => {} + } + + response_objects.push(LfsObjectResponse { + oid: obj.oid, + size: obj.size, + authenticated: Some(true), + actions: if actions.is_empty() { + None + } else { + Some(actions) + }, + error: None, + }); + } + + Ok(BatchResponse { + transfer: "basic".to_string(), + objects: response_objects, + hash_algo: req.hash_algo, + }) + } + + pub async fn upload_object( + &self, + oid: &str, + payload: web::Payload, + ) -> Result { + if !is_valid_lfs_oid(oid) { + return Err(GitError::InvalidOid(format!( + "Invalid OID format: {}", + oid + ))); + } + + let object_path = self.get_object_path(oid); + if let Some(parent) = object_path.parent() { + tokio::fs::create_dir_all(parent).await.map_err(|e| { + GitError::Internal(format!("Failed to create directory: {}", e)) + })?; + } + + let temp_path = object_path.with_extension("tmp"); + let mut file = + tokio::fs::File::create(&temp_path).await.map_err(|e| { + GitError::Internal(format!("Failed to create temp file: {}", e)) + })?; + + use futures_util::stream::StreamExt; + use sha2::Digest; + use tokio::io::AsyncWriteExt; + + let mut payload = payload; + let mut size = 0i64; + let mut hasher = sha2::Sha256::new(); + + while let Some(chunk) = payload.next().await { + let chunk = chunk.map_err(|e| { + GitError::Internal(format!("Payload error: {}", e)) + })?; + size += chunk.len() as i64; + if size > LFS_MAX_OBJECT_SIZE { + let _ = tokio::fs::remove_file(&temp_path).await; + return Err(GitError::InvalidOid(format!( + "Object size exceeds maximum allowed size {}", + LFS_MAX_OBJECT_SIZE + ))); + } + hasher.update(&chunk); + if let Err(e) = file.write_all(&chunk).await { + let _ = tokio::fs::remove_file(&temp_path).await; + return Err(GitError::Internal(format!( + "Failed to write file: {}", + e + ))); + } + } + + file.flush().await.map_err(|e| { + GitError::Internal(format!("Failed to flush file: {}", e)) + })?; + drop(file); + + let hash_bytes = hasher.finalize(); + let calculated_oid = hex::encode(hash_bytes.as_slice()); + + if calculated_oid != oid { + let _ = tokio::fs::remove_file(&temp_path).await; + return Err(GitError::InvalidOid(format!( + "OID mismatch: expected {}, got {}", + oid, calculated_oid + ))); + } + + if let Err(e) = tokio::fs::rename(&temp_path, &object_path).await { + let _ = tokio::fs::remove_file(&temp_path).await; + return Err(GitError::Internal(format!( + "Failed to move file: {}", + e + ))); + } + + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO repo_lfs_object (repo, oid, size_bytes, storage_key, created_at) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(self.model.id) + .bind(oid) + .bind(size) + .bind(object_path.to_string_lossy().to_string()) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + Ok(HttpResponse::Ok().finish()) + } + + pub async fn download_object( + &self, + oid: &str, + ) -> Result { + if !is_valid_lfs_oid(oid) { + return Err(GitError::InvalidOid(format!( + "Invalid OID format: {}", + oid + ))); + } + + let obj = sqlx::query_as::<_, RepoLfsObjectModel>( + "SELECT repo, oid, size_bytes, storage_key, created_at \ + FROM repo_lfs_object \ + WHERE oid = $1 AND repo = $2", + ) + .bind(oid) + .bind(self.model.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))? + .ok_or_else(|| GitError::NotFound("Object not found".to_string()))?; + + let expected_base = self.get_lfs_storage_path(); + let obj_path = PathBuf::from(&obj.storage_key); + if !obj_path.starts_with(&expected_base) { + tracing::error!( + "LFS object path outside storage directory: {}", + obj.storage_key + ); + return Err(GitError::AuthFailed( + "Invalid object path".to_string(), + )); + } + + let file = tokio::fs::File::open(&obj_path).await.map_err(|e| { + GitError::Internal(format!("Failed to open file: {}", e)) + })?; + + use actix_web::body::BodyStream; + use futures_util::stream; + use tokio::io::AsyncReadExt; + + let chunk_size: usize = 65536; + + let stream = stream::unfold(file, move |mut file| async move { + let mut buffer = vec![0u8; chunk_size]; + match file.read(&mut buffer).await { + Ok(0) => None, + Ok(n) => { + buffer.truncate(n); + Some(( + Ok::<_, std::io::Error>(actix_web::web::Bytes::from( + buffer, + )), + file, + )) + } + Err(e) => Some((Err(e), file)), + } + }); + + Ok(HttpResponse::Ok() + .content_type("application/octet-stream") + .insert_header(("Content-Length", obj.size_bytes.to_string())) + .body(BodyStream::new(stream))) + } + + pub async fn lock_object( + &self, + oid: &str, + user_id: uuid::Uuid, + ) -> Result { + if !is_valid_lfs_oid(oid) { + return Err(GitError::InvalidOid(format!( + "Invalid OID format: {}", + oid + ))); + } + + let now = chrono::Utc::now(); + let lock_id = Uuid::now_v7(); + + let result = sqlx::query( + "INSERT INTO repo_lfs_lock (id, repo, path, locked_by, ref_name, created_at) \ + VALUES ($1, $2, $3, $4, NULL, $5)", + ) + .bind(lock_id) + .bind(self.model.id) + .bind(oid) + .bind(user_id) + .bind(now) + .execute(self.db.writer()) + .await; + + match result { + Ok(_) => Ok(LockResponse { + id: lock_id, + path: oid.to_string(), + locked_by: user_id, + locked_at: now.to_rfc3339(), + }), + Err(e) => { + let err_msg = format!("{}", e); + if err_msg.contains("duplicate key") + || err_msg.contains("23505") + { + return Err(GitError::Locked("Already locked".to_string())); + } + Err(GitError::Internal(format!("DB error: {}", e))) + } + } + } + + pub async fn unlock_object( + &self, + lock_id: &str, + user_id: uuid::Uuid, + ) -> Result<(), GitError> { + let lock_uuid = Uuid::parse_str(lock_id) + .map_err(|_| GitError::NotFound("Invalid lock ID".to_string()))?; + + let existing = sqlx::query_as::<_, RepoLfsLockModel>( + "SELECT id, repo, path, locked_by, ref_name, created_at \ + FROM repo_lfs_lock \ + WHERE id = $1 AND repo = $2", + ) + .bind(lock_uuid) + .bind(self.model.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))? + .ok_or_else(|| GitError::NotFound("Lock not found".to_string()))?; + + if existing.locked_by != user_id + && existing.locked_by != self.model.created_by + { + return Err(GitError::PermissionDenied( + "Not allowed to unlock".to_string(), + )); + } + + sqlx::query("DELETE FROM repo_lfs_lock WHERE id = $1 AND repo = $2") + .bind(lock_uuid) + .bind(self.model.id) + .execute(self.db.writer()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + Ok(()) + } + + pub async fn list_locks( + &self, + maybe_oid: Option<&str>, + ) -> Result, GitError> { + let rows: Vec = if let Some(oid) = maybe_oid { + sqlx::query_as::<_, RepoLfsLockModel>( + "SELECT id, repo, path, locked_by, ref_name, created_at \ + FROM repo_lfs_lock \ + WHERE repo = $1 AND path = $2", + ) + .bind(self.model.id) + .bind(oid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))? + } else { + sqlx::query_as::<_, RepoLfsLockModel>( + "SELECT id, repo, path, locked_by, ref_name, created_at \ + FROM repo_lfs_lock \ + WHERE repo = $1", + ) + .bind(self.model.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))? + }; + + Ok(rows + .into_iter() + .map(|r| LockResponse { + id: r.id, + path: r.path, + locked_by: r.locked_by, + locked_at: r.created_at.to_rfc3339(), + }) + .collect()) + } + + pub async fn get_lock( + &self, + lock_id: &str, + ) -> Result { + let lock_uuid = Uuid::parse_str(lock_id) + .map_err(|_| GitError::NotFound("Invalid lock ID".to_string()))?; + + let r = sqlx::query_as::<_, RepoLfsLockModel>( + "SELECT id, repo, path, locked_by, ref_name, created_at \ + FROM repo_lfs_lock \ + WHERE id = $1 AND repo = $2", + ) + .bind(lock_uuid) + .bind(self.model.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))? + .ok_or_else(|| GitError::NotFound("Lock not found".to_string()))?; + + Ok(LockResponse { + id: r.id, + path: r.path, + locked_by: r.locked_by, + locked_at: r.created_at.to_rfc3339(), + }) + } +} diff --git a/lib/git/http/lfs_routes.rs b/lib/git/http/lfs_routes.rs new file mode 100644 index 0000000..2671883 --- /dev/null +++ b/lib/git/http/lfs_routes.rs @@ -0,0 +1,550 @@ +use std::path::PathBuf; + +use actix_web::{Error, HttpRequest, HttpResponse, web}; +use argon2::{ + Argon2, + password_hash::{PasswordHash, PasswordVerifier}, +}; +use model::{ + repos::RepoModel, + users::{user::UserModel, user_token::UserTokenModel}, +}; + +use crate::{ + errors::GitError, + http::{ + HttpAppState, + auth::{authorize_repo_access, verify_access_token}, + handler::is_valid_lfs_oid, + lfs::{BatchRequest, CreateLockRequest, LfsHandler}, + utils::{extract_basic_credentials, get_repo_model}, + }, + ssh::authz::SshAuthService, + sync::push_queue::{ + PushQueueEvent, PushQueueLease, PushQueueWaitError, + wait_for_push_queue_slot, + }, +}; + +fn base_url(req: &HttpRequest) -> String { + let conn_info = req.connection_info(); + format!("{}://{}", conn_info.scheme(), conn_info.host()) +} + +fn bearer_token(req: &HttpRequest) -> Result { + let auth_header = req + .headers() + .get("authorization") + .ok_or_else(|| { + actix_web::error::ErrorUnauthorized("Missing authorization header") + })? + .to_str() + .map_err(|_| { + actix_web::error::ErrorUnauthorized("Invalid authorization header") + })?; + + if let Some(token) = auth_header.strip_prefix("Bearer ") { + Ok(token.to_string()) + } else { + Err(actix_web::error::ErrorUnauthorized( + "Invalid authorization format", + )) + } +} + +async fn user_uid( + req: &HttpRequest, + db: &db::database::AppDatabase, +) -> Result { + if let Ok((username, access_key)) = extract_basic_credentials(req) { + return verify_access_token(db, &username, &access_key) + .await + .map(|user| user.id); + } + + let token = bearer_token(req)?; + find_user_by_bearer_token(&token, db).await +} +pub async fn store_lfs_token( + cache: &cache::AppCache, + token: &str, + repo_id: uuid::Uuid, + user_uid: uuid::Uuid, + operation: &str, +) { + if let Some(mut conn) = cache.conn() { + use redis::AsyncCommands; + let value = format!("{}:{}:{}", repo_id, user_uid, operation); + let _: () = conn + .set_ex(format!("lfs:token:{}", token), value, 3600_u64) + .await + .map_err( + |e| tracing::warn!(error = %e, "failed to store lfs token"), + ) + .unwrap_or(()); + } +} + +async fn validate_lfs_token( + token: &str, + cache: &cache::AppCache, + db: &db::database::AppDatabase, + expected_repo_id: uuid::Uuid, + expected_operation: &str, +) -> Result { + if let Some(mut conn) = cache.conn() { + use redis::AsyncCommands; + let stored: Option = conn + .get::(format!("lfs:token:{}", token)) + .await + .ok(); + if let Some(value) = stored { + let parts: Vec<&str> = value.split(':').collect(); + if parts.len() == 3 { + let repo_id = + uuid::Uuid::parse_str(parts[0]).map_err(|_| { + actix_web::error::ErrorUnauthorized( + "Invalid batch token", + ) + })?; + let user_uid = + uuid::Uuid::parse_str(parts[1]).map_err(|_| { + actix_web::error::ErrorUnauthorized( + "Invalid batch token", + ) + })?; + let operation = parts[2]; + + if repo_id != expected_repo_id { + return Err(actix_web::error::ErrorUnauthorized( + "Token not valid for this repo", + )); + } + if operation != expected_operation { + return Err(actix_web::error::ErrorUnauthorized( + "Token not valid for this operation", + )); + } + + let _: Result<(), redis::RedisError> = + conn.del(format!("lfs:token:{}", token)).await; + + return Ok(user_uid); + } + } + } + + find_user_by_bearer_token(token, db).await +} + +async fn find_user_by_bearer_token( + token: &str, + db: &db::database::AppDatabase, +) -> Result { + let tokens: Vec = sqlx::query_as::<_, UserTokenModel>( + "SELECT id, \"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at \ + FROM user_token \ + WHERE is_revoked = false", + ) + .fetch_all(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorUnauthorized("Authentication failed"))?; + + for token_model in tokens { + if token_model + .expires_at + .map(|expires_at| expires_at < chrono::Utc::now()) + .unwrap_or(false) + { + continue; + } + + let Ok(hash) = PasswordHash::new(&token_model.token_hash) else { + tracing::warn!( + token_id = token_model.id, + "invalid stored bearer token hash" + ); + continue; + }; + if Argon2::default() + .verify_password(token.as_bytes(), &hash) + .is_ok() + { + return Ok(token_model.user); + } + } + + Err(actix_web::error::ErrorUnauthorized("Invalid token")) +} + +async fn authorize_user_repo_access( + db: &db::database::AppDatabase, + user_uid: uuid::Uuid, + repo: &RepoModel, + is_write: bool, +) -> Result<(), Error> { + let user = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" \ + WHERE id = $1", + ) + .bind(user_uid) + .fetch_optional(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorUnauthorized("Authentication failed"))? + .ok_or_else(|| actix_web::error::ErrorUnauthorized("Invalid token user"))?; + + let authz = SshAuthService::new(db.clone()); + if authz.check_repo_permission(&user, repo, is_write).await { + Ok(()) + } else { + Err(actix_web::error::ErrorForbidden( + "No permission for repository", + )) + } +} + +async fn acquire_lfs_write_queue( + state: &HttpAppState, + repo: &RepoModel, + operation: &'static str, +) -> Result { + match wait_for_push_queue_slot( + state.sync.clone(), + repo.id, + |event, request_id| { + let request_id = request_id.to_string(); + match event { + PushQueueEvent::Waiting(position) => { + tracing::info!( + repo = %repo.name, + repo_id = %repo.id, + request_id = %request_id, + operation = operation, + position = position.position, + total = position.total, + "lfs_write_queue_waiting" + ); + } + PushQueueEvent::Acquired => { + tracing::info!( + repo = %repo.name, + repo_id = %repo.id, + request_id = %request_id, + operation = operation, + "lfs_write_queue_acquired" + ); + } + } + }, + ) + .await + { + Ok(lease) => Ok(lease), + Err(PushQueueWaitError::Join(e)) => { + tracing::error!( + error = %e, + repo = %repo.name, + repo_id = %repo.id, + operation = operation, + "lfs_write_queue_join_failed" + ); + Err(actix_web::error::ErrorServiceUnavailable( + "LFS write queue is temporarily unavailable. Please retry later.", + )) + } + Err(PushQueueWaitError::Lock(e)) => { + tracing::error!( + error = %e, + repo = %repo.name, + repo_id = %repo.id, + operation = operation, + "lfs_write_queue_lock_failed" + ); + Err(actix_web::error::ErrorServiceUnavailable( + "LFS write queue lock failed. Please retry later.", + )) + } + Err(PushQueueWaitError::Timeout) => { + tracing::warn!( + repo = %repo.name, + repo_id = %repo.id, + operation = operation, + "lfs_write_queue_timeout" + ); + Err(actix_web::error::ErrorServiceUnavailable( + "LFS write queue timed out. Please retry in a moment.", + )) + } + } +} + +pub async fn lfs_batch( + req: HttpRequest, + path: web::Path<(String, String)>, + body: web::Json, + state: web::Data, +) -> Result { + let (namespace, repo_name) = path.into_inner(); + let batch_req = body.into_inner(); + let is_write = batch_req.operation == "upload"; + + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + + if repo.visibility != "public" || is_write { + let uid = user_uid(&req, &state.db).await?; + authorize_repo_access(&req, &state.db, &repo, is_write).await?; + + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo, + namespace, + state.db.clone(), + ); + let response = handler + .batch_with_auth(batch_req, &base_url(&req), uid, &state.cache) + .await + .map_err(|_| { + actix_web::error::ErrorInternalServerError("LFS batch failed") + })?; + Ok(HttpResponse::Ok() + .content_type("application/vnd.git-lfs+json") + .json(response)) + } else { + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo, + namespace, + state.db.clone(), + ); + let response = handler + .batch(batch_req, &base_url(&req)) + .await + .map_err(|_| { + actix_web::error::ErrorInternalServerError("LFS batch failed") + })?; + Ok(HttpResponse::Ok() + .content_type("application/vnd.git-lfs+json") + .json(response)) + } +} + +pub async fn lfs_upload( + req: HttpRequest, + path: web::Path<(String, String, String)>, + payload: web::Payload, + state: web::Data, +) -> Result { + let (namespace, repo_name, oid) = path.into_inner(); + + if !is_valid_lfs_oid(&oid) { + return Err(actix_web::error::ErrorBadRequest("Invalid OID format")); + } + + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + let token = bearer_token(&req)?; + + let uid = + validate_lfs_token(&token, &state.cache, &state.db, repo.id, "upload") + .await?; + authorize_user_repo_access(&state.db, uid, &repo, true).await?; + + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo.clone(), + namespace, + state.db.clone(), + ); + let mut queue_lease = + acquire_lfs_write_queue(&state, &handler.model, "upload").await?; + + let result = match handler.upload_object(&oid, payload).await { + Ok(response) => Ok(response), + Err(GitError::InvalidOid(_)) => { + Err(actix_web::error::ErrorBadRequest("Invalid OID")) + } + Err(GitError::AuthFailed(_)) => { + Err(actix_web::error::ErrorUnauthorized("Unauthorized")) + } + Err(_e) => { + Err(actix_web::error::ErrorInternalServerError("Upload failed")) + } + }; + queue_lease.release().await; + result +} + +pub async fn lfs_download( + req: HttpRequest, + path: web::Path<(String, String, String)>, + state: web::Data, +) -> Result { + let (namespace, repo_name, oid) = path.into_inner(); + + if !is_valid_lfs_oid(&oid) { + return Err(actix_web::error::ErrorBadRequest("Invalid OID format")); + } + + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + + if repo.visibility != "public" { + let token = bearer_token(&req)?; + let uid = validate_lfs_token( + &token, + &state.cache, + &state.db, + repo.id, + "download", + ) + .await?; + authorize_user_repo_access(&state.db, uid, &repo, false).await?; + } + + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo, + namespace, + state.db.clone(), + ); + + match handler.download_object(&oid).await { + Ok(response) => Ok(response), + Err(GitError::NotFound(_)) => { + Err(actix_web::error::ErrorNotFound("Object not found")) + } + Err(GitError::AuthFailed(_)) => { + Err(actix_web::error::ErrorUnauthorized("Unauthorized")) + } + Err(_e) => Err(actix_web::error::ErrorInternalServerError( + "Download failed", + )), + } +} + +pub async fn lfs_lock_create( + req: HttpRequest, + path: web::Path<(String, String)>, + body: web::Json, + state: web::Data, +) -> Result { + let (namespace, repo_name) = path.into_inner(); + + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + let uid = user_uid(&req, &state.db).await?; + authorize_repo_access(&req, &state.db, &repo, true).await?; + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo.clone(), + namespace, + state.db.clone(), + ); + let mut queue_lease = + acquire_lfs_write_queue(&state, &handler.model, "lock_create").await?; + + let result = match handler.lock_object(&body.oid, uid).await { + Ok(lock) => Ok(HttpResponse::Created().json(lock)), + Err(GitError::Locked(msg)) => Ok(HttpResponse::Conflict().body(msg)), + Err(_e) => { + Err(actix_web::error::ErrorInternalServerError("Lock failed")) + } + }; + queue_lease.release().await; + result +} + +pub async fn lfs_lock_list( + req: HttpRequest, + path: web::Path<(String, String)>, + query: web::Query>, + state: web::Data, +) -> Result { + let (namespace, repo_name) = path.into_inner(); + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + + if repo.visibility != "public" { + let uid = user_uid(&req, &state.db).await?; + authorize_user_repo_access(&state.db, uid, &repo, false).await?; + } + + let maybe_oid = query.get("oid").map(|s| s.as_str()); + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo, + namespace, + state.db.clone(), + ); + + match handler.list_locks(maybe_oid).await { + Ok(list) => Ok(HttpResponse::Ok().json(list)), + Err(_e) => Err(actix_web::error::ErrorInternalServerError( + "Lock list failed", + )), + } +} + +pub async fn lfs_lock_get( + req: HttpRequest, + path: web::Path<(String, String, String)>, + state: web::Data, +) -> Result { + let (namespace, repo_name, lock_id) = path.into_inner(); + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + + if repo.visibility != "public" { + let uid = user_uid(&req, &state.db).await?; + authorize_user_repo_access(&state.db, uid, &repo, false).await?; + } + + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo, + namespace, + state.db.clone(), + ); + + match handler.get_lock(&lock_id).await { + Ok(lock) => Ok(HttpResponse::Ok().json(lock)), + Err(GitError::NotFound(_)) => { + Err(actix_web::error::ErrorNotFound("Lock not found")) + } + Err(_e) => Err(actix_web::error::ErrorInternalServerError( + "Lock get failed", + )), + } +} + +pub async fn lfs_lock_delete( + req: HttpRequest, + path: web::Path<(String, String, String)>, + state: web::Data, +) -> Result { + let (namespace, repo_name, lock_id) = path.into_inner(); + + let repo = get_repo_model(&namespace, &repo_name, &state.db).await?; + let uid = user_uid(&req, &state.db).await?; + authorize_repo_access(&req, &state.db, &repo, true).await?; + let handler = LfsHandler::new( + PathBuf::from(&repo.storage_path), + repo.clone(), + namespace, + state.db.clone(), + ); + let mut queue_lease = + acquire_lfs_write_queue(&state, &handler.model, "lock_delete").await?; + + let result = match handler.unlock_object(&lock_id, uid).await { + Ok(()) => Ok(HttpResponse::NoContent().finish()), + Err(GitError::PermissionDenied(_)) => { + Err(actix_web::error::ErrorForbidden("Not allowed")) + } + Err(GitError::NotFound(_)) => { + Err(actix_web::error::ErrorNotFound("Lock not found")) + } + Err(_e) => Err(actix_web::error::ErrorInternalServerError( + "Lock delete failed", + )), + }; + queue_lease.release().await; + result +} diff --git a/lib/git/http/mod.rs b/lib/git/http/mod.rs new file mode 100644 index 0000000..3c524c0 --- /dev/null +++ b/lib/git/http/mod.rs @@ -0,0 +1,219 @@ +use std::{sync::Arc, time::Instant}; + +use actix_web::{App, HttpResponse, HttpServer, dev::Service, web}; +use cache::AppCache; +use config::AppConfig; +use db::database::AppDatabase; +use sqlx; + +pub mod action; +pub mod auth; +pub mod handler; +pub mod lfs; +pub mod lfs_routes; +pub mod rate_limit; +pub mod routes; +pub mod utils; + +const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[ + "/health", + "/live", + "/ready", + "/metrics", + "/favicon.ico", + "/robots.txt", +]; + +fn should_log_request(path: &str) -> bool { + !REQUEST_LOG_EXCLUDED_PATHS.contains(&path) +} + +#[derive(Clone)] +pub struct HttpAppState { + pub db: AppDatabase, + pub cache: AppCache, + pub sync: crate::sync::ReceiveSyncService, + pub rate_limiter: Arc, + pub config: AppConfig, + pub git_state: crate::AppGitState, +} + +async fn robots(state: web::Data) -> HttpResponse { + let sitemap_url = state + .config + .git_http_domain() + .map(|d| format!("{}/sitemap.xml", d.trim_end_matches('/'))) + .unwrap_or_default(); + + let body = if sitemap_url.is_empty() { + "User-agent: *\nDisallow: /\n".to_string() + } else { + format!("User-agent: *\nDisallow: /\n\nSitemap: {sitemap_url}\n") + }; + + HttpResponse::Ok() + .content_type("text/plain; charset=utf-8") + .body(body) +} + +async fn health(state: web::Data) -> HttpResponse { + let db_ok = sqlx::query("SELECT 1") + .execute(state.db.reader()) + .await + .is_ok(); + let cache_ok = state.cache.ping_cluster().await.is_ok(); + + if db_ok && cache_ok { + HttpResponse::Ok().json(serde_json::json!({ + "status": "ok", + "db": "ok", + "cache": "ok", + })) + } else { + HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "unhealthy", + "db": if db_ok { "ok" } else { "error" }, + "cache": if cache_ok { "ok" } else { "error" }, + })) + } +} + +pub fn git_http_cfg(cfg: &mut web::ServiceConfig) { + cfg.route("/robots.txt", web::get().to(robots)) + .route("/health", web::get().to(health)) + .route( + "/{namespace}/{repo_name}.git/action", + web::get().to(action::action_poll), + ) + .route( + "/{namespace}/{repo_name}.git/info/refs", + web::get().to(routes::info_refs), + ) + .route( + "/{namespace}/{repo_name}.git/git-upload-pack", + web::post().to(routes::upload_pack), + ) + .route( + "/{namespace}/{repo_name}.git/git-receive-pack", + web::post().to(routes::receive_pack), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/objects/batch", + web::post().to(lfs_routes::lfs_batch), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/objects/{oid}", + web::put().to(lfs_routes::lfs_upload), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/objects/{oid}", + web::get().to(lfs_routes::lfs_download), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/locks", + web::post().to(lfs_routes::lfs_lock_create), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/locks", + web::get().to(lfs_routes::lfs_lock_list), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/locks/{id}", + web::get().to(lfs_routes::lfs_lock_get), + ) + .route( + "/{namespace}/{repo_name}.git/info/lfs/locks/{id}", + web::delete().to(lfs_routes::lfs_lock_delete), + ) + .route( + "/{namespace}/{repo_name}.git/graphql", + web::post().to(crate::graphql::graphql_handle), + ); +} + +pub async fn run_http( + config: AppConfig, + db: AppDatabase, + cache: AppCache, + redis_pool: deadpool_redis::cluster::Pool, +) -> anyhow::Result<()> { + let sync = crate::sync::ReceiveSyncService::new(redis_pool); + + let rate_limiter = Arc::new(rate_limit::RateLimiter::new( + rate_limit::RateLimitConfig::default(), + )); + let _cleanup = rate_limiter.clone().start_cleanup(); + + let git_state = crate::AppGitState { + cache: cache.clone(), + db: db.clone(), + }; + + let state = HttpAppState { + db: db.clone(), + cache: cache.clone(), + sync, + rate_limiter, + config: config.clone(), + git_state, + }; + + let http_port = config.git_http_port()?; + tracing::info!("Starting git HTTP server on 0.0.0.0:{}", http_port); + + let server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(state.clone())) + .wrap_fn(|req, srv| { + let should_log = should_log_request(req.path()); + let method = req.method().clone(); + let path = req.path().to_owned(); + let peer_addr = + req.connection_info().peer_addr().map(str::to_owned); + let started_at = Instant::now(); + let fut = srv.call(req); + + async move { + match fut.await { + Ok(res) => { + if should_log { + tracing::info!( + method = %method, + path = %path, + status = res.status().as_u16(), + elapsed_ms = started_at.elapsed().as_millis(), + peer_addr = peer_addr.as_deref().unwrap_or("-"), + "http request" + ); + } + Ok(res) + } + Err(err) => { + if should_log { + tracing::warn!( + method = %method, + path = %path, + elapsed_ms = started_at.elapsed().as_millis(), + peer_addr = peer_addr.as_deref().unwrap_or("-"), + error = %err, + "http request failed" + ); + } + Err(err) + } + } + } + }) + .configure(git_http_cfg) + }) + .bind(format!("0.0.0.0:{}", http_port))? + .run(); + + let result = server.await; + if let Err(e) = result { + tracing::error!("HTTP server error: {}", e); + } + + tracing::info!("Git HTTP server stopped"); + Ok(()) +} diff --git a/lib/git/http/rate_limit.rs b/lib/git/http/rate_limit.rs new file mode 100644 index 0000000..41d3107 --- /dev/null +++ b/lib/git/http/rate_limit.rs @@ -0,0 +1,128 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; + +use tokio::{sync::RwLock, time::interval}; + +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + pub read_requests_per_window: u32, + pub write_requests_per_window: u32, + pub window_secs: u64, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + read_requests_per_window: 120, + write_requests_per_window: 30, + window_secs: 60, + } + } +} + +#[derive(Debug)] +struct RateLimitBucket { + read_count: u32, + write_count: u32, + reset_time: Instant, +} + +#[derive(Clone, Copy)] +enum BucketOp { + Read, + Write, +} + +pub struct RateLimiter { + buckets: Arc>>, + config: RateLimitConfig, +} + +impl RateLimiter { + pub fn new(config: RateLimitConfig) -> Self { + Self { + buckets: Arc::new(RwLock::new(HashMap::new())), + config, + } + } + + pub async fn is_read_allowed(&self) -> bool { + self.is_allowed( + "global:read", + BucketOp::Read, + self.config.read_requests_per_window, + ) + .await + } + + pub async fn is_write_allowed(&self) -> bool { + self.is_allowed( + "global:write", + BucketOp::Write, + self.config.write_requests_per_window, + ) + .await + } + + pub async fn is_repo_write_allowed(&self, repo_path: &str) -> bool { + let key = format!("repo:write:{}", repo_path); + self.is_allowed( + &key, + BucketOp::Write, + self.config.write_requests_per_window, + ) + .await + } + + async fn is_allowed(&self, key: &str, op: BucketOp, limit: u32) -> bool { + let now = Instant::now(); + let mut buckets = self.buckets.write().await; + + let bucket = + buckets + .entry(key.to_string()) + .or_insert_with(|| RateLimitBucket { + read_count: 0, + write_count: 0, + reset_time: now + + Duration::from_secs(self.config.window_secs), + }); + + if now >= bucket.reset_time { + bucket.read_count = 0; + bucket.write_count = 0; + bucket.reset_time = + now + Duration::from_secs(self.config.window_secs); + } + + let over_limit = match op { + BucketOp::Read => bucket.read_count >= limit, + BucketOp::Write => bucket.write_count >= limit, + }; + + if over_limit { + return false; + } + + match op { + BucketOp::Read => bucket.read_count += 1, + BucketOp::Write => bucket.write_count += 1, + } + true + } + + pub fn start_cleanup(self: Arc) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(300)); + loop { + ticker.tick().await; + let now = Instant::now(); + let mut buckets = self.buckets.write().await; + buckets.retain(|_, bucket| now < bucket.reset_time); + } + }) + } +} diff --git a/lib/git/http/routes.rs b/lib/git/http/routes.rs new file mode 100644 index 0000000..6498201 --- /dev/null +++ b/lib/git/http/routes.rs @@ -0,0 +1,179 @@ +use std::{path::PathBuf, time::Duration}; + +use actix_web::{Error, HttpRequest, HttpResponse, web}; +use tokio::time::timeout; + +use crate::{ + http::{ + HttpAppState, auth::authorize_repo_access, handler::GitHttpHandler, + utils::get_repo_model, + }, + sync::{ + RepoReceiveSyncTask, + push_queue::{ + PushQueueEvent, PushQueueWaitError, wait_for_push_queue_slot, + }, + }, +}; + +pub async fn info_refs( + req: HttpRequest, + path: web::Path<(String, String)>, + state: web::Data, +) -> Result { + if !state.rate_limiter.is_read_allowed().await { + return Err(actix_web::error::ErrorTooManyRequests( + "Rate limit exceeded", + )); + } + + let service_param = req + .query_string() + .split('&') + .find(|s| s.starts_with("service=")) + .and_then(|s| s.strip_prefix("service=")) + .ok_or_else(|| { + actix_web::error::ErrorBadRequest("Missing service parameter") + })?; + + if service_param != "git-upload-pack" && service_param != "git-receive-pack" + { + return Err(actix_web::error::ErrorBadRequest("Invalid service")); + } + + let path_inner = path.into_inner(); + let model = get_repo_model(&path_inner.0, &path_inner.1, &state.db).await?; + let is_write = service_param == "git-receive-pack"; + authorize_repo_access(&req, &state.db, &model, is_write).await?; + + let storage_path = PathBuf::from(&model.storage_path); + let handler = GitHttpHandler::new(storage_path, model, state.db.clone()); + handler.info_refs(service_param).await +} + +pub async fn upload_pack( + req: HttpRequest, + path: web::Path<(String, String)>, + payload: web::Payload, + state: web::Data, +) -> Result { + if !state.rate_limiter.is_read_allowed().await { + return Err(actix_web::error::ErrorTooManyRequests( + "Rate limit exceeded", + )); + } + + let path_inner = path.into_inner(); + let model = get_repo_model(&path_inner.0, &path_inner.1, &state.db).await?; + authorize_repo_access(&req, &state.db, &model, false).await?; + + let storage_path = PathBuf::from(&model.storage_path); + let handler = GitHttpHandler::new(storage_path, model, state.db.clone()); + handler.upload_pack(payload).await +} + +pub async fn receive_pack( + req: HttpRequest, + path: web::Path<(String, String)>, + payload: web::Payload, + state: web::Data, +) -> Result { + if !state.rate_limiter.is_write_allowed().await { + return Err(actix_web::error::ErrorTooManyRequests( + "Rate limit exceeded", + )); + } + + let path_inner = path.into_inner(); + let model = get_repo_model(&path_inner.0, &path_inner.1, &state.db).await?; + authorize_repo_access(&req, &state.db, &model, true).await?; + + let mut push_queue_lease = match wait_for_push_queue_slot( + state.sync.clone(), + model.id, + |event, request_id| { + let request_id = request_id.to_string(); + let repo_name = model.name.clone(); + let repo_id = model.id; + match event { + PushQueueEvent::Waiting(position) => { + tracing::info!( + repo = %repo_name, + repo_id = %repo_id, + request_id = %request_id, + position = position.position, + total = position.total, + "http_push_queue_waiting" + ); + } + PushQueueEvent::Acquired => { + tracing::info!( + repo = %repo_name, + repo_id = %repo_id, + request_id = %request_id, + "http_push_queue_acquired" + ); + } + } + }, + ) + .await + { + Ok(lease) => lease, + Err(PushQueueWaitError::Join(e)) => { + tracing::error!( + error = %e, + repo = %model.name, + repo_id = %model.id, + "http_push_queue_join_failed" + ); + return Err(actix_web::error::ErrorServiceUnavailable( + "Push queue is temporarily unavailable. Please retry later.", + )); + } + Err(PushQueueWaitError::Lock(e)) => { + tracing::error!( + error = %e, + repo = %model.name, + repo_id = %model.id, + "http_push_queue_lock_failed" + ); + return Err(actix_web::error::ErrorServiceUnavailable( + "Push queue lock failed. Please retry later.", + )); + } + Err(PushQueueWaitError::Timeout) => { + tracing::info!( + repo = %model.name, + repo_id = %model.id, + "http_push_queue_timeout" + ); + return Ok(HttpResponse::ServiceUnavailable() + .insert_header(("Retry-After", "5")) + .content_type("text/plain; charset=utf-8") + .body("Push queue timed out. Please retry in a moment.\n")); + } + }; + + let storage_path = PathBuf::from(&model.storage_path); + let handler = + GitHttpHandler::new(storage_path, model.clone(), state.db.clone()); + let result = handler.receive_pack(payload).await; + push_queue_lease.release().await; + + if result.is_ok() { + let _ = tokio::spawn({ + let sync = state.sync.clone(); + let repo_uid = model.id; + async move { + let _ = timeout( + Duration::from_secs(5), + sync.send(RepoReceiveSyncTask { repo_uid }), + ) + .await; + } + }); + } + + result +} diff --git a/lib/git/http/utils.rs b/lib/git/http/utils.rs new file mode 100644 index 0000000..1990124 --- /dev/null +++ b/lib/git/http/utils.rs @@ -0,0 +1,98 @@ +use actix_web::{Error, HttpRequest}; +use base64::{Engine, engine::general_purpose::STANDARD}; +use db::database::AppDatabase; +use model::{ + repos::RepoModel, + workspace::{ + wk_history_name::WkHistoryNameModel, workspace::WorkspaceModel, + }, +}; + +pub async fn get_repo_model( + namespace: &str, + repo_name: &str, + db: &AppDatabase, +) -> Result { + let wk_id = if let Some(wk) = sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at FROM workspace WHERE name = $1", + ) + .bind(namespace) + .fetch_optional(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorInternalServerError("Database error"))? + { + wk.id + } else if let Some(history) = sqlx::query_as::<_, WkHistoryNameModel>( + "SELECT id, wk, name, changed_by, created_at FROM wk_history_name WHERE name = $1", + ) + .bind(namespace) + .fetch_optional(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorInternalServerError("Database error"))? + { + history.wk + } else { + return Err(actix_web::error::ErrorNotFound("Project not found").into()); + }; + + let repo = sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, \ + created_at, updated_at, deleted_at \ + FROM repo \ + WHERE name = $1 AND wk = $2 AND deleted_at IS NULL", + ) + .bind(repo_name) + .bind(wk_id) + .fetch_optional(db.reader()) + .await + .map_err(|_| actix_web::error::ErrorInternalServerError("Database error"))? + .ok_or_else(|| actix_web::error::ErrorNotFound("Repository not found"))?; + + Ok(repo) +} + +pub fn extract_basic_credentials( + req: &HttpRequest, +) -> Result<(String, String), Error> { + let auth_header = req + .headers() + .get("authorization") + .ok_or_else(|| { + actix_web::error::ErrorUnauthorized("Missing authorization header") + })? + .to_str() + .map_err(|_| { + actix_web::error::ErrorUnauthorized("Invalid authorization header") + })?; + + let encoded = auth_header.strip_prefix("Basic ").ok_or_else(|| { + actix_web::error::ErrorUnauthorized("Invalid authorization scheme") + })?; + + let decoded = STANDARD.decode(encoded).map_err(|_| { + actix_web::error::ErrorUnauthorized( + "Invalid basic authorization encoding", + ) + })?; + + let decoded = String::from_utf8(decoded).map_err(|_| { + actix_web::error::ErrorUnauthorized( + "Invalid basic authorization payload", + ) + })?; + + let (username, access_key) = decoded.split_once(':').ok_or_else(|| { + actix_web::error::ErrorUnauthorized( + "Invalid basic authorization format", + ) + })?; + + if username.is_empty() || access_key.is_empty() { + return Err(actix_web::error::ErrorUnauthorized( + "Username or access key is empty", + )); + } + + Ok((username.to_string(), access_key.to_string())) +} diff --git a/lib/git/lib.rs b/lib/git/lib.rs new file mode 100644 index 0000000..ef04d76 --- /dev/null +++ b/lib/git/lib.rs @@ -0,0 +1,18 @@ +use cache::AppCache; +use db::database::AppDatabase; + +pub mod bare; +pub mod cmd; +pub mod errors; +pub mod graphql; +pub mod http; +pub mod rpc; +pub mod ssh; +pub mod sync; +#[derive(Clone)] +pub struct AppGitState { + pub cache: AppCache, + pub db: AppDatabase, +} + +pub mod role; diff --git a/lib/git/proto/archive.proto b/lib/git/proto/archive.proto new file mode 100644 index 0000000..15d2593 --- /dev/null +++ b/lib/git/proto/archive.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/archive/mod.rs — ArchiveOptions +message ArchiveOptions { + ObjectId tree = 1; + optional string prefix = 2; + repeated string pathspec = 3; +} + +// Mirrors: cmd/archive/mod.rs — ArchiveResult +message ArchiveResult { + bytes bytes = 1; +} + +message ArchiveTarRequest { + string repo_id = 1; + ArchiveOptions options = 2; +} + +message ArchiveTarResponse { + bytes data = 1; +} + +message ArchiveZipRequest { + string repo_id = 1; + ArchiveOptions options = 2; +} + +message ArchiveZipResponse { + bytes data = 1; +} + +service ArchiveService { + rpc ArchiveTar(ArchiveTarRequest) returns (ArchiveTarResponse); + rpc ArchiveZip(ArchiveZipRequest) returns (ArchiveZipResponse); +} \ No newline at end of file diff --git a/lib/git/proto/blame.proto b/lib/git/proto/blame.proto new file mode 100644 index 0000000..a3c9a6c --- /dev/null +++ b/lib/git/proto/blame.proto @@ -0,0 +1,75 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/blame/mod.rs — CommitBlameHunk +message CommitBlameHunk { + ObjectId commit_oid = 1; + uint32 final_start_line = 2; + uint32 final_lines = 3; + uint32 orig_start_line = 4; + uint32 orig_lines = 5; + bool boundary = 6; + optional string orig_path = 7; +} + +// Mirrors: cmd/blame/mod.rs — CommitBlameLine +message CommitBlameLine { + ObjectId commit_oid = 1; + uint32 line_no = 2; + string content = 3; + optional string orig_path = 4; +} + +// Mirrors: cmd/blame/mod.rs — BlameOptions +message BlameOptions { + optional uint64 min_line = 1; + optional uint64 max_line = 2; + bool track_copies_same_file = 3; + bool track_copies_same_commit_moves = 4; + bool ignore_whitespace = 5; +} + +message BlameFileRequest { + string repo_id = 1; + string path = 2; + optional string rev = 3; + optional BlameOptions options = 4; +} + +message BlameFileResponse { + repeated CommitBlameHunk hunks = 1; +} + +message BlameHunkRequest { + string repo_id = 1; + string path = 2; + optional string rev = 3; + uint32 start_line = 4; + uint32 end_line = 5; +} + +message BlameHunkResponse { + repeated CommitBlameHunk hunks = 1; +} + +message BlameLinesRequest { + string repo_id = 1; + string path = 2; + optional string rev = 3; + uint32 start_line = 4; + uint32 end_line = 5; +} + +message BlameLinesResponse { + repeated CommitBlameLine lines = 1; +} + +service BlameService { + rpc BlameFile(BlameFileRequest) returns (BlameFileResponse); + rpc BlameStream(BlameFileRequest) returns (stream CommitBlameHunk); + rpc BlameHunk(BlameHunkRequest) returns (BlameHunkResponse); + rpc BlameLines(BlameLinesRequest) returns (BlameLinesResponse); +} \ No newline at end of file diff --git a/lib/git/proto/blob.proto b/lib/git/proto/blob.proto new file mode 100644 index 0000000..8e65709 --- /dev/null +++ b/lib/git/proto/blob.proto @@ -0,0 +1,113 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/blob/blob_load.rs — BlobLoadParams +message BlobLoadParams { + ObjectId id = 1; + string path = 2; +} + +// Mirrors: cmd/blob/blob_load.rs — BlobLoadResult +message BlobLoadResult { + BlobLoadParams params = 1; + bytes blob = 2; +} + +// Mirrors: cmd/blob/blob_size.rs — BlobSizeParams +message BlobSizeParams { + ObjectId id = 1; + string path = 2; +} + +// Mirrors: cmd/blob/blob_upload.rs — BlobUploadParams +message BlobUploadParams { + bytes blob = 1; + string path = 2; +} + +// Mirrors: cmd/blob/blob_upload.rs — BlobUploadResult +message BlobUploadResult { + ObjectId id = 1; +} + +// Mirrors: cmd/blob/blob_chunk.rs — BlobChunkParam +message BlobChunkParam { + string path = 1; + ObjectId oid = 2; + uint64 size = 3; + uint64 offset = 4; +} + +// Mirrors: cmd/blob/blob_chunk.rs — BlobChunk +message BlobChunk { + BlobChunkParam param = 1; + bytes chunk = 2; +} + +message BlobLoadRequest { + string repo_id = 1; + ObjectId id = 2; + string path = 3; +} + +message BlobLoadResponse { + bytes blob = 1; +} + +message BlobSizeRequest { + string repo_id = 1; + ObjectId id = 2; + string path = 3; +} + +message BlobSizeResponse { + uint64 size = 1; +} + +message BlobExistsRequest { + string repo_id = 1; + ObjectId id = 2; +} + +message BlobExistsResponse { + bool exists = 1; +} + +message BlobIsBinaryRequest { + string repo_id = 1; + ObjectId id = 2; +} + +message BlobIsBinaryResponse { + bool is_binary = 1; +} + +message BlobUploadRequest { + string repo_id = 1; + bytes blob = 2; + string path = 3; +} + +message BlobUploadResponse { + ObjectId id = 1; +} + +message BlobChunkRequest { + string repo_id = 1; + string path = 2; + ObjectId oid = 3; + uint64 size = 4; + uint64 offset = 5; +} + +service BlobService { + rpc BlobLoad(BlobLoadRequest) returns (BlobLoadResponse); + rpc BlobSize(BlobSizeRequest) returns (BlobSizeResponse); + rpc BlobExists(BlobExistsRequest) returns (BlobExistsResponse); + rpc BlobIsBinary(BlobIsBinaryRequest) returns (BlobIsBinaryResponse); + rpc BlobUpload(BlobUploadRequest) returns (BlobUploadResponse); + rpc BlobChunkStream(BlobChunkRequest) returns (stream BlobChunk); +} \ No newline at end of file diff --git a/lib/git/proto/branch.proto b/lib/git/proto/branch.proto new file mode 100644 index 0000000..9516d86 --- /dev/null +++ b/lib/git/proto/branch.proto @@ -0,0 +1,128 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/branch/branch_list.rs — BranchListItem +message BranchListItem { + string name = 1; + ObjectId oid = 2; + bool is_head = 3; + bool is_remote = 4; + bool is_current = 5; + optional string upstream = 6; +} + +// Mirrors: cmd/branch/branch_summary.rs — BranchSummary +message BranchSummary { + uint64 local_count = 1; + uint64 remote_count = 2; + uint64 all_count = 3; +} + +// Mirrors: cmd/branch/branch_fork.rs — BranchForkParams +message BranchForkParams { + string name = 1; + ObjectId oid = 2; + bool force = 3; +} + +// Mirrors: cmd/branch/branch_delete.rs — BranchDeleteParams +message BranchDeleteParams { + string name = 1; + bool force = 2; +} + +// Mirrors: cmd/branch/branch_rename.rs — BranchReNameParams +message BranchReNameParams { + string old_branch = 1; + string new_branch = 2; + bool force = 3; +} + +message BranchListRequest { + string repo_id = 1; +} + +message BranchListResponse { + repeated BranchListItem branches = 1; +} + +message BranchInfoRequest { + string repo_id = 1; + string branch = 2; +} + +message BranchInfoResponse { + BranchListItem branch = 1; +} + +message BranchSummaryRequest { + string repo_id = 1; +} + +message BranchSummaryResponse { + BranchSummary summary = 1; +} + +message BranchHeadRequest { + string repo_id = 1; +} + +message BranchHeadResponse { + string head_name = 1; +} + +message BranchAheadBehindRequest { + string repo_id = 1; + string local_branch = 2; + string remote_branch = 3; +} + +message BranchAheadBehindResponse { + uint64 ahead = 1; + uint64 behind = 2; +} + +message BranchUpstreamRequest { + string repo_id = 1; + string branch = 2; +} + +message BranchUpstreamResponse { + string upstream_name = 1; +} + +message BranchForkRequest { + string repo_id = 1; + BranchForkParams params = 2; +} + +message BranchForkResponse {} + +message BranchDeleteRequest { + string repo_id = 1; + BranchDeleteParams params = 2; +} + +message BranchDeleteResponse {} + +message BranchRenameRequest { + string repo_id = 1; + BranchReNameParams params = 2; +} + +message BranchRenameResponse {} + +service BranchService { + rpc BranchList(BranchListRequest) returns (BranchListResponse); + rpc BranchInfo(BranchInfoRequest) returns (BranchInfoResponse); + rpc BranchSummary(BranchSummaryRequest) returns (BranchSummaryResponse); + rpc BranchHead(BranchHeadRequest) returns (BranchHeadResponse); + rpc BranchAheadBehind(BranchAheadBehindRequest) returns (BranchAheadBehindResponse); + rpc BranchUpstream(BranchUpstreamRequest) returns (BranchUpstreamResponse); + rpc BranchFork(BranchForkRequest) returns (BranchForkResponse); + rpc BranchDelete(BranchDeleteRequest) returns (BranchDeleteResponse); + rpc BranchRename(BranchRenameRequest) returns (BranchRenameResponse); +} \ No newline at end of file diff --git a/lib/git/proto/commit.proto b/lib/git/proto/commit.proto new file mode 100644 index 0000000..b3ef1bd --- /dev/null +++ b/lib/git/proto/commit.proto @@ -0,0 +1,183 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/commit/mod.rs — CommitMeta +message CommitMeta { + ObjectId oid = 1; + string message = 2; + string summary = 3; + CommitSignature author = 4; + CommitSignature committer = 5; + ObjectId tree_id = 6; + repeated ObjectId parent_ids = 7; + optional string encoding = 8; +} + +// Mirrors: cmd/commit/mod.rs — CommitRefInfo +message CommitRefInfo { + string name = 1; + ObjectId target = 2; + bool is_remote = 3; + bool is_tag = 4; +} + +// Mirrors: cmd/commit/commit_summary.rs — CommitSummary +message CommitSummary { + optional CommitMeta head = 1; + uint64 count = 2; +} + +// Mirrors: cmd/commit/commit_walker.rs — CommitWalkSort +enum CommitWalkSort { + COMMIT_WALK_SORT_NONE = 0; + COMMIT_WALK_SORT_TOPOLOGICAL = 1; + COMMIT_WALK_SORT_TIME = 2; + COMMIT_WALK_SORT_REVERSE = 3; +} + +// Mirrors: cmd/commit/commit_walker.rs — CommitWalkParams +message CommitWalkParams { + repeated ObjectId start_oids = 1; + repeated ObjectId hide_oids = 2; + optional uint64 limit = 3; + uint64 skip = 4; + bool first_parent = 5; + CommitWalkSort sort = 6; +} + +// Mirrors: cmd/commit/commit_cherry_pick.rs — CommitCherryPickParams +message CommitCherryPickParams { + ObjectId cherrypick_oid = 1; + CommitSignature author = 2; + CommitSignature committer = 3; + optional string message = 4; + uint32 mainline = 5; + optional string update_ref = 6; +} + +// Mirrors: cmd/commit/commit_cherry_pick.rs — CommitCherryPickSequence +message CommitCherryPickSequence { + repeated ObjectId cherrypick_oids = 1; + CommitSignature author = 2; + CommitSignature committer = 3; + optional string update_ref = 4; +} + +message CommitInfoRequest { + string repo_id = 1; + ObjectId oid = 2; +} + +message CommitInfoResponse { + CommitMeta commit = 1; +} + +message CommitHistoryRequest { + string repo_id = 1; + uint64 limit = 2; + uint64 skip = 3; + CommitWalkSort sort = 4; + optional string branch = 5; +} + +message CommitHistoryResponse { + repeated CommitMeta commits = 1; +} + +message CommitSummaryRequest { + string repo_id = 1; +} + +message CommitSummaryResponse { + CommitSummary summary = 1; +} + +message CommitWalkRequest { + string repo_id = 1; + CommitWalkParams params = 2; +} + +message CommitWalkResponse { + repeated CommitMeta commits = 1; +} + +message CommitRefsRequest { + string repo_id = 1; +} + +message CommitRefsResponse { + repeated CommitRefInfo refs = 1; +} + +message CommitPrefixRequest { + string repo_id = 1; + string prefix = 2; +} + +message CommitPrefixResponse { + ObjectId oid = 1; +} + +message CommitExistsRequest { + string repo_id = 1; + ObjectId oid = 2; +} + +message CommitExistsResponse { + bool exists = 1; +} + +message CherryPickRequest { + string repo_id = 1; + CommitCherryPickParams params = 2; +} + +message CherryPickResponse { + ObjectId oid = 1; +} + +message CherryPickSequenceRequest { + string repo_id = 1; + CommitCherryPickSequence params = 2; +} + +message CherryPickSequenceResponse { + ObjectId oid = 1; +} + +message FileChange { + string path = 1; + bytes content = 2; +} + +message CreateCommitRequest { + string repo_id = 1; + string branch = 2; + string message = 3; + string author_name = 4; + string author_email = 5; + string committer_name = 6; + string committer_email = 7; + repeated FileChange files = 8; +} + +message CreateCommitResponse { + ObjectId oid = 1; +} + +service CommitService { + rpc CommitInfo(CommitInfoRequest) returns (CommitInfoResponse); + rpc CommitHistory(CommitHistoryRequest) returns (CommitHistoryResponse); + rpc CommitHistoryStream(CommitHistoryRequest) returns (stream CommitMeta); + rpc CommitSummary(CommitSummaryRequest) returns (CommitSummaryResponse); + rpc CommitWalk(CommitWalkRequest) returns (CommitWalkResponse); + rpc CommitRefs(CommitRefsRequest) returns (CommitRefsResponse); + rpc CommitPrefix(CommitPrefixRequest) returns (CommitPrefixResponse); + rpc CommitExists(CommitExistsRequest) returns (CommitExistsResponse); + rpc CherryPick(CherryPickRequest) returns (CherryPickResponse); + rpc CherryPickSequence(CherryPickSequenceRequest) returns (CherryPickSequenceResponse); + rpc CreateCommit(CreateCommitRequest) returns (CreateCommitResponse); +} \ No newline at end of file diff --git a/lib/git/proto/common.proto b/lib/git/proto/common.proto new file mode 100644 index 0000000..48f636f --- /dev/null +++ b/lib/git/proto/common.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package git.v1; + +// Mirrors: cmd/oid.rs — ObjectId(pub String) +message ObjectId { + string value = 1; +} + +// Mirrors: cmd/commit/mod.rs — CommitSignature +message CommitSignature { + string name = 1; + string email = 2; + int64 time_secs = 3; + int32 offset_minutes = 4; +} + +// Mirrors: cmd/tagger.rs — GitTagger +message GitTagger { + string name = 1; + string email = 2; +} \ No newline at end of file diff --git a/lib/git/proto/diff.proto b/lib/git/proto/diff.proto new file mode 100644 index 0000000..eff8c23 --- /dev/null +++ b/lib/git/proto/diff.proto @@ -0,0 +1,174 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/diff/mod.rs — DiffDeltaStatus +enum DiffDeltaStatus { + DIFF_DELTA_STATUS_UNMODIFIED = 0; + DIFF_DELTA_STATUS_ADDED = 1; + DIFF_DELTA_STATUS_DELETED = 2; + DIFF_DELTA_STATUS_MODIFIED = 3; + DIFF_DELTA_STATUS_RENAMED = 4; + DIFF_DELTA_STATUS_COPIED = 5; + DIFF_DELTA_STATUS_TYPECHANGE = 6; + DIFF_DELTA_STATUS_CONFLICTED = 7; +} + +// Mirrors: cmd/diff/mod.rs — DiffFile +message DiffFile { + ObjectId oid = 1; + optional string path = 2; + uint64 size = 3; + bool is_binary = 4; +} + +// Mirrors: cmd/diff/mod.rs — DiffHunk +message DiffHunk { + uint32 old_start = 1; + uint32 old_lines = 2; + uint32 new_start = 3; + uint32 new_lines = 4; + string header = 5; +} + +// Mirrors: cmd/diff/mod.rs — DiffDelta +message DiffDelta { + DiffDeltaStatus status = 1; + DiffFile old_file = 2; + DiffFile new_file = 3; + uint32 nfiles = 4; + repeated DiffHunk hunks = 5; + repeated DiffLine lines = 6; +} + +// Mirrors: cmd/diff/mod.rs — DiffLine +message DiffLine { + string content = 1; + string origin = 2; + optional uint32 old_lineno = 3; + optional uint32 new_lineno = 4; + uint32 num_lines = 5; + int64 content_offset = 6; +} + +// Mirrors: cmd/diff/mod.rs — DiffStats +message DiffStats { + uint64 files_changed = 1; + uint64 insertions = 2; + uint64 deletions = 3; +} + +// Mirrors: cmd/diff/mod.rs — DiffResult +message DiffResult { + DiffStats stats = 1; + repeated DiffDelta deltas = 2; +} + +// Mirrors: cmd/diff/mod.rs — DiffOptions +message DiffOptions { + uint32 context_lines = 1; + repeated string pathspec = 2; + bool ignore_whitespace = 3; + bool force_text = 4; + bool reverse = 5; +} + +// Mirrors: cmd/diff/mod.rs — SideBySideChangeType +enum SideBySideChangeType { + SIDE_BY_SIDE_CHANGE_TYPE_UNCHANGED = 0; + SIDE_BY_SIDE_CHANGE_TYPE_ADDED = 1; + SIDE_BY_SIDE_CHANGE_TYPE_REMOVED = 2; + SIDE_BY_SIDE_CHANGE_TYPE_MODIFIED = 3; + SIDE_BY_SIDE_CHANGE_TYPE_EMPTY = 4; +} + +// Mirrors: cmd/diff/mod.rs — SideBySideLine +message SideBySideLine { + optional uint32 left_line_no = 1; + optional uint32 right_line_no = 2; + string left_content = 3; + string right_content = 4; + SideBySideChangeType change_type = 5; +} + +// Mirrors: cmd/diff/mod.rs — SideBySideFile +message SideBySideFile { + string path = 1; + uint64 additions = 2; + uint64 deletions = 3; + bool is_binary = 4; + bool is_rename = 5; + repeated SideBySideLine lines = 6; +} + +// Mirrors: cmd/diff/mod.rs — SideBySideDiffResult +message SideBySideDiffResult { + repeated SideBySideFile files = 1; + uint64 total_additions = 2; + uint64 total_deletions = 3; +} + +message DiffStatsRequest { + string repo_id = 1; + ObjectId old_oid = 2; + ObjectId new_oid = 3; + optional DiffOptions options = 4; +} + +message DiffStatsResponse { + DiffResult result = 1; +} + +message DiffPatchRequest { + string repo_id = 1; + ObjectId old_oid = 2; + ObjectId new_oid = 3; + optional DiffOptions options = 4; +} + +message DiffPatchResponse { + DiffResult result = 1; +} + +message DiffPatchSideBySideRequest { + string repo_id = 1; + ObjectId old_oid = 2; + ObjectId new_oid = 3; + optional DiffOptions options = 4; +} + +message DiffPatchSideBySideResponse { + SideBySideDiffResult result = 1; +} + +message DiffTreeToTreeRequest { + string repo_id = 1; + ObjectId old_tree = 2; + ObjectId new_tree = 3; + optional DiffOptions options = 4; +} + +message DiffTreeToTreeResponse { + DiffResult result = 1; +} + +message DiffIndexToTreeRequest { + string repo_id = 1; + ObjectId tree_oid = 2; + optional DiffOptions options = 3; +} + +message DiffIndexToTreeResponse { + DiffResult result = 1; +} + +service DiffService { + rpc DiffStats(DiffStatsRequest) returns (DiffStatsResponse); + rpc DiffPatch(DiffPatchRequest) returns (DiffPatchResponse); + rpc DiffStream(DiffPatchRequest) returns (stream DiffDelta); + rpc DiffPatchSideBySide(DiffPatchSideBySideRequest) returns (DiffPatchSideBySideResponse); + rpc DiffTreeToTree(DiffTreeToTreeRequest) returns (DiffTreeToTreeResponse); + rpc DiffIndexToTree(DiffIndexToTreeRequest) returns (DiffIndexToTreeResponse); +} \ No newline at end of file diff --git a/lib/git/proto/fork.proto b/lib/git/proto/fork.proto new file mode 100644 index 0000000..cfc0561 --- /dev/null +++ b/lib/git/proto/fork.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Fork a bare repository from an existing source repo. +message ForkRepoParams { + string namespace = 1; + string repo_name = 2; + string default_branch = 3; + optional string description = 4; + bool enable_lfs = 5; +} + +message ForkBareRequest { + string storage_root = 1; + string source_storage_path = 2; + ForkRepoParams params = 3; +} + +message ForkBareResponse { + string storage_path = 1; +} + +service ForkService { + rpc ForkBare(ForkBareRequest) returns (ForkBareResponse); +} \ No newline at end of file diff --git a/lib/git/proto/init.proto b/lib/git/proto/init.proto new file mode 100644 index 0000000..56a8b70 --- /dev/null +++ b/lib/git/proto/init.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package git.v1; + +// Mirrors: cmd/init.rs — InitRepositoriesParams +message InitRepoParams { + string namespace = 1; + string repo_name = 2; + string default_branch = 3; + optional string description = 4; + bool initialize_with_readme = 5; + bool enable_lfs = 6; +} + +message InitBareRequest { + string storage_root = 1; + InitRepoParams params = 2; +} + +message InitBareResponse { + string storage_path = 1; +} + +message SetDefaultBranchRequest { + string repo_id = 1; + string branch_name = 2; +} + +message SetDefaultBranchResponse {} + +message CloneBareRequest { + string storage_root = 1; + string source_url = 2; + string namespace = 3; + string repo_name = 4; +} + +message CloneBareResponse { + string storage_path = 1; +} + +service InitService { + rpc InitBare(InitBareRequest) returns (InitBareResponse); + rpc SetDefaultBranch(SetDefaultBranchRequest) returns (SetDefaultBranchResponse); + rpc CloneBare(CloneBareRequest) returns (CloneBareResponse); +} \ No newline at end of file diff --git a/lib/git/proto/merge.proto b/lib/git/proto/merge.proto new file mode 100644 index 0000000..cdbdd17 --- /dev/null +++ b/lib/git/proto/merge.proto @@ -0,0 +1,170 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/merge/mod.rs — MergeAnalysisResult +message MergeAnalysisResult { + bool is_none = 1; + bool is_normal = 2; + bool is_up_to_date = 3; + bool is_fast_forward = 4; + bool is_unborn = 5; +} + +// Mirrors: cmd/merge/mod.rs — MergePreferenceResult +message MergePreferenceResult { + bool is_none = 1; + bool is_no_fast_forward = 2; + bool is_fastforward_only = 3; +} + +// Mirrors: cmd/merge/mod.rs — MergeOptions +message MergeOptions { + bool find_renames = 1; + bool fail_on_conflict = 2; + bool skip_reuc = 3; + bool no_recursive = 4; + uint32 rename_threshold = 5; + uint32 target_limit = 6; + uint32 recursion_limit = 7; +} + +// Mirrors: cmd/merge/merge_commit.rs — MergeCommitParams +message MergeCommitParams { + ObjectId their_commit = 1; + CommitSignature author = 2; + CommitSignature committer = 3; + string message = 4; + optional string update_ref = 5; + optional MergeOptions options = 6; +} + +// Mirrors: cmd/merge/merge_tree.rs — MergeTreeResult +message MergeTreeResult { + ObjectId tree_id = 1; + bool has_conflicts = 2; +} + +// Mirrors: cmd/merge/squash_commit.rs — SquashCommitParams +message SquashCommitParams { + ObjectId their_commit = 1; + optional MergeOptions options = 2; +} + +message MergeBaseRequest { + string repo_id = 1; + ObjectId oid_a = 2; + ObjectId oid_b = 3; +} + +message MergeBaseResponse { + ObjectId base_oid = 1; +} + +message MergeBaseManyRequest { + string repo_id = 1; + repeated ObjectId oids = 2; +} + +message MergeBaseManyResponse { + ObjectId base_oid = 1; +} + +message MergeBaseOctopusRequest { + string repo_id = 1; + repeated ObjectId oids = 2; +} + +message MergeBaseOctopusResponse { + ObjectId base_oid = 1; +} + +message MergeAnalysisRequest { + string repo_id = 1; + ObjectId oid_a = 2; + ObjectId oid_b = 3; +} + +message MergeAnalysisResponse { + MergeAnalysisResult analysis = 1; + MergePreferenceResult preference = 2; +} + +message MergeAnalysisForRefRequest { + string repo_id = 1; + string ref_name = 2; + ObjectId oid_a = 3; + ObjectId oid_b = 4; +} + +message MergeAnalysisForRefResponse { + MergeAnalysisResult analysis = 1; + MergePreferenceResult preference = 2; +} + +message MergeIsConflictedRequest { + string repo_id = 1; +} + +message MergeIsConflictedResponse { + bool is_conflicted = 1; +} + +message MergeheadListRequest { + string repo_id = 1; +} + +message MergeheadListResponse { + repeated ObjectId oids = 1; +} + +message MergeTreeRequest { + string repo_id = 1; + ObjectId ours = 2; + ObjectId theirs = 3; + optional MergeOptions options = 4; +} + +message MergeTreeResponse { + MergeTreeResult result = 1; +} + +message MergeCommitRequest { + string repo_id = 1; + MergeCommitParams params = 2; +} + +message MergeCommitResponse { + ObjectId oid = 1; +} + +message SquashCommitRequest { + string repo_id = 1; + SquashCommitParams params = 2; +} + +message SquashCommitResponse { + ObjectId oid = 1; +} + +message MergeAbortRequest { + string repo_id = 1; +} + +message MergeAbortResponse {} + +service MergeService { + rpc MergeBase(MergeBaseRequest) returns (MergeBaseResponse); + rpc MergeBaseMany(MergeBaseManyRequest) returns (MergeBaseManyResponse); + rpc MergeBaseOctopus(MergeBaseOctopusRequest) returns (MergeBaseOctopusResponse); + rpc MergeAnalysis(MergeAnalysisRequest) returns (MergeAnalysisResponse); + rpc MergeAnalysisForRef(MergeAnalysisForRefRequest) returns (MergeAnalysisForRefResponse); + rpc MergeIsConflicted(MergeIsConflictedRequest) returns (MergeIsConflictedResponse); + rpc MergeheadList(MergeheadListRequest) returns (MergeheadListResponse); + rpc MergeTree(MergeTreeRequest) returns (MergeTreeResponse); + rpc MergeCommit(MergeCommitRequest) returns (MergeCommitResponse); + rpc SquashCommit(SquashCommitRequest) returns (SquashCommitResponse); + rpc MergeAbort(MergeAbortRequest) returns (MergeAbortResponse); +} \ No newline at end of file diff --git a/lib/git/proto/tag.proto b/lib/git/proto/tag.proto new file mode 100644 index 0000000..1a75bee --- /dev/null +++ b/lib/git/proto/tag.proto @@ -0,0 +1,117 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/tag/mod.rs — TagItem +message TagItem { + string name = 1; + ObjectId oid = 2; + ObjectId target = 3; + bool is_annotated = 4; + optional string message = 5; + optional string tagger = 6; + optional string tagger_email = 7; +} + +// Mirrors: cmd/tag/mod.rs — TagSummary +message TagSummary { + uint64 total_count = 1; +} + +// Mirrors: cmd/tag/tag_init.rs — TagInitParams +message TagInitParams { + string name = 1; + ObjectId target = 2; + optional string message = 3; + optional GitTagger tagger = 4; + bool force = 5; +} + +// Mirrors: cmd/tag/tag_delete.rs — TagDeleteParams +message TagDeleteParams { + string name = 1; +} + +// Mirrors: cmd/tag/tag_rename.rs — TagRenameParams +message TagRenameParams { + string old_name = 1; + string new_name = 2; + bool force = 3; +} + +// Mirrors: cmd/tag/tag_upmsg.rs — TagUpdateMessageParams +message TagUpdateMessageParams { + string name = 1; + string message = 2; + GitTagger tagger = 3; + bool force = 4; +} + +message TagListRequest { + string repo_id = 1; +} + +message TagListResponse { + repeated TagItem tags = 1; +} + +message TagInfoRequest { + string repo_id = 1; + string name = 2; +} + +message TagInfoResponse { + TagItem tag = 1; +} + +message TagSummaryRequest { + string repo_id = 1; +} + +message TagSummaryResponse { + TagSummary summary = 1; +} + +message TagInitRequest { + string repo_id = 1; + TagInitParams params = 2; +} + +message TagInitResponse { + ObjectId oid = 1; +} + +message TagDeleteRequest { + string repo_id = 1; + TagDeleteParams params = 2; +} + +message TagDeleteResponse {} + +message TagRenameRequest { + string repo_id = 1; + TagRenameParams params = 2; +} + +message TagRenameResponse {} + +message TagUpdateMessageRequest { + string repo_id = 1; + TagUpdateMessageParams params = 2; +} + +message TagUpdateMessageResponse { + ObjectId oid = 1; +} + +service TagService { + rpc TagList(TagListRequest) returns (TagListResponse); + rpc TagInfo(TagInfoRequest) returns (TagInfoResponse); + rpc TagSummary(TagSummaryRequest) returns (TagSummaryResponse); + rpc TagInit(TagInitRequest) returns (TagInitResponse); + rpc TagDelete(TagDeleteRequest) returns (TagDeleteResponse); + rpc TagRename(TagRenameRequest) returns (TagRenameResponse); + rpc TagUpdateMessage(TagUpdateMessageRequest) returns (TagUpdateMessageResponse); +} \ No newline at end of file diff --git a/lib/git/proto/tree.proto b/lib/git/proto/tree.proto new file mode 100644 index 0000000..ba39c54 --- /dev/null +++ b/lib/git/proto/tree.proto @@ -0,0 +1,80 @@ +syntax = "proto3"; + +package git.v1; + +import "common.proto"; + +// Mirrors: cmd/tree/mod.rs — TreeKind +enum TreeKind { + TREE_KIND_BLOB = 0; + TREE_KIND_TREE = 1; + TREE_KIND_LFS_POINTER = 2; +} + +// Mirrors: cmd/tree/mod.rs — TreeInfo +message TreeInfo { + ObjectId oid = 1; + uint64 entry_count = 2; + bool is_empty = 3; +} + +// Mirrors: cmd/tree/mod.rs — TreeEntry +message TreeEntry { + string name = 1; + ObjectId oid = 2; + TreeKind kind = 3; + uint32 filemode = 4; + bool is_binary = 5; + bool is_lfs = 6; + string last_commit_message = 7; + string last_commit_time = 8; + string last_commit_author_name = 9; + string last_commit_author_email = 10; +} + +message TreeEntriesRequest { + string repo_id = 1; + ObjectId oid = 2; + string base_path = 3; + bool last = 4; +} + +message TreeEntriesResponse { + repeated TreeEntry entries = 1; +} + +message TreeEntryByPathRequest { + string repo_id = 1; + ObjectId tree_oid = 2; + string path = 3; +} + +message TreeEntryByPathResponse { + optional TreeEntry entry = 1; +} + +message TreeEntryByPathFromCommitRequest { + string repo_id = 1; + ObjectId commit_oid = 2; + string path = 3; +} + +message TreeEntryByPathFromCommitResponse { + optional TreeEntry entry = 1; +} + +message ResolveTreeRequest { + string repo_id = 1; + ObjectId oid = 2; +} + +message ResolveTreeResponse { + TreeInfo info = 1; +} + +service TreeService { + rpc TreeEntries(TreeEntriesRequest) returns (TreeEntriesResponse); + rpc TreeEntryByPath(TreeEntryByPathRequest) returns (TreeEntryByPathResponse); + rpc TreeEntryByPathFromCommit(TreeEntryByPathFromCommitRequest) returns (TreeEntryByPathFromCommitResponse); + rpc ResolveTree(ResolveTreeRequest) returns (ResolveTreeResponse); +} \ No newline at end of file diff --git a/lib/git/role.rs b/lib/git/role.rs new file mode 100644 index 0000000..0c46ecf --- /dev/null +++ b/lib/git/role.rs @@ -0,0 +1,162 @@ +use model::{ + repos::{repo::RepoModel, repo_history_name::RepoHistoryNameModel}, + users::user::UserModel, + workspace::{ + wk_history_name::WkHistoryNameModel, wk_member::WkMemberModel, + workspace::WorkspaceModel, + }, +}; +use uuid::Uuid; + +use crate::{ + AppGitState, + errors::{GitError, GitResult}, +}; + +#[derive(Clone, Debug)] +pub struct DbRepoStatus { + pub repo: RepoModel, + pub wk: WorkspaceModel, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum MemberRole { + Owner, + Admin, + Member, +} + +impl MemberRole { + pub fn can_write(self) -> bool { + matches!(self, Self::Owner | Self::Admin) + } +} + +impl AppGitState { + pub async fn repo( + &self, + wk_name: String, + repo_name: String, + ) -> GitResult { + let wk = self.resolve_wk(&wk_name).await?; + let repo = self.resolve_repo(wk.id, &repo_name).await?; + + Ok(DbRepoStatus { repo, wk }) + } + + pub async fn member_check( + &self, + status: &DbRepoStatus, + user: &UserModel, + ) -> GitResult> { + self.member_check_by_user_id(status, user.id).await + } + + pub async fn member_check_by_user_id( + &self, + status: &DbRepoStatus, + user_id: Uuid, + ) -> GitResult> { + let member = db::sqlx::query_as::<_, WkMemberModel>( + "SELECT wk, \"user\", owner, admin, join_at, leave_at \ + FROM wk_member \ + WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(status.wk.id) + .bind(user_id) + .fetch_optional(self.db.reader()) + .await?; + + Ok(member.map(|member| { + if member.owner { + MemberRole::Owner + } else if member.admin { + MemberRole::Admin + } else { + MemberRole::Member + } + })) + } + + async fn resolve_wk(&self, name: &str) -> GitResult { + if let Some(wk) = db::sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at \ + FROM workspace \ + WHERE name = $1", + ) + .bind(&name) + .fetch_optional(self.db.reader()) + .await? + { + return Ok(wk); + } + + let Some(history) = db::sqlx::query_as::<_, WkHistoryNameModel>( + "SELECT id, wk, name, changed_by, created_at \ + FROM wk_history_name \ + WHERE name = $1 \ + ORDER BY created_at DESC \ + LIMIT 1", + ) + .bind(&name) + .fetch_optional(self.db.reader()) + .await? + else { + return Err(GitError::RepoNotFound); + }; + + db::sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at \ + FROM workspace \ + WHERE id = $1", + ) + .bind(history.wk) + .fetch_optional(self.db.reader()) + .await? + .ok_or(GitError::RepoNotFound) + } + + async fn resolve_repo(&self, wk: Uuid, name: &str) -> GitResult { + if let Some(repo) = db::sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, created_at, updated_at, deleted_at \ + FROM repo \ + WHERE wk = $1 AND name = $2 AND deleted_at IS NULL", + ) + .bind(wk) + .bind(&name) + .fetch_optional(self.db.reader()) + .await? + { + return Ok(repo); + } + + let Some(history) = db::sqlx::query_as::<_, RepoHistoryNameModel>( + "SELECT h.id, h.repo, h.name, h.changed_by, h.created_at \ + FROM repo_history_name h \ + INNER JOIN repo r ON h.repo = r.id \ + WHERE h.name = $1 AND r.wk = $2 AND r.deleted_at IS NULL \ + ORDER BY h.created_at DESC \ + LIMIT 1", + ) + .bind(&name) + .bind(wk) + .fetch_optional(self.db.reader()) + .await? + else { + return Err(GitError::RepoNotFound); + }; + + db::sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, created_at, updated_at, deleted_at \ + FROM repo \ + WHERE id = $1 AND wk = $2 AND deleted_at IS NULL", + ) + .bind(history.repo) + .bind(wk) + .fetch_optional(self.db.reader()) + .await? + .ok_or(GitError::RepoNotFound) + } +} diff --git a/lib/git/rpc/archive.rs b/lib/git/rpc/archive.rs new file mode 100644 index 0000000..17d92c6 --- /dev/null +++ b/lib/git/rpc/archive.rs @@ -0,0 +1,46 @@ +use std::sync::Arc; + +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct ArchiveServiceImpl { + pub registry: Arc, +} + +#[tonic::async_trait] +impl p::archive_service_server::ArchiveService for ArchiveServiceImpl { + async fn archive_tar( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let options = inner.options.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.archive_tar(options)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::ArchiveTarResponse { data: result.bytes })) + } + + async fn archive_zip( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let options = inner.options.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.archive_zip(options)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::ArchiveZipResponse { data: result.bytes })) + } +} diff --git a/lib/git/rpc/blame.rs b/lib/git/rpc/blame.rs new file mode 100644 index 0000000..17e9004 --- /dev/null +++ b/lib/git/rpc/blame.rs @@ -0,0 +1,130 @@ +use std::sync::Arc; + +use cache::AppCache; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct BlameServiceImpl { + pub registry: Arc, + pub cache: AppCache, +} + +type BlameStream = ReceiverStream>; + +#[tonic::async_trait] +impl p::blame_service_server::BlameService for BlameServiceImpl { + type BlameStreamStream = BlameStream; + + async fn blame_file( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let rev_str = inner.rev.unwrap_or_default(); + let path = inner.path.clone(); + let cache_key = format!("git:rpc:cache:blame:file:{}:{}:{}", repo_id, path, rev_str); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let oid = crate::cmd::oid::ObjectId::new(&rev_str); + let opts = inner.options.map(Into::into); + let result = tokio::task::spawn_blocking(move || { + bare.blame_file(oid, &path, opts) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let resp = p::BlameFileResponse { + hunks: result.into_iter().map(Into::into).collect(), + }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + async fn blame_stream( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let rev_str = inner.rev.unwrap_or_default(); + let oid = crate::cmd::oid::ObjectId::new(&rev_str); + let path = inner.path.clone(); + let opts = inner.options.map(Into::into); + let (tx, rx) = tokio::sync::mpsc::channel(128); + tokio::task::spawn_blocking(move || { + let result = bare.blame_file(oid, &path, opts); + match result { + Ok(hunks) => { + for hunk in hunks { + if tx.blocking_send(Ok(hunk.into())).is_err() { + break; + } + } + } + Err(e) => { + let _ = tx.blocking_send(Err(to_status(e))); + } + } + }); + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn blame_hunk( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let rev_str = inner.rev.unwrap_or_default(); + let oid = crate::cmd::oid::ObjectId::new(&rev_str); + let path = inner.path.clone(); + let line_on = inner.start_line as usize; + let result = tokio::task::spawn_blocking(move || { + bare.blame_hunk(&oid, &path, line_on) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BlameHunkResponse { + hunks: vec![result.into()], + })) + } + + async fn blame_lines( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let rev_str = inner.rev.unwrap_or_default(); + let oid = crate::cmd::oid::ObjectId::new(&rev_str); + let path = inner.path.clone(); + let opts = Some(crate::cmd::blame::BlameOptions { + min_line: Some(inner.start_line as usize), + max_line: Some(inner.end_line as usize), + track_copies_same_file: false, + track_copies_same_commit_moves: false, + ignore_whitespace: false, + }); + let result = tokio::task::spawn_blocking(move || { + bare.blame_lines(oid, &path, opts) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BlameLinesResponse { + lines: result.into_iter().map(Into::into).collect(), + })) + } +} diff --git a/lib/git/rpc/blob.rs b/lib/git/rpc/blob.rs new file mode 100644 index 0000000..cd40309 --- /dev/null +++ b/lib/git/rpc/blob.rs @@ -0,0 +1,159 @@ +use std::sync::Arc; + +use cache::AppCache; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct BlobServiceImpl { + pub registry: Arc, + pub cache: AppCache, +} + +type BlobChunkStream = ReceiverStream>; + +#[tonic::async_trait] +impl p::blob_service_server::BlobService for BlobServiceImpl { + async fn blob_load( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let oid_str = inner.id.clone().map(|o| o.value).unwrap_or_default(); + let path = inner.path.clone(); + let cache_key = format!("git:rpc:cache:blob:load:{}:{}:{}", repo_id, oid_str, path); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let params = crate::cmd::blob::BlobLoadParams { + id: inner.id.unwrap_or_default().into(), + path: inner.path.clone(), + }; + let params_clone = params.clone(); + let result = + tokio::task::spawn_blocking(move || bare.blob_load(¶ms_clone)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let resp = p::BlobLoadResponse { blob: result.blob }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + async fn blob_size( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = crate::cmd::blob::BlobSizeParams { + id: inner.id.unwrap_or_default().into(), + path: inner.path.clone(), + }; + let result = + tokio::task::spawn_blocking(move || bare.blob_size(¶ms)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BlobSizeResponse { size: result })) + } + + async fn blob_exists( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let oid: crate::cmd::oid::ObjectId = + inner.id.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || bare.blob_exists(oid)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BlobExistsResponse { exists: result })) + } + + async fn blob_is_binary( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let oid_str = inner.id.clone().map(|o| o.value).unwrap_or_default(); + let cache_key = format!("git:rpc:cache:blob:binary:{}:{}", repo_id, oid_str); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let oid: crate::cmd::oid::ObjectId = + inner.id.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.blob_is_binary(oid)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let resp = p::BlobIsBinaryResponse { is_binary: result }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + async fn blob_upload( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = crate::cmd::blob::BlobUploadParams { + blob: inner.blob.clone(), + path: inner.path.clone(), + }; + let result = + tokio::task::spawn_blocking(move || bare.blob_upload(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BlobUploadResponse { + id: Some(result.id.into()), + })) + } + + type BlobChunkStreamStream = BlobChunkStream; + + async fn blob_chunk_stream( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let param = crate::cmd::blob::BlobChunkParam { + path: inner.path.clone(), + oid: inner.oid.unwrap_or_default().into(), + size: inner.size as usize, + offset: inner.offset as usize, + }; + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::task::spawn_blocking(move || { + let result = bare.blob_chunk(param); + match result { + Ok(chunk) => { + let _ = tx.blocking_send(Ok(chunk.into())); + } + Err(e) => { + let _ = tx.blocking_send(Err(to_status(e))); + } + } + }); + Ok(Response::new(ReceiverStream::new(rx))) + } +} diff --git a/lib/git/rpc/branch.rs b/lib/git/rpc/branch.rs new file mode 100644 index 0000000..6bf1c48 --- /dev/null +++ b/lib/git/rpc/branch.rs @@ -0,0 +1,160 @@ +use std::sync::Arc; + +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct BranchServiceImpl { + pub registry: Arc, +} + +#[tonic::async_trait] +impl p::branch_service_server::BranchService for BranchServiceImpl { + async fn branch_list( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = + tokio::task::spawn_blocking(move || bare.branch_list_all()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchListResponse { + branches: result.into_iter().map(Into::into).collect(), + })) + } + + async fn branch_info( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let branch = inner.branch.clone(); + let result = + tokio::task::spawn_blocking(move || bare.branch_info(branch)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchInfoResponse { + branch: Some(result.into()), + })) + } + + async fn branch_summary( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = tokio::task::spawn_blocking(move || bare.branch_summary()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchSummaryResponse { + summary: Some(result.into()), + })) + } + + async fn branch_head( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = + tokio::task::spawn_blocking(move || bare.branch_head_name()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchHeadResponse { head_name: result })) + } + + async fn branch_ahead_behind( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let branch = inner.local_branch.clone(); + let result = tokio::task::spawn_blocking(move || { + bare.branch_ahead_behind(branch) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchAheadBehindResponse { + ahead: if result { 1 } else { 0 }, + behind: 0, + })) + } + + async fn branch_upstream( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let branch = inner.branch.clone(); + let result = tokio::task::spawn_blocking(move || { + bare.branch_upstream_name(branch) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchUpstreamResponse { + upstream_name: result.unwrap_or_default(), + })) + } + + async fn branch_fork( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let _result = + tokio::task::spawn_blocking(move || bare.branch_fork(¶ms)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchForkResponse {})) + } + + async fn branch_delete( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let _result = + tokio::task::spawn_blocking(move || bare.branch_delete(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchDeleteResponse {})) + } + + async fn branch_rename( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let _result = + tokio::task::spawn_blocking(move || bare.branch_rename(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::BranchRenameResponse {})) + } +} diff --git a/lib/git/rpc/commit.rs b/lib/git/rpc/commit.rs new file mode 100644 index 0000000..83edad2 --- /dev/null +++ b/lib/git/rpc/commit.rs @@ -0,0 +1,291 @@ +use std::sync::Arc; + +use cache::AppCache; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; +use uuid::Uuid; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; +use crate::sync::ReceiveSyncService; + +pub struct CommitServiceImpl { + pub registry: Arc, + pub cache: AppCache, + pub sync: ReceiveSyncService, +} + +type CommitHistoryStream = ReceiverStream>; + +#[tonic::async_trait] +impl p::commit_service_server::CommitService for CommitServiceImpl { + async fn commit_info( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let oid = inner.oid.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || bare.commit_info(oid)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::CommitInfoResponse { + commit: Some(result.into()), + })) + } + + async fn commit_history( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let cache_key = format!( + "git:rpc:cache:commit:history:{}:{}:{}:{}:{}", + repo_id, inner.limit, inner.skip, inner.sort, inner.branch.as_deref().unwrap_or("") + ); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let params = crate::cmd::commit::CommitWalkParams { + start_oids: vec![], + hide_oids: vec![], + limit: if inner.limit > 0 { Some(inner.limit as usize) } else { None }, + skip: inner.skip as usize, + first_parent: false, + sort: inner.sort.into(), + branch: inner.branch.clone(), + }; + let result = + tokio::task::spawn_blocking(move || bare.commit_history(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let resp = p::CommitHistoryResponse { + commits: result.into_iter().map(Into::into).collect(), + }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + type CommitHistoryStreamStream = CommitHistoryStream; + + async fn commit_history_stream( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = crate::cmd::commit::CommitWalkParams { + start_oids: vec![], + hide_oids: vec![], + limit: if inner.limit > 0 { Some(inner.limit as usize) } else { None }, + skip: inner.skip as usize, + first_parent: false, + sort: inner.sort.into(), + branch: inner.branch.clone(), + }; + let (tx, rx) = tokio::sync::mpsc::channel(128); + tokio::task::spawn_blocking(move || { + let result = bare.commit_history(params); + match result { + Ok(commits) => { + for c in commits { + if tx.blocking_send(Ok(c.into())).is_err() { + break; + } + } + } + Err(e) => { + let _ = tx.blocking_send(Err(to_status(e))); + } + } + }); + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn commit_summary( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let cache_key = format!("git:rpc:cache:commit:summary:{}", repo_id); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let result = tokio::task::spawn_blocking(move || bare.commit_summary()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let resp = p::CommitSummaryResponse { summary: Some(result.into()) }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + async fn commit_walk( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.commit_walk(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::CommitWalkResponse { + commits: result.into_iter().map(Into::into).collect(), + })) + } + + async fn commit_refs( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = tokio::task::spawn_blocking(move || { + let repo = bare.gix_repo()?; + let head_id = repo.head_id()?.detach(); + let oid = crate::cmd::oid::ObjectId::new(head_id.to_hex().to_string()); + bare.commit_refs(oid) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::CommitRefsResponse { + refs: result.into_iter().map(Into::into).collect(), + })) + } + + async fn commit_prefix( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let prefix = inner.prefix.clone(); + let result = tokio::task::spawn_blocking(move || { + bare.commit_oid_from_prefix(&prefix) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::CommitPrefixResponse { + oid: Some(result.into()), + })) + } + + async fn commit_exists( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let oid: crate::cmd::oid::ObjectId = + inner.oid.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.commit_exists(oid)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::CommitExistsResponse { exists: result })) + } + + async fn cherry_pick( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.commit_pick(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + + if let Ok(repo_uid) = Uuid::parse_str(&repo_id) { + self.sync.send(crate::sync::RepoReceiveSyncTask { repo_uid }).await; + } + + Ok(Response::new(p::CherryPickResponse { + oid: Some(result.into()), + })) + } + + async fn cherry_pick_sequence( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || { + bare.commit_cherry_pick_sequence(params) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + + if let Ok(repo_uid) = Uuid::parse_str(&repo_id) { + self.sync.send(crate::sync::RepoReceiveSyncTask { repo_uid }).await; + } + + Ok(Response::new(p::CherryPickSequenceResponse { + oid: Some(result.into()), + })) + } + + async fn create_commit( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = crate::cmd::commit::CreateCommitParams { + branch: inner.branch.clone(), + message: inner.message.clone(), + author_name: inner.author_name.clone(), + author_email: inner.author_email.clone(), + committer_name: inner.committer_name.clone(), + committer_email: inner.committer_email.clone(), + files: inner + .files + .into_iter() + .map(|f| crate::cmd::commit::FileChange { + path: f.path, + content: f.content, + }) + .collect(), + }; + let result = tokio::task::spawn_blocking(move || bare.commit_create(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + + // Trigger sync after write + if let Ok(repo_uid) = Uuid::parse_str(&repo_id) { + self.sync.send(crate::sync::RepoReceiveSyncTask { repo_uid }).await; + } + + Ok(Response::new(p::CreateCommitResponse { + oid: Some(result.into()), + })) + } +} diff --git a/lib/git/rpc/convert.rs b/lib/git/rpc/convert.rs new file mode 100644 index 0000000..44020a5 --- /dev/null +++ b/lib/git/rpc/convert.rs @@ -0,0 +1,704 @@ +use crate::{cmd::oid::ObjectId, rpc::proto as p}; + +impl From for p::ObjectId { + fn from(oid: ObjectId) -> Self { + p::ObjectId { value: oid.0 } + } +} + +impl From for ObjectId { + fn from(oid: p::ObjectId) -> Self { + ObjectId::new(&oid.value) + } +} + +impl From<&ObjectId> for p::ObjectId { + fn from(oid: &ObjectId) -> Self { + p::ObjectId { + value: oid.0.clone(), + } + } +} + +impl From for p::CommitSignature { + fn from(sig: crate::cmd::commit::CommitSignature) -> Self { + p::CommitSignature { + name: sig.name, + email: sig.email, + time_secs: sig.time_secs, + offset_minutes: sig.offset_minutes, + } + } +} + +impl From for crate::cmd::commit::CommitSignature { + fn from(sig: p::CommitSignature) -> Self { + crate::cmd::commit::CommitSignature { + name: sig.name, + email: sig.email, + time_secs: sig.time_secs, + offset_minutes: sig.offset_minutes, + } + } +} + +impl From for p::GitTagger { + fn from(t: crate::cmd::tagger::GitTagger) -> Self { + p::GitTagger { + name: t.name, + email: t.email, + } + } +} + +impl From for crate::cmd::tagger::GitTagger { + fn from(t: p::GitTagger) -> Self { + crate::cmd::tagger::GitTagger { + name: t.name, + email: t.email, + } + } +} + +impl From for p::CommitMeta { + fn from(c: crate::cmd::commit::CommitMeta) -> Self { + p::CommitMeta { + oid: Some(c.oid.into()), + message: c.message, + summary: c.summary, + author: Some(c.author.into()), + committer: Some(c.committer.into()), + tree_id: Some(c.tree_id.into()), + parent_ids: c.parent_ids.into_iter().map(Into::into).collect(), + encoding: c.encoding, + } + } +} + +impl From for p::CommitRefInfo { + fn from(r: crate::cmd::commit::CommitRefInfo) -> Self { + p::CommitRefInfo { + name: r.name, + target: Some(r.target.into()), + is_remote: r.is_remote, + is_tag: r.is_tag, + } + } +} + +impl From for p::CommitSummary { + fn from(s: crate::cmd::commit::CommitSummary) -> Self { + p::CommitSummary { + head: s.head.map(Into::into), + count: s.count as u64, + } + } +} + +impl From for p::CommitWalkSort { + fn from(s: crate::cmd::commit::CommitWalkSort) -> Self { + match s { + crate::cmd::commit::CommitWalkSort::None => p::CommitWalkSort::None, + crate::cmd::commit::CommitWalkSort::Topological => { + p::CommitWalkSort::Topological + } + crate::cmd::commit::CommitWalkSort::Time => p::CommitWalkSort::Time, + crate::cmd::commit::CommitWalkSort::Reverse => { + p::CommitWalkSort::Reverse + } + } + } +} + +impl From for crate::cmd::commit::CommitWalkSort { + fn from(s: p::CommitWalkSort) -> Self { + match s { + p::CommitWalkSort::None => crate::cmd::commit::CommitWalkSort::None, + p::CommitWalkSort::Topological => { + crate::cmd::commit::CommitWalkSort::Topological + } + p::CommitWalkSort::Time => crate::cmd::commit::CommitWalkSort::Time, + p::CommitWalkSort::Reverse => { + crate::cmd::commit::CommitWalkSort::Reverse + } + } + } +} + +impl From for crate::cmd::commit::CommitWalkSort { + fn from(val: i32) -> Self { + let proto_sort: p::CommitWalkSort = + p::CommitWalkSort::try_from(val).unwrap_or(p::CommitWalkSort::Time); + proto_sort.into() + } +} + +impl From for crate::cmd::commit::CommitWalkParams { + fn from(p: p::CommitWalkParams) -> Self { + crate::cmd::commit::CommitWalkParams { + start_oids: p.start_oids.into_iter().map(Into::into).collect(), + hide_oids: p.hide_oids.into_iter().map(Into::into).collect(), + limit: p.limit.map(|l| l as usize), + skip: p.skip as usize, + first_parent: p.first_parent, + sort: p.sort.into(), + branch: None, + } + } +} + +impl From + for crate::cmd::commit::CommitCherryPickParams +{ + fn from(p: p::CommitCherryPickParams) -> Self { + crate::cmd::commit::CommitCherryPickParams { + cherrypick_oid: p.cherrypick_oid.unwrap_or_default().into(), + author: p.author.unwrap_or_default().into(), + committer: p.committer.unwrap_or_default().into(), + message: p.message, + mainline: p.mainline, + update_ref: p.update_ref, + } + } +} + +impl From + for crate::cmd::commit::CommitCherryPickSequence +{ + fn from(p: p::CommitCherryPickSequence) -> Self { + crate::cmd::commit::CommitCherryPickSequence { + cherrypick_oids: p + .cherrypick_oids + .into_iter() + .map(Into::into) + .collect(), + author: p.author.unwrap_or_default().into(), + committer: p.committer.unwrap_or_default().into(), + update_ref: p.update_ref, + } + } +} + +impl From for p::BranchListItem { + fn from(b: crate::cmd::branch::BranchListItem) -> Self { + p::BranchListItem { + name: b.name, + oid: Some(b.oid.into()), + is_head: b.is_head, + is_remote: b.is_remote, + is_current: b.is_current, + upstream: b.upstream, + } + } +} + +impl From for p::BranchSummary { + fn from(s: crate::cmd::branch::BranchSummary) -> Self { + p::BranchSummary { + local_count: s.local_count as u64, + remote_count: s.remote_count as u64, + all_count: s.all_count as u64, + } + } +} + +impl From for crate::cmd::branch::BranchForkParams { + fn from(p: p::BranchForkParams) -> Self { + crate::cmd::branch::BranchForkParams { + name: p.name, + oid: p.oid.unwrap_or_default().into(), + force: p.force, + } + } +} + +impl From for crate::cmd::branch::BranchDeleteParams { + fn from(p: p::BranchDeleteParams) -> Self { + crate::cmd::branch::BranchDeleteParams { + name: p.name, + force: p.force, + } + } +} + +impl From for crate::cmd::branch::BranchReNameParams { + fn from(p: p::BranchReNameParams) -> Self { + crate::cmd::branch::BranchReNameParams { + old_branch: p.old_branch, + new_branch: p.new_branch, + force: p.force, + } + } +} + +impl From for p::DiffDeltaStatus { + fn from(s: crate::cmd::diff::DiffDeltaStatus) -> Self { + match s { + crate::cmd::diff::DiffDeltaStatus::Unmodified => { + p::DiffDeltaStatus::Unmodified + } + crate::cmd::diff::DiffDeltaStatus::Added => { + p::DiffDeltaStatus::Added + } + crate::cmd::diff::DiffDeltaStatus::Deleted => { + p::DiffDeltaStatus::Deleted + } + crate::cmd::diff::DiffDeltaStatus::Modified => { + p::DiffDeltaStatus::Modified + } + crate::cmd::diff::DiffDeltaStatus::Renamed => { + p::DiffDeltaStatus::Renamed + } + crate::cmd::diff::DiffDeltaStatus::Copied => { + p::DiffDeltaStatus::Copied + } + crate::cmd::diff::DiffDeltaStatus::Typechange => { + p::DiffDeltaStatus::Typechange + } + crate::cmd::diff::DiffDeltaStatus::Conflicted => { + p::DiffDeltaStatus::Conflicted + } + } + } +} + +impl From for p::DiffFile { + fn from(f: crate::cmd::diff::DiffFile) -> Self { + p::DiffFile { + oid: f.oid.map(Into::into), + path: f.path, + size: f.size, + is_binary: f.is_binary, + } + } +} + +impl From for p::DiffHunk { + fn from(h: crate::cmd::diff::DiffHunk) -> Self { + p::DiffHunk { + old_start: h.old_start, + old_lines: h.old_lines, + new_start: h.new_start, + new_lines: h.new_lines, + header: h.header, + } + } +} + +impl From for p::DiffLine { + fn from(l: crate::cmd::diff::DiffLine) -> Self { + p::DiffLine { + content: l.content, + origin: l.origin.to_string(), + old_lineno: l.old_lineno, + new_lineno: l.new_lineno, + num_lines: l.num_lines, + content_offset: l.content_offset, + } + } +} + +impl From for p::DiffDelta { + fn from(d: crate::cmd::diff::DiffDelta) -> Self { + let status_proto: p::DiffDeltaStatus = d.status.into(); + p::DiffDelta { + status: status_proto as i32, + old_file: Some(d.old_file.into()), + new_file: Some(d.new_file.into()), + nfiles: d.nfiles as u32, + hunks: d.hunks.into_iter().map(Into::into).collect(), + lines: d.lines.into_iter().map(Into::into).collect(), + } + } +} + +impl From for p::DiffStats { + fn from(s: crate::cmd::diff::DiffStats) -> Self { + p::DiffStats { + files_changed: s.files_changed as u64, + insertions: s.insertions as u64, + deletions: s.deletions as u64, + } + } +} + +impl From for p::DiffResult { + fn from(r: crate::cmd::diff::DiffResult) -> Self { + p::DiffResult { + stats: Some(r.stats.into()), + deltas: r.deltas.into_iter().map(Into::into).collect(), + } + } +} + +impl From for crate::cmd::diff::DiffOptions { + fn from(o: p::DiffOptions) -> Self { + crate::cmd::diff::DiffOptions { + context_lines: o.context_lines, + pathspec: o.pathspec, + ignore_whitespace: o.ignore_whitespace, + force_text: o.force_text, + reverse: o.reverse, + } + } +} + +impl From for p::SideBySideChangeType { + fn from(t: crate::cmd::diff::SideBySideChangeType) -> Self { + match t { + crate::cmd::diff::SideBySideChangeType::Unchanged => { + p::SideBySideChangeType::Unchanged + } + crate::cmd::diff::SideBySideChangeType::Added => { + p::SideBySideChangeType::Added + } + crate::cmd::diff::SideBySideChangeType::Removed => { + p::SideBySideChangeType::Removed + } + crate::cmd::diff::SideBySideChangeType::Modified => { + p::SideBySideChangeType::Modified + } + crate::cmd::diff::SideBySideChangeType::Empty => { + p::SideBySideChangeType::Empty + } + } + } +} + +impl From for p::SideBySideLine { + fn from(l: crate::cmd::diff::SideBySideLine) -> Self { + let change_type_proto: p::SideBySideChangeType = l.change_type.into(); + p::SideBySideLine { + left_line_no: l.left_line_no, + right_line_no: l.right_line_no, + left_content: l.left_content, + right_content: l.right_content, + change_type: change_type_proto as i32, + } + } +} + +impl From for p::SideBySideFile { + fn from(f: crate::cmd::diff::SideBySideFile) -> Self { + p::SideBySideFile { + path: f.path, + additions: f.additions as u64, + deletions: f.deletions as u64, + is_binary: f.is_binary, + is_rename: f.is_rename, + lines: f.lines.into_iter().map(Into::into).collect(), + } + } +} + +impl From for p::SideBySideDiffResult { + fn from(r: crate::cmd::diff::SideBySideDiffResult) -> Self { + p::SideBySideDiffResult { + files: r.files.into_iter().map(Into::into).collect(), + total_additions: r.total_additions as u64, + total_deletions: r.total_deletions as u64, + } + } +} + +impl From for p::MergeAnalysisResult { + fn from(r: crate::cmd::merge::MergeAnalysisResult) -> Self { + p::MergeAnalysisResult { + is_none: r.is_none, + is_normal: r.is_normal, + is_up_to_date: r.is_up_to_date, + is_fast_forward: r.is_fast_forward, + is_unborn: r.is_unborn, + } + } +} + +impl From + for p::MergePreferenceResult +{ + fn from(r: crate::cmd::merge::MergePreferenceResult) -> Self { + p::MergePreferenceResult { + is_none: r.is_none, + is_no_fast_forward: r.is_no_fast_forward, + is_fastforward_only: r.is_fastforward_only, + } + } +} + +impl From for p::MergeOptions { + fn from(o: crate::cmd::merge::MergeOptions) -> Self { + p::MergeOptions { + find_renames: o.find_renames, + fail_on_conflict: o.fail_on_conflict, + skip_reuc: o.skip_reuc, + no_recursive: o.no_recursive, + rename_threshold: o.rename_threshold, + target_limit: o.target_limit, + recursion_limit: o.recursion_limit, + } + } +} + +impl From for crate::cmd::merge::MergeOptions { + fn from(o: p::MergeOptions) -> Self { + crate::cmd::merge::MergeOptions { + find_renames: o.find_renames, + fail_on_conflict: o.fail_on_conflict, + skip_reuc: o.skip_reuc, + no_recursive: o.no_recursive, + rename_threshold: o.rename_threshold, + target_limit: o.target_limit, + recursion_limit: o.recursion_limit, + } + } +} + +impl From for crate::cmd::merge::MergeCommitParams { + fn from(p: p::MergeCommitParams) -> Self { + crate::cmd::merge::MergeCommitParams { + their_commit: p.their_commit.unwrap_or_default().into(), + author: p.author.unwrap_or_default().into(), + committer: p.committer.unwrap_or_default().into(), + message: p.message, + update_ref: p.update_ref, + options: p.options.map(Into::into), + } + } +} + +impl From for p::MergeTreeResult { + fn from(r: crate::cmd::merge::MergeTreeResult) -> Self { + p::MergeTreeResult { + tree_id: Some(r.tree_id.into()), + has_conflicts: r.has_conflicts, + } + } +} + +impl From for crate::cmd::merge::SquashCommitParams { + fn from(p: p::SquashCommitParams) -> Self { + crate::cmd::merge::SquashCommitParams { + their_commit: p.their_commit.unwrap_or_default().into(), + options: p.options.map(Into::into), + } + } +} + +impl From for p::TagItem { + fn from(t: crate::cmd::tag::TagItem) -> Self { + p::TagItem { + name: t.name, + oid: Some(t.oid.into()), + target: Some(t.target.into()), + is_annotated: t.is_annotated, + message: t.message, + tagger: t.tagger, + tagger_email: t.tagger_email, + } + } +} + +impl From for p::TagSummary { + fn from(s: crate::cmd::tag::TagSummary) -> Self { + p::TagSummary { + total_count: s.total_count as u64, + } + } +} + +impl From for crate::cmd::tag::TagInitParams { + fn from(p: p::TagInitParams) -> Self { + crate::cmd::tag::TagInitParams { + name: p.name, + target: p.target.unwrap_or_default().into(), + message: p.message, + tagger: p.tagger.map(Into::into), + force: p.force, + } + } +} + +impl From for crate::cmd::tag::TagDeleteParams { + fn from(p: p::TagDeleteParams) -> Self { + crate::cmd::tag::TagDeleteParams { name: p.name } + } +} + +impl From for crate::cmd::tag::TagRenameParams { + fn from(p: p::TagRenameParams) -> Self { + crate::cmd::tag::TagRenameParams { + old_name: p.old_name, + new_name: p.new_name, + force: p.force, + } + } +} + +impl From + for crate::cmd::tag::TagUpdateMessageParams +{ + fn from(p: p::TagUpdateMessageParams) -> Self { + crate::cmd::tag::TagUpdateMessageParams { + name: p.name, + message: p.message, + tagger: p.tagger.unwrap_or_default().into(), + force: p.force, + } + } +} + +impl From for p::TreeKind { + fn from(k: crate::cmd::tree::TreeKind) -> Self { + match k { + crate::cmd::tree::TreeKind::Blob => p::TreeKind::Blob, + crate::cmd::tree::TreeKind::Tree => p::TreeKind::Tree, + crate::cmd::tree::TreeKind::LfsPointer => p::TreeKind::LfsPointer, + } + } +} + +impl From for p::TreeInfo { + fn from(i: crate::cmd::tree::TreeInfo) -> Self { + p::TreeInfo { + oid: Some(i.oid.into()), + entry_count: i.entry_count as u64, + is_empty: i.is_empty, + } + } +} + +impl From for p::TreeEntry { + fn from(e: crate::cmd::tree::TreeEntry) -> Self { + let kind_proto: p::TreeKind = e.kind.into(); + p::TreeEntry { + name: e.name, + oid: Some(e.oid.into()), + kind: kind_proto as i32, + filemode: e.filemode, + is_binary: e.is_binary, + is_lfs: e.is_lfs, + last_commit_message: String::new(), + last_commit_time: String::new(), + last_commit_author_name: String::new(), + last_commit_author_email: String::new(), + } + } +} + +impl From for p::CommitBlameHunk { + fn from(h: crate::cmd::blame::CommitBlameHunk) -> Self { + p::CommitBlameHunk { + commit_oid: Some(h.commit_oid.into()), + final_start_line: h.final_start_line, + final_lines: h.final_lines, + orig_start_line: h.orig_start_line, + orig_lines: h.orig_lines, + boundary: h.boundary, + orig_path: h.orig_path, + } + } +} + +impl From for p::CommitBlameLine { + fn from(l: crate::cmd::blame::CommitBlameLine) -> Self { + p::CommitBlameLine { + commit_oid: Some(l.commit_oid.into()), + line_no: l.line_no, + content: l.content, + orig_path: l.orig_path, + } + } +} + +impl From for crate::cmd::blame::BlameOptions { + fn from(o: p::BlameOptions) -> Self { + crate::cmd::blame::BlameOptions { + min_line: o.min_line.map(|l| l as usize), + max_line: o.max_line.map(|l| l as usize), + track_copies_same_file: o.track_copies_same_file, + track_copies_same_commit_moves: o.track_copies_same_commit_moves, + ignore_whitespace: o.ignore_whitespace, + } + } +} + +impl From for crate::cmd::archive::ArchiveOptions { + fn from(o: p::ArchiveOptions) -> Self { + crate::cmd::archive::ArchiveOptions { + tree: o.tree.unwrap_or_default().into(), + prefix: o.prefix, + pathspec: o.pathspec, + } + } +} + +impl From for crate::cmd::blob::BlobLoadParams { + fn from(p: p::BlobLoadParams) -> Self { + crate::cmd::blob::BlobLoadParams { + id: p.id.unwrap_or_default().into(), + path: p.path, + } + } +} + +impl From for crate::cmd::blob::BlobSizeParams { + fn from(p: p::BlobSizeParams) -> Self { + crate::cmd::blob::BlobSizeParams { + id: p.id.unwrap_or_default().into(), + path: p.path, + } + } +} + +impl From for crate::cmd::blob::BlobUploadParams { + fn from(p: p::BlobUploadParams) -> Self { + crate::cmd::blob::BlobUploadParams { + blob: p.blob, + path: p.path, + } + } +} + +impl From for crate::cmd::blob::BlobChunkParam { + fn from(p: p::BlobChunkParam) -> Self { + crate::cmd::blob::BlobChunkParam { + path: p.path, + oid: p.oid.unwrap_or_default().into(), + size: p.size as usize, + offset: p.offset as usize, + } + } +} + +impl From for p::BlobChunk { + fn from(c: crate::cmd::blob::BlobChunk) -> Self { + p::BlobChunk { + param: Some(c.param.into()), + chunk: c.chunk, + } + } +} + +impl From for crate::cmd::init::InitRepositoriesParams { + fn from(p: p::InitRepoParams) -> Self { + crate::cmd::init::InitRepositoriesParams { + namespace: p.namespace, + repo_name: p.repo_name, + default_branch: p.default_branch, + description: p.description, + initialize_with_readme: p.initialize_with_readme, + enable_lfs: p.enable_lfs, + } + } +} + +impl From for p::BlobChunkParam { + fn from(p: crate::cmd::blob::BlobChunkParam) -> Self { + p::BlobChunkParam { + path: p.path, + oid: Some(p.oid.into()), + size: p.size as u64, + offset: p.offset as u64, + } + } +} diff --git a/lib/git/rpc/diff.rs b/lib/git/rpc/diff.rs new file mode 100644 index 0000000..9141b3b --- /dev/null +++ b/lib/git/rpc/diff.rs @@ -0,0 +1,180 @@ +use std::sync::Arc; + +use cache::AppCache; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct DiffServiceImpl { + pub registry: Arc, + pub cache: AppCache, +} + +type DiffStream = ReceiverStream>; + +#[tonic::async_trait] +impl p::diff_service_server::DiffService for DiffServiceImpl { + async fn diff_stats( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let old_str = inner.old_oid.clone().map(|o| o.value).unwrap_or_default(); + let new_str = inner.new_oid.clone().map(|o| o.value).unwrap_or_default(); + let cache_key = format!("git:rpc:cache:diff:stats:{}:{}:{}", repo_id, old_str, new_str); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let old: crate::cmd::oid::ObjectId = + inner.old_oid.unwrap_or_default().into(); + let new: crate::cmd::oid::ObjectId = + inner.new_oid.unwrap_or_default().into(); + let opts = inner.options.map(Into::into); + let stats = tokio::task::spawn_blocking(move || { + bare.diff_stats(old, new, opts) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let result = crate::cmd::diff::DiffResult { stats, deltas: vec![] }; + let resp = p::DiffStatsResponse { result: Some(result.into()) }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + async fn diff_patch( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let old_str = inner.old_oid.clone().map(|o| o.value).unwrap_or_default(); + let new_str = inner.new_oid.clone().map(|o| o.value).unwrap_or_default(); + let cache_key = format!("git:rpc:cache:diff:patch:{}:{}:{}", repo_id, old_str, new_str); + + if let Ok(Some(cached)) = self.cache.get::(&cache_key).await { + return Ok(Response::new(cached)); + } + + let bare = self.registry.get(&repo_id).await?; + let old: crate::cmd::oid::ObjectId = + inner.old_oid.unwrap_or_default().into(); + let new: crate::cmd::oid::ObjectId = + inner.new_oid.unwrap_or_default().into(); + let opts = inner.options.map(Into::into); + let result = tokio::task::spawn_blocking(move || { + bare.diff_patch(old, new, opts) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + let resp = p::DiffPatchResponse { result: Some(result.into()) }; + let _ = self.cache.set(&cache_key, &resp).await; + Ok(Response::new(resp)) + } + + type DiffStreamStream = DiffStream; + + async fn diff_stream( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let old: crate::cmd::oid::ObjectId = + inner.old_oid.unwrap_or_default().into(); + let new: crate::cmd::oid::ObjectId = + inner.new_oid.unwrap_or_default().into(); + let opts = inner.options.map(Into::into); + let (tx, rx) = tokio::sync::mpsc::channel(128); + tokio::task::spawn_blocking(move || { + let result = bare.diff_patch(old, new, opts); + match result { + Ok(diff_result) => { + for delta in diff_result.deltas { + if tx.blocking_send(Ok(delta.into())).is_err() { + break; + } + } + } + Err(e) => { + let _ = tx.blocking_send(Err(to_status(e))); + } + } + }); + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn diff_patch_side_by_side( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let old: crate::cmd::oid::ObjectId = + inner.old_oid.unwrap_or_default().into(); + let new: crate::cmd::oid::ObjectId = + inner.new_oid.unwrap_or_default().into(); + let opts = inner.options.map(Into::into); + let result = tokio::task::spawn_blocking(move || { + bare.diff_patch_side_by_side(old, new, opts) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::DiffPatchSideBySideResponse { + result: Some(result.into()), + })) + } + + async fn diff_tree_to_tree( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let old_tree: crate::cmd::oid::ObjectId = + inner.old_tree.unwrap_or_default().into(); + let new_tree: crate::cmd::oid::ObjectId = + inner.new_tree.unwrap_or_default().into(); + let options = inner.options.map(Into::into); + let result = tokio::task::spawn_blocking(move || { + bare.diff_tree_to_tree(old_tree, new_tree, options) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::DiffTreeToTreeResponse { + result: Some(result.into()), + })) + } + + async fn diff_index_to_tree( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let tree_oid: crate::cmd::oid::ObjectId = + inner.tree_oid.unwrap_or_default().into(); + let options = inner.options.map(Into::into); + let result = tokio::task::spawn_blocking(move || { + bare.diff_index_to_tree(tree_oid, options) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::DiffIndexToTreeResponse { + result: Some(result.into()), + })) + } +} diff --git a/lib/git/rpc/error.rs b/lib/git/rpc/error.rs new file mode 100644 index 0000000..5e90a26 --- /dev/null +++ b/lib/git/rpc/error.rs @@ -0,0 +1,48 @@ +use tonic::Status; + +use crate::errors::GitError; + +pub fn to_status(err: GitError) -> Status { + match err { + GitError::NotBareRepository => { + Status::not_found("not a bare repository") + } + GitError::CommandFailed { + status_code, + stderr, + } => { + let msg = format!( + "git command failed (status={status_code:?}): {stderr}" + ); + Status::internal(msg) + } + GitError::UnsafeCommand(cmd) => { + Status::invalid_argument(format!("unsafe command: {cmd}")) + } + GitError::ObjectNotFound(oid) => { + Status::not_found(format!("object not found: {oid}")) + } + GitError::RefNotFound(ref_name) => { + Status::not_found(format!("ref not found: {ref_name}")) + } + GitError::ParseError(msg) => { + Status::internal(format!("parse error: {msg}")) + } + GitError::Io(e) => Status::internal(format!("io error: {e}")), + GitError::DatabaseError(e) => { + Status::internal(format!("database error{e}")) + } + GitError::RepoNotFound => Status::not_found("repo not found"), + GitError::Internal(msg) => Status::internal(msg), + GitError::NotFound(msg) => Status::not_found(msg), + GitError::InvalidOid(msg) => Status::invalid_argument(msg), + GitError::Locked(msg) => Status::failed_precondition(msg), + GitError::PermissionDenied(msg) => Status::permission_denied(msg), + GitError::AuthFailed(msg) => Status::unauthenticated(msg), + GitError::Gix(msg) => Status::internal(format!("gix error: {msg}")), + } +} + +pub fn spawn_blocking_error(e: tokio::task::JoinError) -> Status { + Status::internal(format!("spawn_blocking failed: {e}")) +} diff --git a/lib/git/rpc/fork.rs b/lib/git/rpc/fork.rs new file mode 100644 index 0000000..d7e807f --- /dev/null +++ b/lib/git/rpc/fork.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use tonic::{Request, Response, Status}; + +use crate::{ + cmd::fork::ForkRepoParams, + rpc::{error::to_status, proto as p, registry::RepoRegistry}, +}; + +pub struct ForkServiceImpl { + pub registry: Arc, +} + +#[tonic::async_trait] +impl p::fork_service_server::ForkService for ForkServiceImpl { + async fn fork_bare( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let proto_params = inner.params.unwrap_or_default(); + + let params = ForkRepoParams { + namespace: proto_params.namespace, + repo_name: proto_params.repo_name, + default_branch: proto_params.default_branch, + description: proto_params.description, + enable_lfs: proto_params.enable_lfs, + }; + + let storage_root = inner.storage_root; + let source_storage_path = inner.source_storage_path; + + let storage_path = ForkRepoParams::fork_bare( + storage_root, + source_storage_path, + params, + ) + .await + .map_err(to_status)?; + + Ok(Response::new(p::ForkBareResponse { storage_path })) + } +} diff --git a/lib/git/rpc/init.rs b/lib/git/rpc/init.rs new file mode 100644 index 0000000..ac8b7a0 --- /dev/null +++ b/lib/git/rpc/init.rs @@ -0,0 +1,83 @@ +use std::sync::Arc; + +use tonic::{Request, Response, Status}; + +use crate::{ + cmd::init::InitRepositoriesParams, + rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, + }, + sync::ReceiveSyncService, +}; + +pub struct InitServiceImpl { + pub registry: Arc, + pub sync: ReceiveSyncService, +} + +#[tonic::async_trait] +impl p::init_service_server::InitService for InitServiceImpl { + async fn init_bare( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let proto_params = inner.params.unwrap_or_default(); + let namespace = proto_params.namespace.clone(); + let repo_name = proto_params.repo_name.clone(); + + let params = InitRepositoriesParams { + namespace, + repo_name: repo_name.clone(), + default_branch: proto_params.default_branch, + description: proto_params.description, + initialize_with_readme: proto_params.initialize_with_readme, + enable_lfs: proto_params.enable_lfs, + }; + + let storage_root = inner.storage_root; + let storage_path = + InitRepositoriesParams::init_bare(storage_root, params) + .await + .map_err(to_status)?; + + // Note: init_bare creates a new repo; sync is triggered by the caller (service layer) + // because the repo UUID is assigned in the DB, not here. + + Ok(Response::new(p::InitBareResponse { storage_path })) + } + + async fn set_default_branch( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let branch_name = inner.branch_name; + tokio::task::spawn_blocking(move || bare.set_default_branch(&branch_name)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::SetDefaultBranchResponse {})) + } + + async fn clone_bare( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let params = crate::cmd::init::CloneRepoParams { + namespace: inner.namespace, + repo_name: inner.repo_name, + source_url: inner.source_url, + }; + let storage_root = inner.storage_root; + let storage_path = + crate::cmd::init::CloneRepoParams::clone_bare(storage_root, params) + .await + .map_err(to_status)?; + Ok(Response::new(p::CloneBareResponse { storage_path })) + } +} diff --git a/lib/git/rpc/merge.rs b/lib/git/rpc/merge.rs new file mode 100644 index 0000000..d444f53 --- /dev/null +++ b/lib/git/rpc/merge.rs @@ -0,0 +1,208 @@ +use std::sync::Arc; + +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct MergeServiceImpl { + pub registry: Arc, +} + +#[tonic::async_trait] +impl p::merge_service_server::MergeService for MergeServiceImpl { + async fn merge_base( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let a: crate::cmd::oid::ObjectId = + inner.oid_a.unwrap_or_default().into(); + let b: crate::cmd::oid::ObjectId = + inner.oid_b.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || bare.merge_base(a, b)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeBaseResponse { + base_oid: Some(result.into()), + })) + } + + async fn merge_base_many( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let oids = inner.oids.into_iter().map(Into::into).collect::>(); + let result = + tokio::task::spawn_blocking(move || bare.merge_base_many(oids)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeBaseManyResponse { + base_oid: Some(result.into()), + })) + } + + async fn merge_base_octopus( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let oids = inner.oids.into_iter().map(Into::into).collect::>(); + let result = + tokio::task::spawn_blocking(move || bare.merge_base_octopus(oids)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeBaseOctopusResponse { + base_oid: Some(result.into()), + })) + } + + async fn merge_analysis( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let their_commit: crate::cmd::oid::ObjectId = + inner.oid_a.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || { + bare.merge_analysis(their_commit) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeAnalysisResponse { + analysis: Some(result.0.into()), + preference: Some(result.1.into()), + })) + } + + async fn merge_analysis_for_ref( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let ref_name = inner.ref_name.clone(); + let result = tokio::task::spawn_blocking(move || { + bare.merge_analysis_for_ref(ref_name) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeAnalysisForRefResponse { + analysis: Some(result.0.into()), + preference: Some(result.1.into()), + })) + } + + async fn merge_is_conflicted( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = + tokio::task::spawn_blocking(move || bare.merge_is_conflicted()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeIsConflictedResponse { + is_conflicted: result, + })) + } + + async fn mergehead_list( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = tokio::task::spawn_blocking(move || bare.mergehead_list()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeheadListResponse { + oids: result.into_iter().map(Into::into).collect(), + })) + } + + async fn merge_tree( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let ours: crate::cmd::oid::ObjectId = + inner.ours.unwrap_or_default().into(); + let theirs: crate::cmd::oid::ObjectId = + inner.theirs.unwrap_or_default().into(); + let options = inner.options.map(Into::into); + let result = tokio::task::spawn_blocking(move || { + bare.merge_tree(ours, theirs, options) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeTreeResponse { + result: Some(result.into()), + })) + } + + async fn merge_commit( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.merge_commit(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeCommitResponse { + oid: Some(result.into()), + })) + } + + async fn squash_commit( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.squash_commit(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::SquashCommitResponse { + oid: Some(result.into()), + })) + } + + async fn merge_abort( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let _result = tokio::task::spawn_blocking(move || bare.merge_abort()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::MergeAbortResponse {})) + } +} diff --git a/lib/git/rpc/mod.rs b/lib/git/rpc/mod.rs new file mode 100644 index 0000000..5c18092 --- /dev/null +++ b/lib/git/rpc/mod.rs @@ -0,0 +1,19 @@ +pub mod archive; +pub mod blame; +pub mod blob; +pub mod branch; +pub mod commit; +pub mod convert; +pub mod diff; +pub mod error; +pub mod fork; +pub mod init; +pub mod merge; +pub mod registry; +pub mod server; +pub mod tag; +pub mod tree; + +pub mod proto { + tonic::include_proto!("git.v1"); +} diff --git a/lib/git/rpc/registry.rs b/lib/git/rpc/registry.rs new file mode 100644 index 0000000..fcfe0a2 --- /dev/null +++ b/lib/git/rpc/registry.rs @@ -0,0 +1,92 @@ +use std::{path::PathBuf, sync::Arc}; + +use cache::AppCache; +use dashmap::DashMap; +use db::database::AppDatabase; +use model::repos::RepoModel; +use sqlx::query_as; +use uuid::Uuid; + +use crate::bare::GitBare; + +const STORAGE_PATH_CACHE_KEY_PREFIX: &str = "git:rpc:repo:storage_path"; + +#[derive(Clone)] +pub struct RepoRegistry { + repos: DashMap, + db: AppDatabase, + cache: AppCache, +} + +impl RepoRegistry { + pub fn new(db: AppDatabase, cache: AppCache) -> Self { + Self { + repos: DashMap::new(), + db, + cache, + } + } + + pub fn register(&self, repo_id: Uuid, bare_dir: PathBuf) { + let bare = GitBare { bare_dir }; + self.repos.insert(repo_id, bare); + } + + pub fn unregister(&self, repo_id: &Uuid) { + self.repos.remove(repo_id); + } + + pub async fn get( + &self, + repo_id_str: &str, + ) -> Result { + let repo_id = repo_id_str.parse::().map_err(|e| { + tonic::Status::invalid_argument(format!( + "invalid repo_id UUID: {e}" + )) + })?; + + if let Some(bare) = self.repos.get(&repo_id) { + return Ok(bare.value().clone()); + } + + let storage_path = self.lookup_storage_path(repo_id).await?; + let bare = GitBare { + bare_dir: PathBuf::from(storage_path), + }; + self.repos.insert(repo_id, bare.clone()); + Ok(bare) + } + + async fn lookup_storage_path( + &self, + repo_id: Uuid, + ) -> Result { + let cache_key = format!("{STORAGE_PATH_CACHE_KEY_PREFIX}:{repo_id}"); + + if let Ok(Some(path)) = self.cache.get::(&cache_key).await { + return Ok(path); + } + + let model: RepoModel = query_as( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at FROM repo WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(repo_id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| tonic::Status::internal(format!("database error: {e}")))?; + + let path = model.storage_path; + let _ = self.cache.set(&cache_key, &path).await; + + Ok(path) + } + + pub fn list(&self) -> Vec { + self.repos.iter().map(|r| *r.key()).collect() + } +} + +pub fn shared_registry(db: AppDatabase, cache: AppCache) -> Arc { + Arc::new(RepoRegistry::new(db, cache)) +} diff --git a/lib/git/rpc/server.rs b/lib/git/rpc/server.rs new file mode 100644 index 0000000..3dcbeb9 --- /dev/null +++ b/lib/git/rpc/server.rs @@ -0,0 +1,156 @@ +use std::{net::SocketAddr, path::PathBuf, sync::Arc}; + +use cache::AppCache; +use db::database::AppDatabase; +use tonic::transport::Server; +use uuid::Uuid; + +use crate::sync::ReceiveSyncService; +use crate::rpc::{ + archive::ArchiveServiceImpl, + blame::BlameServiceImpl, + blob::BlobServiceImpl, + branch::BranchServiceImpl, + commit::CommitServiceImpl, + diff::DiffServiceImpl, + fork::ForkServiceImpl, + init::InitServiceImpl, + merge::MergeServiceImpl, + proto::{ + archive_service_server::ArchiveServiceServer, + blame_service_server::BlameServiceServer, + blob_service_server::BlobServiceServer, + branch_service_server::BranchServiceServer, + commit_service_server::CommitServiceServer, + diff_service_server::DiffServiceServer, + fork_service_server::ForkServiceServer, + init_service_server::InitServiceServer, + merge_service_server::MergeServiceServer, + tag_service_server::TagServiceServer, + tree_service_server::TreeServiceServer, + }, + registry::{RepoRegistry, shared_registry}, + tag::TagServiceImpl, + tree::TreeServiceImpl, +}; + +type RepoId = Uuid; + +pub struct GitServer { + addr: SocketAddr, + pub registry: Arc, + pub cache: AppCache, + pub sync: ReceiveSyncService, +} + +impl GitServer { + pub fn new(addr: SocketAddr, db: AppDatabase, cache: AppCache, sync: ReceiveSyncService) -> Self { + Self { + addr, + cache: cache.clone(), + registry: shared_registry(db, cache), + sync, + } + } + + pub fn register_repo(&self, repo_id: Uuid, bare_dir: PathBuf) { + self.registry.register(repo_id, bare_dir); + } + + pub async fn serve(self) -> Result<(), Box> { + let archive = ArchiveServiceServer::new(ArchiveServiceImpl { + registry: self.registry.clone(), + }); + let blame = BlameServiceServer::new(BlameServiceImpl { + registry: self.registry.clone(), + cache: self.cache.clone(), + }); + let blob = BlobServiceServer::new(BlobServiceImpl { + registry: self.registry.clone(), + cache: self.cache.clone(), + }); + let branch = BranchServiceServer::new(BranchServiceImpl { + registry: self.registry.clone(), + }); + let commit = CommitServiceServer::new(CommitServiceImpl { + registry: self.registry.clone(), + cache: self.cache.clone(), + sync: self.sync.clone(), + }); + let diff = DiffServiceServer::new(DiffServiceImpl { + registry: self.registry.clone(), + cache: self.cache.clone(), + }); + let fork = ForkServiceServer::new(ForkServiceImpl { + registry: self.registry.clone(), + }); + let init = InitServiceServer::new(InitServiceImpl { + registry: self.registry.clone(), + sync: self.sync.clone(), + }); + let merge = MergeServiceServer::new(MergeServiceImpl { + registry: self.registry.clone(), + }); + let tag = TagServiceServer::new(TagServiceImpl { + registry: self.registry.clone(), + }); + let tree = TreeServiceServer::new(TreeServiceImpl { + registry: self.registry.clone(), + cache: self.cache.clone(), + }); + + Server::builder() + .add_service(archive) + .add_service(blame) + .add_service(blob) + .add_service(branch) + .add_service(commit) + .add_service(diff) + .add_service(fork) + .add_service(init) + .add_service(merge) + .add_service(tag) + .add_service(tree) + .serve(self.addr) + .await?; + + Ok(()) + } +} + +pub struct GitServerBuilder { + addr: SocketAddr, + db: AppDatabase, + cache: AppCache, + sync: ReceiveSyncService, + repos: Vec<(RepoId, PathBuf)>, +} + +impl GitServerBuilder { + pub fn new(addr: SocketAddr, db: AppDatabase, cache: AppCache, sync: ReceiveSyncService) -> Self { + Self { + addr, + db, + cache, + sync, + repos: Vec::new(), + } + } + + pub fn repo( + mut self, + repo_id: RepoId, + bare_dir: impl Into, + ) -> Self { + self.repos.push((repo_id, bare_dir.into())); + self + } + + pub fn build(self) -> GitServer { + let server = GitServer::new(self.addr, self.db, self.cache, self.sync); + for (repo_id, bare_dir) in self.repos { + server.register_repo(repo_id, bare_dir); + } + server + } +} diff --git a/lib/git/rpc/tag.rs b/lib/git/rpc/tag.rs new file mode 100644 index 0000000..400aba3 --- /dev/null +++ b/lib/git/rpc/tag.rs @@ -0,0 +1,126 @@ +use std::sync::Arc; + +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct TagServiceImpl { + pub registry: Arc, +} + +#[tonic::async_trait] +impl p::tag_service_server::TagService for TagServiceImpl { + async fn tag_list( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = tokio::task::spawn_blocking(move || bare.tag_list()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagListResponse { + tags: result.into_iter().map(Into::into).collect(), + })) + } + + async fn tag_info( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let name = inner.name.clone(); + let result = tokio::task::spawn_blocking(move || bare.tag_info(name)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagInfoResponse { + tag: Some(result.into()), + })) + } + + async fn tag_summary( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let result = tokio::task::spawn_blocking(move || bare.tag_summary()) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagSummaryResponse { + summary: Some(result.into()), + })) + } + + async fn tag_init( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || bare.tag_init(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagInitResponse { + oid: Some(result.into()), + })) + } + + async fn tag_delete( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let _result = + tokio::task::spawn_blocking(move || bare.tag_delete(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagDeleteResponse {})) + } + + async fn tag_rename( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let _result = + tokio::task::spawn_blocking(move || bare.tag_rename(params)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagRenameResponse {})) + } + + async fn tag_update_message( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let params = inner.params.unwrap_or_default().into(); + let result = tokio::task::spawn_blocking(move || { + bare.tag_update_message(params) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TagUpdateMessageResponse { + oid: Some(result.into()), + })) + } +} diff --git a/lib/git/rpc/tree.rs b/lib/git/rpc/tree.rs new file mode 100644 index 0000000..2ab875f --- /dev/null +++ b/lib/git/rpc/tree.rs @@ -0,0 +1,167 @@ +use std::sync::Arc; + +use cache::AppCache; +use tonic::{Request, Response, Status}; + +use crate::rpc::{ + error::{spawn_blocking_error, to_status}, + proto as p, + registry::RepoRegistry, +}; + +pub struct TreeServiceImpl { + pub registry: Arc, + pub cache: AppCache, +} + +#[tonic::async_trait] +impl p::tree_service_server::TreeService for TreeServiceImpl { + async fn tree_entries( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let repo_id = inner.repo_id.clone(); + let oid_val = inner.oid.clone().map(|o| o.value).unwrap_or_default(); + let base_path = inner.base_path.clone(); + let want_last = inner.last; + if !want_last { + let bare = self.registry.get(&repo_id).await?; + let oid = inner.oid.unwrap_or_default().into(); + let entries = tokio::task::spawn_blocking(move || { + use crate::errors::GitError; + bare.tree_entries(oid) + .map_err(|e| GitError::Internal(format!("tree_entries: {e}"))) + .map(|e| e.into_iter().map(Into::into).collect::>()) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + return Ok(Response::new(p::TreeEntriesResponse { entries })); + } + let cache_key = format!( + "git:rpc:cache:tree:entries:{}:{}:{}", + repo_id, oid_val, base_path + ); + + if let Ok(Some(cached)) = self.cache.get::>(&cache_key).await { + return Ok(Response::new(p::TreeEntriesResponse { entries: cached })); + } + + let bare = self.registry.get(&repo_id).await?; + let bare_bg = bare.clone(); + let oid = inner.oid.unwrap_or_default().into(); + let base_path_bg = base_path.clone(); + let cache_bg = self.cache.clone(); + let cache_key_bg = cache_key.clone(); + + let entries = tokio::task::spawn_blocking(move || { + use crate::errors::GitError; + bare.tree_entries(oid) + .map_err(|e| GitError::Internal(format!("tree_entries: {e}"))) + .map(|e| e.into_iter().map(Into::into).collect::>()) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + + let mut response = entries.clone(); + for entry in &mut response { + entry.last_commit_message.clear(); + entry.last_commit_time.clear(); + entry.last_commit_author_name.clear(); + entry.last_commit_author_email.clear(); + } + + tokio::task::spawn(async move { + let enriched = tokio::task::spawn_blocking(move || { + let paths: Vec = entries.iter().map(|e| { + if base_path_bg.is_empty() { + e.name.clone() + } else { + format!("{}/{}", base_path_bg, e.name) + } + }).collect(); + + let last_commits = match bare_bg.last_commits_for_paths(&paths) { + Ok(lc) => lc, + Err(_) => return entries, + }; + + let mut out = entries; + for (i, info) in last_commits.into_iter().enumerate() { + if let Some(info) = info { + if let Some(e) = out.get_mut(i) { + e.last_commit_message = info.message; + e.last_commit_time = info.time; + e.last_commit_author_name = info.author_name; + e.last_commit_author_email = info.author_email; + } + } + } + out + }).await; + + if let Ok(entries) = enriched { + let _ = cache_bg.set(&cache_key_bg, &entries).await; + } + }); + + Ok(Response::new(p::TreeEntriesResponse { entries: response })) + } + + async fn tree_entry_by_path( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let tree_oid = inner.tree_oid.unwrap_or_default().into(); + let path = inner.path; + let result = tokio::task::spawn_blocking(move || { + bare.tree_entry_by_path(tree_oid, path) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TreeEntryByPathResponse { + entry: Some(result.into()), + })) + } + + async fn tree_entry_by_path_from_commit( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let commit_oid = inner.commit_oid.unwrap_or_default().into(); + let path = inner.path; + let result = tokio::task::spawn_blocking(move || { + bare.tree_entry_by_path_from_commit(commit_oid, path) + }) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::TreeEntryByPathFromCommitResponse { + entry: Some(result.into()), + })) + } + + async fn resolve_tree( + &self, + req: Request, + ) -> Result, Status> { + let inner = req.into_inner(); + let bare = self.registry.get(&inner.repo_id).await?; + let oid = inner.oid.unwrap_or_default().into(); + let result = + tokio::task::spawn_blocking(move || bare.resolve_tree(oid)) + .await + .map_err(spawn_blocking_error)? + .map_err(to_status)?; + Ok(Response::new(p::ResolveTreeResponse { + info: Some(result.into()), + })) + } +} diff --git a/lib/git/ssh/authz.rs b/lib/git/ssh/authz.rs new file mode 100644 index 0000000..1f9d680 --- /dev/null +++ b/lib/git/ssh/authz.rs @@ -0,0 +1,325 @@ +use base64::{Engine as _, engine::general_purpose}; +use db::{database::AppDatabase, sqlx}; +use model::{ + repos::{RepoHistoryNameModel, RepoModel}, + users::{UserModel, UserSshKeyModel}, + workspace::{WkHistoryNameModel, WkMemberModel, WorkspaceModel}, +}; +use sha2::{Digest, Sha256}; + +use crate::errors::GitError; + +pub struct SshAuthService { + db: AppDatabase, +} + +pub struct SshKeyUser { + pub user: UserModel, + pub key_id: i64, + pub key_title: String, +} + +impl SshAuthService { + pub fn new(db: AppDatabase) -> Self { + Self { db } + } + + pub async fn find_repo( + &self, + namespace: &str, + repo_name: &str, + ) -> Result { + let namespace = self.find_namespace(namespace).await?; + self.find_repository_by_name_and_wk(repo_name, namespace.id) + .await + } + + async fn find_namespace( + &self, + namespace: &str, + ) -> Result { + let workspace = sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at FROM workspace WHERE name = $1", + ) + .bind(namespace) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + if let Some(ws) = workspace { + return Ok(ws); + } + + let history = sqlx::query_as::<_, WkHistoryNameModel>( + "SELECT id, wk, name, changed_by, created_at FROM wk_history_name WHERE name = $1", + ) + .bind(namespace) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + if let Some(history) = history { + let ws = sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at FROM workspace WHERE id = $1", + ) + .bind(history.wk) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + if let Some(ws) = ws { + return Ok(ws); + } + } + + Err(GitError::NotFound("Workspace not found".to_string())) + } + + async fn find_repository_by_name_and_wk( + &self, + repo_name: &str, + wk_id: uuid::Uuid, + ) -> Result { + let repo = sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at FROM repo WHERE name = $1 AND wk = $2", + ) + .bind(repo_name) + .bind(wk_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + if let Some(repo) = repo { + return Ok(repo); + } + + let history = sqlx::query_as::<_, RepoHistoryNameModel>( + "SELECT id, repo, name, changed_by, created_at FROM repo_history_name WHERE name = $1", + ) + .bind(repo_name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + if let Some(history) = history { + let repo = sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at FROM repo WHERE id = $1 AND wk = $2", + ) + .bind(history.repo) + .bind(wk_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + if let Some(repo) = repo { + return Ok(repo); + } + } + + Err(GitError::NotFound("Repository not found".to_string())) + } + + pub async fn find_user_by_public_key( + &self, + public_key_str: &str, + ) -> Result, sqlx::Error> { + let fingerprint = + match self.generate_fingerprint_from_public_key(public_key_str) { + Ok(fp) => fp, + Err(e) => { + tracing::error!( + "failed to generate SSH key fingerprint error={}", + e + ); + return Ok(None); + } + }; + + let fingerprint_preview = if fingerprint.len() > 16 { + format!("{}...", &fingerprint[..16]) + } else { + fingerprint.clone() + }; + tracing::info!( + "looking up user with SSH key fingerprint={}", + fingerprint_preview + ); + + let ssh_key = sqlx::query_as::<_, UserSshKeyModel>( + "SELECT id, \"user\", title, public_key, fingerprint, key_type, key_bits, is_verified, last_used_at, expires_at, is_revoked, created_at, updated_at FROM user_ssh_key WHERE fingerprint = $1 AND is_revoked = false", + ) + .bind(&fingerprint) + .fetch_optional(self.db.reader()) + .await?; + + let ssh_key = match ssh_key { + Some(key) => key, + None => { + tracing::warn!("no SSH key found fingerprint={}", fingerprint); + return Ok(None); + } + }; + + if self.is_key_expired(&ssh_key) { + tracing::warn!( + "SSH key expired key_id={} expires_at={:?}", + ssh_key.id, + ssh_key.expires_at + ); + return Ok(None); + } + + let user_model = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, last_sign_in_at, created_at, updated_at FROM \"user\" WHERE id = $1", + ) + .bind(ssh_key.user) + .fetch_optional(self.db.reader()) + .await?; + + if let Some(user) = user_model { + tracing::info!( + "SSH key matched user={} key={}", + user.username, + ssh_key.title + ); + return Ok(Some(SshKeyUser { + user, + key_id: ssh_key.id, + key_title: ssh_key.title, + })); + } + + Ok(None) + } + + fn is_key_expired(&self, ssh_key: &UserSshKeyModel) -> bool { + if let Some(expires_at) = ssh_key.expires_at { + let now = chrono::Utc::now(); + now >= expires_at + } else { + false + } + } + + pub fn update_key_last_used_async(&self, key_id: i64) { + let db_clone = self.db.clone(); + tokio::spawn(async move { + if let Err(e) = + Self::update_key_last_used_sync(db_clone, key_id).await + { + tracing::warn!( + "failed to update key last_used key_id={} error={}", + key_id, + e + ); + } + }); + } + + async fn update_key_last_used_sync( + db: AppDatabase, + key_id: i64, + ) -> Result<(), sqlx::Error> { + let now = chrono::Utc::now(); + sqlx::query("UPDATE user_ssh_key SET last_used_at = $1, updated_at = $2 WHERE id = $3") + .bind(now) + .bind(now) + .bind(key_id) + .execute(db.writer()) + .await?; + + tracing::info!("updated key last_used key_id={}", key_id); + Ok(()) + } + + pub async fn check_repo_permission( + &self, + user: &UserModel, + repo: &RepoModel, + is_write: bool, + ) -> bool { + if repo.created_by == user.id { + tracing::info!( + "user is repo owner user={} repo={}", + user.username, + repo.name + ); + return true; + } + + if !is_write && repo.visibility == "public" { + tracing::info!("public repo allows read access repo={}", repo.name); + return true; + } + + let wk_id = repo.wk; + if self + .check_wk_member_permission(user, wk_id, is_write) + .await + .unwrap_or(false) + { + tracing::info!( + "user has workspace member access user={} repo={}", + user.username, + repo.name + ); + return true; + } + + tracing::warn!( + "access denied user={} repo={} is_write={}", + user.username, + repo.name, + is_write + ); + false + } + + async fn check_wk_member_permission( + &self, + user: &UserModel, + wk_id: uuid::Uuid, + is_write: bool, + ) -> Result { + let member = sqlx::query_as::<_, WkMemberModel>( + "SELECT wk, \"user\", owner, admin, join_at, leave_at FROM wk_member WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(wk_id) + .bind(user.id) + .fetch_optional(self.db.reader()) + .await?; + + if let Some(member) = member { + if member.owner || member.admin { + return Ok(true); + } + Ok(!is_write) + } else { + Ok(false) + } + } + + fn generate_fingerprint_from_public_key( + &self, + public_key_str: &str, + ) -> Result { + let key_data_base64 = public_key_str + .split_whitespace() + .nth(1) + .ok_or("Invalid SSH key format")?; + + let key_data = general_purpose::STANDARD + .decode(key_data_base64) + .map_err(|e| format!("Base64 decode error: {}", e))?; + + let mut hasher = Sha256::new(); + hasher.update(&key_data); + let hash = hasher.finalize(); + + let mut fingerprint = String::with_capacity(51); + fingerprint.push_str("SHA256:"); + fingerprint.push_str(&general_purpose::STANDARD_NO_PAD.encode(&hash)); + + Ok(fingerprint) + } +} diff --git a/lib/git/ssh/branch_protect.rs b/lib/git/ssh/branch_protect.rs new file mode 100644 index 0000000..abcfd4d --- /dev/null +++ b/lib/git/ssh/branch_protect.rs @@ -0,0 +1,54 @@ +use model::repos::RepoProtectModel; + +use crate::ssh::ref_update::RefUpdate; + +fn ref_matches_protection(ref_name: &str, protection_pattern: &str) -> bool { + ref_name == protection_pattern + || ref_name.starts_with(&format!("{}/", protection_pattern)) +} + +pub fn check_branch_protection( + branch_protects: &[RepoProtectModel], + r#ref: &RefUpdate, +) -> Option { + for protection in branch_protects { + if !ref_matches_protection(&r#ref.name, &protection.pattern) { + continue; + } + + if r#ref.new_oid == "0000000000000000000000000000000000000000" { + if !protection.allow_deletions { + return Some(format!( + "protected branch rejected. Deletion of '{}' is forbidden. Create a PR or ask a maintainer to update branch protection.", + r#ref.name + )); + } + continue; + } + + if r#ref.name.starts_with("refs/tags/") { + continue; + } + + let is_new_branch = + r#ref.old_oid == "0000000000000000000000000000000000000000"; + if !is_new_branch + && r#ref.old_oid != r#ref.new_oid + && r#ref.name.starts_with("refs/heads/") + && !protection.allow_force_pushes + { + return Some(format!( + "protected branch rejected. Force push to '{}' is forbidden. Create a PR instead of rewriting protected history.", + r#ref.name + )); + } + + if protection.require_pull_request { + return Some(format!( + "protected branch rejected. Direct push to '{}' is forbidden. Please push to a feature branch and create a PR.", + r#ref.name + )); + } + } + None +} diff --git a/lib/git/ssh/forward.rs b/lib/git/ssh/forward.rs new file mode 100644 index 0000000..7a365d9 --- /dev/null +++ b/lib/git/ssh/forward.rs @@ -0,0 +1,51 @@ +use std::{future::Future, time::Duration}; + +use actix_web::web::Bytes; +use russh::{ChannelId, server::Handle}; +use tokio::{ + io::{AsyncRead, AsyncReadExt}, + time::sleep, +}; + +pub async fn forward<'a, R, Fut, Fwd>( + session_handle: &'a Handle, + chan_id: ChannelId, + r: &mut R, + mut fwd: Fwd, +) -> Result<(), russh::Error> +where + R: AsyncRead + Send + Unpin, + Fut: Future> + 'a, + Fwd: FnMut(&'a Handle, ChannelId, Bytes) -> Fut, +{ + const BUF_SIZE: usize = 1024 * 32; + const MAX_RETRIES: usize = 5; + const RETRY_DELAY: u64 = 10; + + let mut buf = [0u8; BUF_SIZE]; + loop { + let read = r.read(&mut buf).await.map_err(russh::Error::IO)?; + + if read == 0 { + break; + } + + let mut chunk = Bytes::copy_from_slice(&buf[..read]); + let mut retries = 0; + loop { + match fwd(session_handle, chan_id, chunk).await { + Ok(()) => break, + Err(unsent) => { + retries += 1; + if retries >= MAX_RETRIES { + return Ok(()); + } + chunk = unsent; + sleep(Duration::from_millis(RETRY_DELAY)).await; + } + } + } + } + + Ok(()) +} diff --git a/lib/git/ssh/git_service.rs b/lib/git/ssh/git_service.rs new file mode 100644 index 0000000..3ed86af --- /dev/null +++ b/lib/git/ssh/git_service.rs @@ -0,0 +1,92 @@ +use std::{path::PathBuf, str::FromStr}; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum GitService { + UploadPack, + ReceivePack, + UploadArchive, +} + +impl FromStr for GitService { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "upload-pack" => Ok(Self::UploadPack), + "receive-pack" => Ok(Self::ReceivePack), + "upload-archive" => Ok(Self::UploadArchive), + _ => Err(()), + } + } +} + +pub fn parse_git_command(cmd: &str) -> Option<(GitService, &str)> { + let (svc, path) = match cmd.split_once(' ') { + Some(("git-receive-pack", path)) => (GitService::ReceivePack, path), + Some(("git-upload-pack", path)) => (GitService::UploadPack, path), + Some(("git-upload-archive", path)) => (GitService::UploadArchive, path), + _ => return None, + }; + Some((svc, strip_apostrophes(path))) +} + +pub fn parse_repo_path(path: &str) -> Option<(&str, &str)> { + let path = path.trim_matches('/'); + let mut parts = path.splitn(2, '/'); + match (parts.next(), parts.next()) { + (Some(owner), Some(repo)) if !owner.is_empty() && !repo.is_empty() => { + Some((owner, repo)) + } + _ => None, + } +} + +pub fn build_git_command( + service: GitService, + path: PathBuf, +) -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("git"); + + let cwd = match path.canonicalize() { + Ok(p) => p, + Err(e) => { + tracing::debug!(error = %e, "path canonicalize failed, falling back to raw path"); + path.clone() + } + }; + cmd.current_dir(cwd); + + match service { + GitService::UploadPack => { + cmd.arg("upload-pack"); + } + GitService::ReceivePack => { + cmd.arg("receive-pack"); + } + GitService::UploadArchive => { + cmd.arg("upload-archive"); + } + } + + cmd.arg(".") + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_NO_REPLACE_OBJECTS", "1"); + + #[cfg(unix)] + { + cmd.env("GIT_CONFIG_GLOBAL", "/dev/null") + .env("GIT_CONFIG_SYSTEM", "/dev/null"); + } + #[cfg(windows)] + { + let nul = "NUL"; + cmd.env("GIT_CONFIG_GLOBAL", nul) + .env("GIT_CONFIG_SYSTEM", nul); + } + + cmd +} + +fn strip_apostrophes(s: &str) -> &str { + s.trim_matches('\'') +} diff --git a/lib/git/ssh/handler.rs b/lib/git/ssh/handler.rs new file mode 100644 index 0000000..b48312e --- /dev/null +++ b/lib/git/ssh/handler.rs @@ -0,0 +1,1064 @@ +use std::{ + collections::{HashMap, HashSet}, + io, + net::SocketAddr, + path::PathBuf, + process::Stdio, + sync::Arc, + time::Duration, +}; + +use cache::AppCache; +use db::{database::AppDatabase, sqlx}; +use model::{ + repos::{RepoModel, RepoProtectModel}, + users::UserModel, +}; +use russh::{ + Channel, ChannelId, Disconnect, + keys::{Certificate, PublicKey}, + server::{Auth, Msg, Session}, +}; +use tokio::{ + io::AsyncWriteExt, + process::ChildStdin, + sync::{Mutex, mpsc::Sender}, + time::sleep, +}; + +use crate::{ + ssh::{ + SshTokenService, + authz::SshAuthService, + branch_protect::check_branch_protection, + forward::forward, + git_service::{ + GitService, build_git_command, parse_git_command, parse_repo_path, + }, + ref_update::RefUpdate, + }, + sync::{ + ReceiveSyncService, RepoReceiveSyncTask, + push_queue::{ + PushQueueEvent, PushQueueWaitError, wait_for_push_queue_slot, + }, + }, +}; + +const PRE_PACK_LIMIT: usize = 1_048_576; +const ZERO_OID: &str = "0000000000000000000000000000000000000000"; + +pub struct SSHandle { + pub repo: Option, + pub model: Option, + pub stdin: HashMap, + pub eof: HashMap>, + pub operator: Option, + pub db: AppDatabase, + pub auth: SshAuthService, + pub buffer: HashMap>, + pub branch: HashMap>, + pub post_receive_refs: HashMap>>>, + pub service: Option, + pub cache: AppCache, + pub sync: ReceiveSyncService, + pub upload_pack_eof_sent: HashSet, + pub token_service: SshTokenService, + pub client_addr: Option, +} + +impl SSHandle { + pub fn new( + db: AppDatabase, + cache: AppCache, + sync: ReceiveSyncService, + token_service: SshTokenService, + client_addr: Option, + ) -> Self { + let auth = SshAuthService::new(db.clone()); + let addr_str = client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + tracing::info!("SSH handler created client={}", addr_str); + Self { + repo: None, + model: None, + stdin: HashMap::new(), + eof: HashMap::new(), + operator: None, + db, + auth, + buffer: HashMap::new(), + branch: HashMap::new(), + post_receive_refs: HashMap::new(), + service: None, + cache, + sync, + upload_pack_eof_sent: HashSet::new(), + token_service, + client_addr, + } + } + + fn cleanup_channel(&mut self, channel_id: ChannelId) { + if let Some(stdin) = self.stdin.remove(&channel_id) { + let channel_id_for_task = channel_id; + tokio::spawn(async move { + let _ = tokio::time::timeout(Duration::from_secs(5), async { + let mut stdin = stdin; + if let Err(e) = stdin.flush().await { + tracing::warn!(error = %e, "ssh_cleanup_flush_failed channel={:?}", channel_id_for_task); + } + let _ = stdin.shutdown().await; + }) + .await; + }); + } + self.eof.remove(&channel_id); + self.post_receive_refs.remove(&channel_id); + self.upload_pack_eof_sent.remove(&channel_id); + } + + fn format_post_receive_hints( + namespace: &str, + repo: &RepoModel, + refs: &[RefUpdate], + queue: Option<(usize, usize)>, + ) -> String { + let mut lines = Vec::new(); + for r#ref in refs { + if r#ref.old_oid == ZERO_OID + && r#ref.name.starts_with("refs/heads/") + { + let branch = r#ref.name.trim_start_matches("refs/heads/"); + lines.push(format!( + "remote: new branch '{}' pushed. Create a PR: /{}/repo/{}/pulls/new?head={}\r\n", + branch, + namespace, + repo.name, + branch + )); + } + } + if let Some((position, total)) = queue { + lines.push(format!( + "remote: repository sync queued ({}/{}). Metadata, webhooks and search indexes will update shortly.\r\n", + position, total + )); + } + lines.concat() + } +} + +impl Drop for SSHandle { + fn drop(&mut self) { + let addr_str = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + tracing::info!("ssh_handler_dropped client={}", addr_str); + + let channel_ids: Vec<_> = self.stdin.keys().copied().collect(); + for channel_id in channel_ids { + self.cleanup_channel(channel_id); + } + } +} + +impl russh::server::Handler for SSHandle { + type Error = russh::Error; + + async fn auth_none(&mut self, user: &str) -> Result { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + tracing::info!( + "auth_none_received user={} client={}", + user, + client_info + ); + Ok(Auth::UnsupportedMethod) + } + + async fn auth_password( + &mut self, + _user: &str, + token: &str, + ) -> Result { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + + if token.is_empty() { + tracing::warn!("auth_rejected_empty_token client={}", client_info); + return Err(russh::Error::NotAuthenticated); + } + + tracing::info!("auth_token_attempt client={}", client_info); + + let user_model = + match self.token_service.find_user_by_token(token).await { + Ok(Some(model)) => model, + Ok(None) => { + tracing::warn!( + "auth_rejected_token_not_found client={}", + client_info + ); + return Err(russh::Error::NotAuthenticated); + } + Err(e) => { + tracing::error!("auth_token_error error={}", e); + return Err(russh::Error::NotAuthenticated); + } + }; + + tracing::info!( + "auth_token_success user={} client={}", + user_model.username, + client_info + ); + self.operator = Some(user_model); + Ok(Auth::Accept) + } + + async fn auth_publickey_offered( + &mut self, + user: &str, + public_key: &PublicKey, + ) -> Result { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + + if user != "git" { + tracing::warn!( + "auth_publickey_offer_rejected_invalid_username user={} client={}", + user, + client_info + ); + return Err(russh::Error::NotAuthenticated); + } + + let public_key_str = public_key.to_string(); + if public_key_str.len() < 32 { + tracing::warn!( + "auth_publickey_offer_rejected_invalid_key_length key_length={}", + public_key_str.len() + ); + return Err(russh::Error::NotAuthenticated); + } + + tracing::info!("auth_publickey_offer client={}", client_info); + match self.auth.find_user_by_public_key(&public_key_str).await { + Ok(Some(key_user)) => { + tracing::info!( + "auth_publickey_offer_accepted user={} key={} client={}", + key_user.user.username, + key_user.key_title, + client_info + ); + Ok(Auth::Accept) + } + Ok(None) => { + tracing::warn!( + "auth_publickey_offer_rejected_key_not_found client={}", + client_info + ); + Err(russh::Error::NotAuthenticated) + } + Err(e) => { + tracing::error!("auth_publickey_offer_error error={}", e); + Err(russh::Error::NotAuthenticated) + } + } + } + + async fn auth_publickey( + &mut self, + user: &str, + public_key: &PublicKey, + ) -> Result { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + + if user != "git" { + tracing::warn!( + "auth_rejected_invalid_username user={} client={}", + user, + client_info + ); + return Err(russh::Error::NotAuthenticated); + } + let public_key_str = public_key.to_string(); + if public_key_str.len() < 32 { + tracing::warn!( + "auth_rejected_invalid_key_length key_length={}", + public_key_str.len() + ); + return Err(russh::Error::NotAuthenticated); + } + + tracing::info!("auth_publickey_attempt client={}", client_info); + let key_user = + match self.auth.find_user_by_public_key(&public_key_str).await { + Ok(Some(key_user)) => key_user, + Ok(None) => { + tracing::warn!( + "auth_rejected_key_not_found client={}", + client_info + ); + return Err(russh::Error::NotAuthenticated); + } + Err(e) => { + tracing::error!("auth_publickey_error error={}", e); + return Err(russh::Error::NotAuthenticated); + } + }; + + tracing::info!( + "auth_publickey_success user={} client={}", + key_user.user.username, + client_info + ); + self.auth.update_key_last_used_async(key_user.key_id); + self.operator = Some(key_user.user); + Ok(Auth::Accept) + } + + async fn auth_openssh_certificate( + &mut self, + user: &str, + certificate: &Certificate, + ) -> Result { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + + if user != "git" { + tracing::warn!( + "auth_rejected_invalid_username user={} client={}", + user, + client_info + ); + return Err(russh::Error::NotAuthenticated); + } + let public_key_str = certificate.to_string(); + if public_key_str.len() < 32 { + tracing::warn!( + "auth_rejected_invalid_key_length key_length={}", + public_key_str.len() + ); + return Err(russh::Error::NotAuthenticated); + } + + tracing::info!("auth_publickey_attempt client={}", client_info); + let key_user = + match self.auth.find_user_by_public_key(&public_key_str).await { + Ok(Some(key_user)) => key_user, + Ok(None) => { + tracing::warn!( + "auth_rejected_key_not_found client={}", + client_info + ); + return Err(russh::Error::NotAuthenticated); + } + Err(e) => { + tracing::error!("auth_publickey_error error={}", e); + return Err(russh::Error::NotAuthenticated); + } + }; + + tracing::info!( + "auth_publickey_success user={} client={}", + key_user.user.username, + client_info + ); + self.auth.update_key_last_used_async(key_user.key_id); + self.operator = Some(key_user.user); + Ok(Auth::Accept) + } + + async fn channel_close( + &mut self, + channel: ChannelId, + _: &mut Session, + ) -> Result<(), Self::Error> { + tracing::info!( + "channel_close channel={:?} client={:?}", + channel, + self.client_addr + ); + self.cleanup_channel(channel); + Ok(()) + } + + async fn channel_eof( + &mut self, + channel: ChannelId, + _: &mut Session, + ) -> Result<(), Self::Error> { + tracing::info!( + "channel_eof channel={:?} client={:?}", + channel, + self.client_addr + ); + + if let Some(eof) = self.eof.get(&channel) { + let _ = eof.send(true).await; + } + + if let Some(mut stdin) = self.stdin.remove(&channel) { + tracing::info!( + "Closing stdin channel={:?} client={:?}", + channel, + self.client_addr + ); + let _ = tokio::time::timeout(Duration::from_secs(5), async { + if let Err(e) = stdin.flush().await { + tracing::warn!(error = %e, "ssh_eof_flush_failed channel={:?}", channel); + } + let _ = stdin.shutdown().await; + }) + .await; + } + + Ok(()) + } + + async fn channel_open_session( + &mut self, + channel: Channel, + session: &mut Session, + ) -> Result { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + tracing::info!( + "channel_open_session channel={:?} client={}", + channel, + client_info + ); + if let Err(e) = session.flush() { + tracing::warn!(error = %e, "ssh_session_flush_failed"); + } + Ok(true) + } + + async fn pty_request( + &mut self, + channel: ChannelId, + term: &str, + col_width: u32, + row_height: u32, + _pix_width: u32, + _pix_height: u32, + _modes: &[(russh::Pty, u32)], + session: &mut Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + "pty_request not supported channel={:?} term={} cols={} rows={}", + channel, + term, + col_width, + row_height + ); + if let Err(e) = session.flush() { + tracing::warn!(error = %e, "ssh_session_flush_failed"); + } + Ok(()) + } + + async fn subsystem_request( + &mut self, + channel: ChannelId, + name: &str, + session: &mut Session, + ) -> Result<(), Self::Error> { + tracing::info!( + "subsystem_request channel={:?} subsystem={}", + channel, + name + ); + if let Err(e) = session.flush() { + tracing::warn!(error = %e, "ssh_session_flush_failed"); + } + Ok(()) + } + + async fn data( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + if matches!(self.service, Some(GitService::ReceivePack)) { + if !self.branch.contains_key(&channel) { + let bf = self.buffer.entry(channel).or_default(); + + if bf.len() + data.len() > PRE_PACK_LIMIT { + tracing::warn!( + "ssh_pre_pack_too_large channel={:?}", + channel + ); + let msg = "remote: Ref negotiation exceeds size limit\r\n"; + let _ = session.extended_data( + channel, + 1, + msg.as_bytes().to_vec(), + ); + let _ = session.exit_status_request(channel, 1); + let _ = session.eof(channel); + let _ = session.close(channel); + self.cleanup_channel(channel); + return Ok(()); + } + + bf.extend_from_slice(data); + + if !bf.windows(4).any(|w| w == b"0000") { + return Ok(()); + } + + let buffered = self.buffer.remove(&channel).unwrap_or_default(); + + match RefUpdate::parse_ref_updates(&buffered) { + Ok(refs) => { + if let Some(model) = &self.model { + let branch_protect_roles = sqlx::query_as::<_, RepoProtectModel>( + "SELECT id, repo, pattern, require_pull_request, required_approvals, require_status_checks, required_status_contexts, enforce_admins, allow_force_pushes, allow_deletions, created_at, updated_at FROM repo_protect WHERE repo = $1", + ) + .bind(model.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| { + russh::Error::IO(io::Error::new(io::ErrorKind::Other, e)) + })?; + + for r#ref in &refs { + if let Some(msg) = check_branch_protection( + &branch_protect_roles, + r#ref, + ) { + let full_msg = + format!("remote: {}\r\n", msg); + let _ = session.extended_data( + channel, + 1, + full_msg.into_bytes(), + ); + let _ = + session.exit_status_request(channel, 1); + let _ = session.eof(channel); + let _ = session.close(channel); + self.cleanup_channel(channel); + return Ok(()); + } + } + } + if let Some(refs_for_hints) = + self.post_receive_refs.get(&channel) + { + *refs_for_hints.lock().await = refs.clone(); + } + self.branch.insert(channel, refs); + } + Err(e) => { + tracing::warn!("ref_update_parse_error error={:?}", e); + if let Some(refs_for_hints) = + self.post_receive_refs.get(&channel) + { + refs_for_hints.lock().await.clear(); + } + self.branch.insert(channel, vec![]); + } + } + + if let Some(stdin) = self.stdin.get_mut(&channel) { + stdin.write_all(&buffered).await?; + stdin.flush().await?; + } + return Ok(()); + } + + if let Some(stdin) = self.stdin.get_mut(&channel) { + stdin.write_all(data).await?; + stdin.flush().await?; + } + return Ok(()); + } + + if let Some(stdin) = self.stdin.get_mut(&channel) { + stdin.write_all(data).await?; + if matches!(self.service, Some(GitService::UploadPack)) + && !self.upload_pack_eof_sent.contains(&channel) + { + let has_flush_pkt = data.windows(4).any(|w| w == b"0000"); + if has_flush_pkt { + stdin.flush().await?; + let _ = stdin.shutdown().await; + self.upload_pack_eof_sent.insert(channel); + } + } + } + Ok(()) + } + + async fn shell_request( + &mut self, + channel_id: ChannelId, + session: &mut Session, + ) -> Result<(), Self::Error> { + if let Some(user) = &self.operator { + let welcome_msg = format!( + "Hi {}! You've successfully authenticated, but interactive shell access is not provided.\r\n", + user.username + ); + + tracing::info!("shell_request user={}", user.username); + let _ = session.data(channel_id, welcome_msg.into_bytes()); + let _ = session.exit_status_request(channel_id, 0); + let _ = session.eof(channel_id); + let _ = session.close(channel_id); + let _ = session.flush(); + } else { + tracing::warn!( + "shell_request_unauthenticated channel={:?}", + channel_id + ); + let msg = "Authentication required\r\n"; + let _ = session.data(channel_id, msg.as_bytes().to_vec()); + let _ = session.exit_status_request(channel_id, 1); + let _ = session.eof(channel_id); + let _ = session.close(channel_id); + let _ = session.flush(); + } + Ok(()) + } + + async fn exec_request( + &mut self, + channel_id: ChannelId, + data: &[u8], + session: &mut Session, + ) -> Result<(), Self::Error> { + let client_info = self + .client_addr + .map(|addr| format!("{}", addr)) + .unwrap_or_else(|| "unknown".to_string()); + + tracing::info!( + "exec_request received channel={:?} client={}", + channel_id, + client_info + ); + + let git_shell_cmd = match std::str::from_utf8(data) { + Ok(cmd) => cmd.trim(), + Err(e) => { + tracing::error!("invalid_command_encoding error={}", e); + let _ = session.disconnect( + Disconnect::ServiceNotAvailable, + "Invalid command encoding", + "", + ); + return Err(russh::Error::Disconnect); + } + }; + let (service, path) = match parse_git_command(git_shell_cmd) { + Some((s, p)) => (s, p), + None => { + tracing::error!( + "invalid_git_command command={}", + git_shell_cmd + ); + let msg = format!("Invalid git command: {}", git_shell_cmd); + let _ = session.disconnect( + Disconnect::ServiceNotAvailable, + &msg, + "", + ); + return Err(russh::Error::Disconnect); + } + }; + self.service = Some(service); + let (owner, repo) = match parse_repo_path(path) { + Some(pair) => pair, + None => { + let msg = format!("Invalid repository path: {}", path); + tracing::error!("invalid_repo_path path={}", path); + let _ = session.disconnect( + Disconnect::ServiceNotAvailable, + &msg, + "", + ); + return Err(russh::Error::Disconnect); + } + }; + let namespace = owner.to_string(); + let repo = repo.strip_suffix(".git").unwrap_or(repo).to_string(); + + let repo = match self.auth.find_repo(owner, &repo).await { + Ok(repo) => repo, + Err(e) => { + tracing::error!("repo_fetch_error error={}", e); + let _ = session.disconnect( + Disconnect::ServiceNotAvailable, + "Repository not found", + "", + ); + return Err(russh::Error::Disconnect); + } + }; + + self.model = Some(repo.clone()); + let operator = match &self.operator { + Some(user) => user, + None => { + let msg = "Authentication error: no authenticated user"; + tracing::error!( + "exec_no_authenticated_user channel={:?}", + channel_id + ); + let _ = session.disconnect(Disconnect::ByApplication, msg, ""); + return Err(russh::Error::Disconnect); + } + }; + + let is_write = service == GitService::ReceivePack; + let has_permission = self + .auth + .check_repo_permission(operator, &repo, is_write) + .await; + + if !has_permission { + let msg = format!( + "Access denied: user '{}' does not have {} permission for repository {}", + operator.username, + if is_write { "write" } else { "read" }, + repo.name + ); + tracing::error!( + "access_denied user={} repo={} is_write={}", + operator.username, + repo.name, + is_write + ); + let _ = session.disconnect(Disconnect::ByApplication, &msg, ""); + return Err(russh::Error::Disconnect); + } + + tracing::info!( + "access_granted user={} repo={} is_write={}", + operator.username, + repo.name, + is_write + ); + + let mut push_queue_lease = if is_write { + let repo_id = repo.id; + let queue_result = + wait_for_push_queue_slot(self.sync.clone(), repo_id, |event, request_id| { + let request_id = request_id.to_string(); + match event { + PushQueueEvent::Waiting(position) => { + let msg = format!( + "remote: another push is running for this repository. Queued {}/{}.\r\n", + position.position, position.total + ); + let _ = session.extended_data(channel_id, 1, msg.as_bytes().to_vec()); + let _ = session.flush(); + tracing::info!( + repo_id = %repo_id, + request_id = %request_id, + position = position.position, + total = position.total, + "push_queue_waiting" + ); + } + PushQueueEvent::Acquired => { + let msg = "remote: push queue slot acquired. Processing now.\r\n"; + let _ = session.extended_data(channel_id, 1, msg.as_bytes().to_vec()); + let _ = session.flush(); + tracing::info!( + repo_id = %repo_id, + request_id = %request_id, + "push_queue_acquired" + ); + } + } + }) + .await; + + match queue_result { + Ok(lease) => Some(lease), + Err(error) => { + match &error { + PushQueueWaitError::Join(e) => { + tracing::error!(error = %e, repo = %repo.name, "push_queue_join_failed"); + let msg = "remote: push queue is temporarily unavailable. Please retry later.\r\n"; + let _ = session.extended_data( + channel_id, + 1, + msg.as_bytes().to_vec(), + ); + } + PushQueueWaitError::Lock(e) => { + tracing::error!(error = %e, repo_id = %repo.id, "push_queue_lock_failed"); + let msg = "remote: push queue lock failed. Please retry later.\r\n"; + let _ = session.extended_data( + channel_id, + 1, + msg.as_bytes().to_vec(), + ); + } + PushQueueWaitError::Timeout => { + tracing::warn!(repo_id = %repo.id, "push_queue_timeout"); + let msg = "remote: push queue timed out. Please retry in a moment.\r\n"; + let _ = session.extended_data( + channel_id, + 1, + msg.as_bytes().to_vec(), + ); + } + } + let _ = session.channel_failure(channel_id); + let _ = session.close(channel_id); + self.cleanup_channel(channel_id); + return if matches!(error, PushQueueWaitError::Timeout) { + Ok(()) + } else { + Err(russh::Error::IO(io::Error::new( + io::ErrorKind::Other, + error.to_string(), + ))) + }; + } + } + } else { + None + }; + + let repo = match &self.model { + Some(m) => m, + None => { + let msg = "Repository model not available"; + tracing::error!("repo_model_missing"); + let _ = session.disconnect(Disconnect::ByApplication, msg, ""); + return Err(russh::Error::Disconnect); + } + }; + let repo_path = PathBuf::from(&repo.storage_path); + if !repo_path.exists() { + tracing::error!("repo_path_not_found path={}", repo_path.display()); + } + tracing::info!( + "spawn_git_process service={:?} path={}", + service, + repo_path.display() + ); + let mut cmd = build_git_command(service, repo_path); + + let mut shell = match cmd + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + { + Ok(shell) => { + let _ = session.channel_success(channel_id); + shell + } + Err(e) => { + tracing::error!("process_spawn_failed error={}", e); + if let Some(lease) = &mut push_queue_lease { + lease.release().await; + } + let _ = session.channel_failure(channel_id); + self.cleanup_channel(channel_id); + return Err(russh::Error::IO(e)); + } + }; + let session_handle = session.handle(); + let stdin = match shell.stdin.take() { + Some(s) => s, + None => { + tracing::error!( + "stdin pipe unavailable for channel={:?}", + channel_id + ); + if let Some(lease) = &mut push_queue_lease { + lease.release().await; + } + let _ = session_handle.channel_failure(channel_id).await; + return Err(russh::Error::IO(io::Error::new( + io::ErrorKind::Other, + "stdin unavailable", + ))); + } + }; + self.stdin.insert(channel_id, stdin); + let mut shell_stdout = match shell.stdout.take() { + Some(s) => s, + None => { + tracing::error!( + "stdout pipe unavailable for channel={:?}", + channel_id + ); + if let Some(lease) = &mut push_queue_lease { + lease.release().await; + } + return Err(russh::Error::IO(io::Error::new( + io::ErrorKind::Other, + "stdout unavailable", + ))); + } + }; + let mut shell_stderr = match shell.stderr.take() { + Some(s) => s, + None => { + tracing::error!( + "stderr pipe unavailable for channel={:?}", + channel_id + ); + if let Some(lease) = &mut push_queue_lease { + lease.release().await; + } + return Err(russh::Error::IO(io::Error::new( + io::ErrorKind::Other, + "stderr unavailable", + ))); + } + }; + + let (eof_tx, mut eof_rx) = tokio::sync::mpsc::channel::(10); + self.eof.insert(channel_id, eof_tx); + let refs_for_hints = Arc::new(Mutex::new(Vec::new())); + self.post_receive_refs + .insert(channel_id, refs_for_hints.clone()); + let repo_uid = repo.id; + let repo_for_hints = repo.clone(); + let namespace_for_hints = namespace.clone(); + let should_sync = service == GitService::ReceivePack; + let sync = self.sync.clone(); + let mut push_queue_lease = push_queue_lease; + + let fut = async move { + tracing::info!(channel = ?channel_id, "git_task_started"); + + let mut stdout_done = false; + let mut stderr_done = false; + + let stdout_fut = forward( + &session_handle, + channel_id, + &mut shell_stdout, + |handle, chan, data| async move { handle.data(chan, data).await }, + ); + tokio::pin!(stdout_fut); + + let stderr_fut = forward( + &session_handle, + channel_id, + &mut shell_stderr, + |handle, chan, data| async move { + handle.extended_data(chan, 1, data).await + }, + ); + tokio::pin!(stderr_fut); + + loop { + tokio::select! { + result = shell.wait() => { + let status = match result { + Ok(status) => status, + Err(e) => { + if let Some(lease) = &mut push_queue_lease { + lease.release().await; + } + return Err(russh::Error::IO(e)); + } + }; + let status_code = status.code().unwrap_or(128) as u32; + + tracing::info!("git_process_exited channel={:?} status={}", channel_id, status_code); + + if let Some(lease) = &mut push_queue_lease { + lease.release().await; + } + + if !stdout_done || !stderr_done { + let _ = tokio::time::timeout(Duration::from_millis(100), async { + tokio::join!( + async { + if !stdout_done { + let _ = (&mut stdout_fut).await; + } + }, + async { + if !stderr_done { + let _ = (&mut stderr_fut).await; + } + } + ); + }).await; + } + + if should_sync && status_code == 0 { + let queue = sync.send(RepoReceiveSyncTask { repo_uid }).await; + let refs_for_hints = refs_for_hints.lock().await.clone(); + let msg = SSHandle::format_post_receive_hints( + &namespace_for_hints, + &repo_for_hints, + &refs_for_hints, + queue, + ); + if !msg.is_empty() { + let _ = session_handle + .extended_data(channel_id, 1, msg.into_bytes()) + .await; + } + } + + let _ = session_handle.exit_status_request(channel_id, status_code).await; + sleep(Duration::from_millis(50)).await; + let _ = session_handle.eof(channel_id).await; + let _ = session_handle.close(channel_id).await; + tracing::info!(channel = ?channel_id, "channel_closed"); + break; + } + result = &mut stdout_fut, if !stdout_done => { + tracing::info!("stdout completed"); + stdout_done = true; + if let Err(e) = result { + tracing::warn!(error = ?e, "stdout_forward_error"); + } + } + result = &mut stderr_fut, if !stderr_done => { + tracing::info!("stderr completed"); + stderr_done = true; + if let Err(e) = result { + tracing::warn!(error = ?e, "stderr_forward_error"); + } + } + } + } + + Ok::<(), russh::Error>(()) + }; + + tokio::spawn(async move { + if let Err(e) = fut.await { + tracing::error!("git_ssh_channel_task_error error={}", e); + } + while eof_rx.recv().await.is_some() {} + }); + Ok(()) + } +} diff --git a/lib/git/ssh/mod.rs b/lib/git/ssh/mod.rs new file mode 100644 index 0000000..1e554f8 --- /dev/null +++ b/lib/git/ssh/mod.rs @@ -0,0 +1,195 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Context; +use argon2::{ + Argon2, + password_hash::{PasswordHash, PasswordVerifier}, +}; +use cache::AppCache; +use config::AppConfig; +use db::{database::AppDatabase, sqlx}; +use deadpool_redis::cluster::Pool as RedisPool; +use model::users::{UserModel, UserTokenModel}; +use russh::{ + MethodKind, MethodSet, SshId, + server::{Config, Server}, +}; + +use crate::errors::GitError; + +pub mod authz; +pub mod branch_protect; +pub mod forward; +pub mod git_service; +pub mod handler; +pub mod rate_limit; +pub mod ref_update; +pub mod server; + +#[derive(Clone)] +pub struct SSHHandle { + pub db: AppDatabase, + pub app: AppConfig, + pub cache: AppCache, + pub redis_pool: RedisPool, +} + +impl SSHHandle { + pub fn new( + db: AppDatabase, + app: AppConfig, + cache: AppCache, + redis_pool: RedisPool, + ) -> Self { + SSHHandle { + db, + app, + cache, + redis_pool, + } + } + + pub async fn run_ssh(&self) -> anyhow::Result<()> { + tracing::info!("SSH server starting"); + let key_file = self.app.ssh_server_private_key_file()?; + if key_file.is_empty() { + return Err(anyhow::anyhow!( + "SSH server private key file is not configured (APP_SSH_SERVER_PRIVATE_KEY_FILE)" + )); + } + + tracing::info!("Loading SSH private key from file: {}", key_file); + + let private_key_pem = std::fs::read_to_string(&key_file) + .with_context(|| format!("Failed to read SSH private key file: {}", key_file))?; + + let private_key = russh::keys::decode_secret_key(&private_key_pem, None) + .or_else(|e| { + tracing::info!("decode_secret_key failed: {}, trying from_openssh", e); + russh::keys::ssh_key::PrivateKey::from_openssh(&private_key_pem) + .map_err(|e2| anyhow::anyhow!( + "Failed to parse SSH private key from {}: decode_secret_key={}, from_openssh={}", + key_file, e, e2 + )) + })?; + + tracing::info!("SSH private key loaded"); + let mut config = Config::default(); + config.keys = vec![private_key]; + let version = format!("SSH-2.0-Work {}", env!("CARGO_PKG_VERSION")); + config.server_id = SshId::Standard(version.into()); + config.methods = MethodSet::empty(); + config.methods.push(MethodKind::PublicKey); + config.methods.push(MethodKind::Password); + config.auth_rejection_time = Duration::from_secs(5); + config.inactivity_timeout = Some(Duration::from_secs(300)); + config.keepalive_interval = Some(Duration::from_secs(60)); + config.keepalive_max = 3; + + tracing::info!( + "SSH server configured with methods: {:?}", + config.methods + ); + let token_service = SshTokenService::new(self.db.clone()); + let mut ssh_server = server::SSHServer::new( + self.db.clone(), + self.cache.clone(), + self.redis_pool.clone(), + token_service, + ); + + let _cleanup = ssh_server.rate_limiter.clone().start_cleanup(); + + let ssh_port = self.app.ssh_port()?; + let bind_addr = format!("0.0.0.0:{}", ssh_port); + let public_host = self.app.ssh_domain()?; + let msg = if ssh_port == 22 { + format!( + "SSH server listening on port 22. Please use port {} for SSH connections.", + ssh_port + ) + } else { + format!( + "SSH server listening on port {} (public: {}). Please use port {} for SSH connections.", + ssh_port, public_host, ssh_port + ) + }; + tracing::info!("{}", msg); + ssh_server + .run_on_address(Arc::new(config), bind_addr) + .await?; + Ok(()) + } +} +#[derive(Clone)] +pub struct SshTokenService { + db: AppDatabase, +} + +impl SshTokenService { + pub fn new(db: AppDatabase) -> Self { + Self { db } + } + + pub async fn find_user_by_token( + &self, + token: &str, + ) -> Result, GitError> { + let token_models = sqlx::query_as::<_, UserTokenModel>( + "SELECT id, user, name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at FROM user_token WHERE is_revoked = false", + ) + .fetch_all(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + for token_model in token_models { + if token_model + .expires_at + .map(|expires_at| expires_at < chrono::Utc::now()) + .unwrap_or(false) + { + continue; + } + + let Ok(hash) = PasswordHash::new(&token_model.token_hash) else { + tracing::warn!( + token_id = token_model.id, + "invalid stored SSH token hash" + ); + continue; + }; + + if Argon2::default() + .verify_password(token.as_bytes(), &hash) + .is_err() + { + continue; + } + + let user_model = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, last_sign_in_at, created_at, updated_at FROM \"user\" WHERE id = $1", + ) + .bind(token_model.user) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| GitError::Internal(e.to_string()))?; + + return Ok(user_model); + } + + Ok(None) + } +} + +pub async fn run_ssh( + config: AppConfig, + db: AppDatabase, + cache: AppCache, + redis_pool: RedisPool, +) -> anyhow::Result<()> { + tracing::info!("SSH server initializing"); + SSHHandle::new(db, config.clone(), cache, redis_pool) + .run_ssh() + .await?; + Ok(()) +} diff --git a/lib/git/ssh/rate_limit.rs b/lib/git/ssh/rate_limit.rs new file mode 100644 index 0000000..2b60a5f --- /dev/null +++ b/lib/git/ssh/rate_limit.rs @@ -0,0 +1,114 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; + +use tokio::{sync::RwLock, time::interval}; + +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + pub requests_per_window: u32, + pub window_duration: Duration, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + requests_per_window: 100, + window_duration: Duration::from_secs(60), + } + } +} + +#[derive(Debug)] +struct RateLimitState { + count: u32, + reset_time: Instant, +} + +#[derive(Debug, Clone)] +pub struct RateLimiter { + limits: Arc>>, + config: RateLimitConfig, +} + +impl RateLimiter { + pub fn new(config: RateLimitConfig) -> Self { + Self { + limits: Arc::new(RwLock::new(HashMap::new())), + config, + } + } + + pub async fn is_allowed(&self, key: &str) -> bool { + let now = Instant::now(); + let mut limits = self.limits.write().await; + + let state = + limits + .entry(key.to_string()) + .or_insert_with(|| RateLimitState { + count: 0, + reset_time: now + self.config.window_duration, + }); + + if now >= state.reset_time { + state.count = 0; + state.reset_time = now + self.config.window_duration; + } + + if state.count >= self.config.requests_per_window { + return false; + } + + state.count += 1; + true + } + + pub fn start_cleanup(self: Arc) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(300)); + loop { + ticker.tick().await; + let now = Instant::now(); + let mut limits = self.limits.write().await; + limits.retain(|_, state| now < state.reset_time); + } + }) + } +} + +pub struct SshRateLimiter { + limiter: RateLimiter, +} + +impl SshRateLimiter { + pub fn new() -> Self { + Self { + limiter: RateLimiter::new(RateLimitConfig::default()), + } + } + + pub async fn is_user_allowed(&self, user_id: &str) -> bool { + self.limiter.is_allowed(&format!("user:{}", user_id)).await + } + + pub async fn is_ip_allowed(&self, ip_address: &str) -> bool { + self.limiter.is_allowed(&format!("ip:{}", ip_address)).await + } + + pub async fn is_repo_access_allowed( + &self, + user_id: &str, + repo_path: &str, + ) -> bool { + self.limiter + .is_allowed(&format!("repo_access:{}:{}", user_id, repo_path)) + .await + } + + pub fn start_cleanup(self: Arc) -> tokio::task::JoinHandle<()> { + RateLimiter::start_cleanup(Arc::new(self.limiter.clone())) + } +} diff --git a/lib/git/ssh/ref_update.rs b/lib/git/ssh/ref_update.rs new file mode 100644 index 0000000..0486b4f --- /dev/null +++ b/lib/git/ssh/ref_update.rs @@ -0,0 +1,111 @@ +#[derive(Clone, Debug)] +pub struct RefUpdate { + pub name: String, + pub old_oid: String, + pub new_oid: String, +} + +impl RefUpdate { + pub fn parse_ref_updates(data: &[u8]) -> Result, String> { + let mut refs = Vec::new(); + + for payload in parse_pkt_line_payloads(data)? { + let line = String::from_utf8_lossy(payload); + let line = line.trim_end_matches(['\r', '\n']); + if line.is_empty() { + continue; + } + + let mut parts = line.splitn(3, ' '); + let old_oid = parts.next().unwrap_or_default(); + let new_oid = parts.next().unwrap_or_default(); + let raw_name = parts.next().unwrap_or_default(); + let name = raw_name + .split_once('\0') + .map(|(name, _)| name) + .unwrap_or(raw_name) + .trim(); + + if old_oid.len() != 40 || new_oid.len() != 40 || name.is_empty() { + continue; + } + + refs.push(RefUpdate { + old_oid: old_oid.to_string(), + new_oid: new_oid.to_string(), + name: name.to_string(), + }); + } + + Ok(refs) + } +} + +fn parse_pkt_line_payloads(data: &[u8]) -> Result, String> { + let mut payloads = Vec::new(); + let mut offset = 0; + + while offset + 4 <= data.len() { + let header = std::str::from_utf8(&data[offset..offset + 4]) + .map_err(|_| "invalid pkt-line header encoding".to_string())?; + let len = usize::from_str_radix(header, 16) + .map_err(|_| format!("invalid pkt-line length: {header}"))?; + offset += 4; + + match len { + 0 => break, + 1..=3 => return Err(format!("invalid pkt-line length: {len}")), + _ => { + let payload_len = len - 4; + if offset + payload_len > data.len() { + return Err("truncated pkt-line payload".to_string()); + } + payloads.push(&data[offset..offset + payload_len]); + offset += payload_len; + } + } + } + + Ok(payloads) +} + +#[cfg(test)] +mod tests { + use super::RefUpdate; + + fn pkt(payload: &str) -> Vec { + let len = payload.len() + 4; + let mut out = format!("{len:04x}").into_bytes(); + out.extend_from_slice(payload.as_bytes()); + out + } + + #[test] + fn parses_receive_pack_ref_with_capabilities() { + let mut data = pkt( + "0000000000000000000000000000000000000000 1111111111111111111111111111111111111111 refs/heads/feature\0 report-status\n", + ); + data.extend_from_slice(b"0000"); + + let refs = RefUpdate::parse_ref_updates(&data).unwrap(); + + assert_eq!(refs.len(), 1); + assert_eq!(refs[0].old_oid, "0000000000000000000000000000000000000000"); + assert_eq!(refs[0].new_oid, "1111111111111111111111111111111111111111"); + assert_eq!(refs[0].name, "refs/heads/feature"); + } + + #[test] + fn parses_receive_pack_ref_without_pack_payload() { + let mut data = pkt( + "2222222222222222222222222222222222222222 0000000000000000000000000000000000000000 refs/heads/old\n", + ); + data.extend_from_slice(b"0000"); + + let refs = RefUpdate::parse_ref_updates(&data).unwrap(); + + assert_eq!(refs.len(), 1); + assert_eq!(refs[0].name, "refs/heads/old"); + assert_eq!(refs[0].new_oid, "0000000000000000000000000000000000000000"); + } +} diff --git a/lib/git/ssh/server.rs b/lib/git/ssh/server.rs new file mode 100644 index 0000000..47bc127 --- /dev/null +++ b/lib/git/ssh/server.rs @@ -0,0 +1,98 @@ +use std::{io, net::SocketAddr, sync::Arc}; + +use cache::AppCache; +use db::database::AppDatabase; +use deadpool_redis::cluster::Pool as RedisPool; +use russh::server::Handler; + +use crate::{ + ssh::{SshTokenService, handler::SSHandle, rate_limit::SshRateLimiter}, + sync::ReceiveSyncService, +}; + +pub struct SSHServer { + pub db: AppDatabase, + pub cache: AppCache, + pub redis_pool: RedisPool, + pub token_service: SshTokenService, + pub rate_limiter: Arc, +} + +impl SSHServer { + pub fn new( + db: AppDatabase, + cache: AppCache, + redis_pool: RedisPool, + token_service: SshTokenService, + ) -> Self { + SSHServer { + db, + cache, + redis_pool, + token_service, + rate_limiter: Arc::new(SshRateLimiter::new()), + } + } +} + +impl russh::server::Server for SSHServer { + type Handler = SSHandle; + + fn new_client(&mut self, addr: Option) -> Self::Handler { + if let Some(addr) = addr { + let ip = addr.ip().to_string(); + tracing::info!("New SSH connection ip={} port={}", ip, addr.port()); + let limiter = self.rate_limiter.clone(); + let ip_clone = ip.clone(); + tokio::spawn(async move { + if !limiter.is_ip_allowed(&ip_clone).await { + tracing::warn!(ip = %ip_clone, "SSH connection rate limited"); + } + }); + } else { + tracing::info!("New SSH connection from unknown address"); + } + let sync_service = ReceiveSyncService::new(self.redis_pool.clone()); + SSHandle::new( + self.db.clone(), + self.cache.clone(), + sync_service, + self.token_service.clone(), + addr, + ) + } + + fn handle_session_error( + &mut self, + error: ::Error, + ) { + match error { + russh::Error::Disconnect => { + tracing::info!("Connection disconnected by peer"); + } + russh::Error::Inconsistent => { + tracing::warn!("Protocol inconsistency detected"); + } + russh::Error::NotAuthenticated => { + tracing::warn!("Authentication failed"); + } + russh::Error::IO(ref io_err) => { + tracing::warn!( + "SSH IO error kind={:?} message={} raw_os_error={:?}", + io_err.kind(), + io_err, + io_err.raw_os_error() + ); + + if io_err.kind() == io::ErrorKind::UnexpectedEof { + tracing::warn!( + "SSH peer closed the connection before a clean disconnect was received" + ); + } + } + _ => { + tracing::warn!("SSH session error error={}", error); + } + } + } +} diff --git a/lib/git/sync/branch.rs b/lib/git/sync/branch.rs new file mode 100644 index 0000000..dbf9c1b --- /dev/null +++ b/lib/git/sync/branch.rs @@ -0,0 +1,163 @@ +use std::collections::HashSet; + +use db::{database::AppDatabase, sqlx}; +use model::repos::RepoRefModel; +use uuid::Uuid; + +use crate::{bare::GitBare, errors::GitError}; +#[derive(Debug, Clone)] +pub struct BranchTip { + pub name: String, + pub shorthand: String, + pub target_oid: String, +} +pub fn collect_branch_tips(bare: &GitBare) -> Result, GitError> { + let repo = bare.gix_repo()?; + let refs = repo.references() + .map_err(|e| GitError::Internal(format!("failed to open references: {}", e)))?; + let iter = refs.all() + .map_err(|e| GitError::Internal(format!("failed to iterate refs: {}", e)))?; + + let mut branches = Vec::new(); + for ref_result in iter { + let reference = ref_result + .map_err(|e| GitError::Internal(format!("ref iteration error: {}", e)))?; + let full_name = reference.name().as_bstr().to_string(); + if !full_name.starts_with("refs/heads/") { + continue; + } + let target_oid = reference.target().try_id() + .map(|id| id.to_hex().to_string()) + .ok_or_else(|| GitError::Internal("ref has no direct target".to_string()))?; + let shorthand = reference.name().shorten().to_string(); + branches.push(BranchTip { + name: full_name, + shorthand, + target_oid, + }); + } + Ok(branches) +} +pub async fn sync_refs( + db: &AppDatabase, + bare: &GitBare, + repo_id: Uuid, +) -> Result<(), GitError> { + let now = chrono::Utc::now(); + let pool = db.writer(); + + let existing: Vec = sqlx::query_as::<_, RepoRefModel>( + "SELECT id, repo, name, kind, target_sha, is_default, is_protected, created_at, updated_at FROM repo_ref WHERE repo = $1 AND kind = 'branch'" + ) + .bind(repo_id) + .fetch_all(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to query branches: {}", e)))?; + + let mut existing_names: HashSet = + existing.iter().map(|r| r.name.clone()).collect(); + + let branches = collect_branch_tips(bare)?; + + const PREFERRED_BRANCHES: &[&str] = &["main", "master", "trunk"]; + + let current_default: Option = sqlx::query_scalar::<_, String>( + "SELECT default_branch FROM repo WHERE id = $1", + ) + .bind(repo_id) + .fetch_optional(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to re-read repo: {}", e)))? + .filter(|b| !b.is_empty()); + + let mut auto_detected_branch: Option = None; + if current_default.is_none() { + for preferred in PREFERRED_BRANCHES { + if branches.iter().any(|b| b.shorthand == *preferred) { + auto_detected_branch = Some((*preferred).to_string()); + break; + } + } + if auto_detected_branch.is_none() { + if let Some(first) = branches.first() { + auto_detected_branch = Some(first.shorthand.clone()); + } + } + } + + for branch in &branches { + if existing_names.contains(&branch.name) { + existing_names.remove(&branch.name); + sqlx::query( + "UPDATE repo_ref SET target_sha = $1, updated_at = $2 WHERE repo = $3 AND name = $4 AND kind = 'branch'" + ) + .bind(&branch.target_oid) + .bind(now) + .bind(repo_id) + .bind(&branch.name) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to update branch: {}", e)))?; + } else { + let new_id = Uuid::new_v4(); + sqlx::query( + "INSERT INTO repo_ref (id, repo, name, kind, target_sha, is_default, is_protected, created_at, updated_at) VALUES ($1, $2, $3, 'branch', $4, $5, false, $6, $7)" + ) + .bind(new_id) + .bind(repo_id) + .bind(&branch.name) + .bind(&branch.target_oid) + .bind(false) + .bind(now) + .bind(now) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to insert branch: {}", e)))?; + } + } + + if !existing_names.is_empty() { + let names_vec: Vec = existing_names.into_iter().collect(); + sqlx::query("DELETE FROM repo_ref WHERE repo = $1 AND name = ANY($2) AND kind = 'branch'") + .bind(repo_id) + .bind(&names_vec) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to delete stale branches: {}", e)))?; + } + + if let Some(ref branch_name) = auto_detected_branch { + let result = sqlx::query( + "UPDATE repo SET default_branch = $1, updated_at = $2 WHERE id = $3 AND default_branch = ''" + ) + .bind(branch_name.clone()) + .bind(now) + .bind(repo_id) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to set default branch: {}", e)))?; + + if result.rows_affected() > 0 { + sqlx::query( + "UPDATE repo_ref SET is_default = false, updated_at = $1 WHERE repo = $2 AND kind = 'branch'" + ) + .bind(now) + .bind(repo_id) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to clear head flags: {}", e)))?; + + sqlx::query( + "UPDATE repo_ref SET is_default = true, updated_at = $1 WHERE repo = $2 AND name = $3 AND kind = 'branch'" + ) + .bind(now) + .bind(repo_id) + .bind(branch_name) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to set head flag: {}", e)))?; + } + } + + Ok(()) +} diff --git a/lib/git/sync/cicheck.rs b/lib/git/sync/cicheck.rs new file mode 100644 index 0000000..28008fd --- /dev/null +++ b/lib/git/sync/cicheck.rs @@ -0,0 +1,157 @@ +use deadpool_redis::cluster::Pool as RedisPool; +use parsefile::{Pipeline, TriggerEvent}; + +use crate::{ + bare::GitBare, + errors::GitError, + sync::{HookTask, TaskType}, +}; + +const PIPELINE_FILE: &str = "pipeline.yaml"; +#[derive(Debug)] +pub enum CiCheckOutcome { + Enqueued, + NoPipelineFile, + NotTriggered, +} +fn ci_queue_keys(repo_id: uuid::Uuid) -> (String, String) { + let hash_tag = format!("{{ci:{}}}", repo_id); + ( + format!("{}:pending", hash_tag), + format!("{}:processing", hash_tag), + ) +} +pub async fn check_and_enqueue( + bare: &GitBare, + repo_id: uuid::Uuid, + event: &TriggerEvent, + redis_pool: &RedisPool, +) -> Result { + let output = bare.git_command_trusted_stdout(vec![ + "show".to_string(), + format!("HEAD:{}", PIPELINE_FILE), + ]); + + let content = match output { + Ok(c) => c, + Err(_) => return Ok(CiCheckOutcome::NoPipelineFile), + }; + + let pipeline = parsefile::parse_from_str(&content).map_err(|e| { + GitError::Internal(format!("failed to parse {}: {}", PIPELINE_FILE, e)) + })?; + + if !pipeline.should_run(event) { + return Ok(CiCheckOutcome::NotTriggered); + } + + enqueue_ci_task(repo_id, event, &pipeline, redis_pool) + .await + .map_err(|e| { + GitError::Internal(format!("failed to enqueue CI task: {}", e)) + })?; + + Ok(CiCheckOutcome::Enqueued) +} + +async fn enqueue_ci_task( + repo_id: uuid::Uuid, + event: &TriggerEvent, + pipeline: &Pipeline, + redis_pool: &RedisPool, +) -> Result<(), String> { + let hook_task = HookTask { + id: uuid::Uuid::new_v4().to_string(), + repo_id: repo_id.to_string(), + task_type: TaskType::Sync, + payload: serde_json::json!({ + "ci": true, + "pipeline_name": pipeline.name, + "trigger": event_variant_name(event), + }), + created_at: chrono::Utc::now(), + retry_count: 0, + }; + + let task_json = serde_json::to_string(&hook_task) + .map_err(|e| format!("serialize error: {}", e))?; + + let (pending_key, _) = ci_queue_keys(repo_id); + + let redis = redis_pool + .get() + .await + .map_err(|e| format!("redis pool: {}", e))?; + let mut conn: deadpool_redis::cluster::Connection = redis; + + redis::cmd("LPUSH") + .arg(&pending_key) + .arg(&task_json) + .query_async::<()>(&mut conn) + .await + .map_err(|e| format!("LPUSH error: {}", e))?; + + tracing::info!( + repo_id = %repo_id, + pipeline = %pipeline.name, + trigger = %event_variant_name(event), + "CI task enqueued" + ); + + Ok(()) +} + +fn event_variant_name(event: &TriggerEvent) -> &'static str { + match event { + TriggerEvent::PushBranch(_) => "push_branch", + TriggerEvent::PushTag(_) => "push_tag", + TriggerEvent::PullRequest { .. } => "pull_request", + } +} +pub async fn poll_ci_task_for_repo( + redis_pool: &RedisPool, + repo_id: uuid::Uuid, + block_timeout_secs: usize, +) -> Option { + let (pending_key, processing_key) = ci_queue_keys(repo_id); + + let redis = redis_pool.get().await.ok()?; + let mut conn: deadpool_redis::cluster::Connection = redis; + + redis::cmd("BLMOVE") + .arg(&pending_key) + .arg(&processing_key) + .arg("RIGHT") + .arg("LEFT") + .arg(block_timeout_secs) + .query_async::>(&mut conn) + .await + .ok() + .flatten() +} +pub async fn ack_ci_task( + redis_pool: &RedisPool, + repo_id: uuid::Uuid, + task_json: &str, +) { + let (_, processing_key) = ci_queue_keys(repo_id); + + let redis = match redis_pool.get().await { + Ok(c) => c, + Err(e) => { + tracing::warn!(error = %e, "CI ack: failed to get redis connection"); + return; + } + }; + let mut conn: deadpool_redis::cluster::Connection = redis; + + if let Err(e) = redis::cmd("LREM") + .arg(&processing_key) + .arg(1) + .arg(task_json) + .query_async::<()>(&mut conn) + .await + { + tracing::warn!(error = %e, "CI ack: LREM failed"); + } +} diff --git a/lib/git/sync/commit.rs b/lib/git/sync/commit.rs new file mode 100644 index 0000000..fde549f --- /dev/null +++ b/lib/git/sync/commit.rs @@ -0,0 +1,195 @@ +use std::collections::{HashMap, HashSet}; + +use chrono::{DateTime, Utc}; +use db::{database::AppDatabase, sqlx}; +use model::repos::RepoCommitterModel; +use uuid::Uuid; + +use crate::{bare::GitBare, cmd::oid::ObjectId, errors::GitError}; +pub async fn sync_commits( + db: &AppDatabase, + bare: &GitBare, + repo_id: Uuid, +) -> Result<(), GitError> { + let repo = bare.gix_repo()?; + let pool = db.writer(); + + let existing_oids: Vec = sqlx::query_scalar::<_, String>( + "SELECT sha FROM repo_commit WHERE repo = $1", + ) + .bind(repo_id) + .fetch_all(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to query commits: {}", e)))?; + let existing_set: HashSet = existing_oids.into_iter().collect(); + + let head_id = repo.head_id() + .map_err(|e| GitError::Internal(format!("failed to resolve HEAD: {}", e)))? + .detach(); + + let tips = { + let refs = repo.references() + .map_err(|e| GitError::Internal(format!("failed to open references: {}", e)))?; + let iter = refs.all() + .map_err(|e| GitError::Internal(format!("failed to iterate refs: {}", e)))?; + let mut tips = vec![head_id]; + for ref_result in iter { + let reference = ref_result + .map_err(|e| GitError::Internal(format!("ref iteration error: {}", e)))?; + let name = reference.name().as_bstr().to_string(); + if !name.starts_with("refs/heads/") { + continue; + } + if let Some(target_id) = reference.target().try_id() { + let hex = target_id.to_hex().to_string(); + if let Ok(gix_id) = gix::hash::ObjectId::from_hex(hex.as_bytes()) { + tips.push(gix_id); + } + } + } + tips + }; + + let platform = repo.rev_walk(tips) + .sorting(gix::revision::walk::Sorting::ByCommitTime( + gix::traverse::commit::simple::CommitTimeOrder::NewestFirst, + )); + let walk = platform.all() + .map_err(|e| GitError::Internal(format!("rev_walk failed: {}", e)))?; + + let mut new_commits: Vec = Vec::new(); + for info in walk { + let info = info.map_err(|e| GitError::Internal(format!("walk step error: {}", e)))?; + let hex = info.id().detach().to_hex().to_string(); + if !existing_set.contains(&hex) { + new_commits.push(info.id().detach()); + } + } + + if new_commits.is_empty() { + return Ok(()); + } + let mut committer_map: HashMap = HashMap::new(); // email → (committer_id, name) + + let existing_committers: Vec = sqlx::query_as::<_, RepoCommitterModel>( + "SELECT id, repo, \"user\", name, email, created_at, updated_at FROM repo_committer WHERE repo = $1", + ) + .bind(repo_id) + .fetch_all(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to query repo_committer: {}", e)))?; + + for model in &existing_committers { + committer_map.insert(model.email.clone(), (model.id, model.name.clone())); + } + + let email_map = resolve_user_ids(db, &committer_map).await?; + + let now = Utc::now(); + + for gix_id in &new_commits { + let hex_oid = gix_id.to_hex().to_string(); + let oid = ObjectId::new(&hex_oid); + let commit_meta = bare.commit_info(oid) + .map_err(|e| GitError::Internal(format!("commit_info failed for {}: {}", hex_oid, e)))?; + + let author_committer_id = ensure_committer( + &mut committer_map, pool, repo_id, &commit_meta.author.email, + &commit_meta.author.name, &email_map, now, + ).await?; + let committer_committer_id = ensure_committer( + &mut committer_map, pool, repo_id, &commit_meta.committer.email, + &commit_meta.committer.name, &email_map, now, + ).await?; + + let parent_shas = commit_meta.parent_ids + .iter() + .map(|p| p.as_str()) + .collect::>() + .join("."); + let authored_at = git_time_to_datetime(commit_meta.author.time_secs, commit_meta.author.offset_minutes); + let committed_at = git_time_to_datetime(commit_meta.committer.time_secs, commit_meta.committer.offset_minutes); + + let new_id = Uuid::new_v4(); + sqlx::query( + "INSERT INTO repo_commit (id, repo, sha, tree_sha, parent_shas, author, committer, message, authored_at, committed_at, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + ) + .bind(new_id) + .bind(repo_id) + .bind(&hex_oid) + .bind(commit_meta.tree_id.as_str()) + .bind(&parent_shas) + .bind(author_committer_id) + .bind(committer_committer_id) + .bind(&commit_meta.message) + .bind(authored_at) + .bind(committed_at) + .bind(now) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to insert commit: {}", e)))?; + } + + Ok(()) +} + +async fn resolve_user_ids( + db: &AppDatabase, + committer_map: &HashMap, +) -> Result, GitError> { + if committer_map.is_empty() { + return Ok(HashMap::new()); + } + let pool = db.writer(); + let email_vec: Vec = committer_map.keys().cloned().collect(); + + let rows: Vec<(Uuid, String)> = sqlx::query_as( + "SELECT \"user\", email FROM user_email WHERE email = ANY($1) AND active = true", + ) + .bind(&email_vec) + .fetch_all(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to query user emails: {}", e)))?; + + let mut map = HashMap::new(); + for (user_id, email) in rows { + map.insert(email, user_id); + } + Ok(map) +} +async fn ensure_committer( + committer_map: &mut HashMap, + pool: &sqlx::Pool, + repo_id: Uuid, + email: &str, + name: &str, + email_map: &HashMap, + now: DateTime, +) -> Result { + if let Some((id, _)) = committer_map.get(email) { + return Ok(*id); + } + let user_id = email_map.get(email).copied(); + let new_id = Uuid::new_v4(); + sqlx::query( + "INSERT INTO repo_committer (id, repo, \"user\", name, email, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7)" + ) + .bind(new_id) + .bind(repo_id) + .bind(user_id) + .bind(name) + .bind(email) + .bind(now) + .bind(now) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to insert repo_committer: {}", e)))?; + committer_map.insert(email.to_string(), (new_id, name.to_string())); + Ok(new_id) +} +fn git_time_to_datetime(secs: i64, offset_minutes: i32) -> DateTime { + let utc_secs = secs - (offset_minutes as i64 * 60); + DateTime::from_timestamp(utc_secs, 0).unwrap_or_else(|| Utc::now()) +} diff --git a/lib/git/sync/consumer.rs b/lib/git/sync/consumer.rs new file mode 100644 index 0000000..66dd820 --- /dev/null +++ b/lib/git/sync/consumer.rs @@ -0,0 +1,90 @@ +use deadpool_redis::cluster::Connection; +use redis::AsyncCommands; + +use crate::sync::{ReceiveSyncService, TaskType}; +pub struct SyncConsumer { + service: ReceiveSyncService, + block_timeout_secs: u64, +} + +impl SyncConsumer { + pub fn new(service: ReceiveSyncService, block_timeout_secs: u64) -> Self { + Self { + service, + block_timeout_secs, + } + } + pub async fn next(&self, task_type: &TaskType) -> Option<(String, String)> { + let prefix = &self.service.redis_prefix; + let queue_key = match task_type { + TaskType::Sync => format!("{prefix}:sync"), + TaskType::Fsck => format!("{prefix}:fsck"), + TaskType::Gc => format!("{prefix}:gc"), + TaskType::Webhook => format!("{prefix}:webhook"), + }; + let work_key = format!("{queue_key}:work"); + + let redis = self.service.pool.get().await.ok()?; + let mut conn: Connection = redis; + + let result: Option = redis::cmd("BLMOVE") + .arg(&queue_key) + .arg(&work_key) + .arg("RIGHT") + .arg("LEFT") + .arg(self.block_timeout_secs) + .query_async(&mut conn) + .await + .ok()?; + + result.map(|json| (json, work_key)) + } + pub async fn ack(&self, task_json: &str, work_key: &str) -> Option<()> { + let redis = self.service.pool.get().await.ok()?; + let mut conn: Connection = redis; + let removed: i32 = conn.lrem(work_key, 1, task_json).await.ok()?; + if removed > 0 { Some(()) } else { None } + } + pub async fn nak_with_retry( + &self, + task_json: &str, + work_key: &str, + queue_key: &str, + ) -> Option<()> { + let redis = self.service.pool.get().await.ok()?; + let mut conn: Connection = redis; + + let script = redis::Script::new( + r#" + local removed = redis.call("LREM", KEYS[1], 0, ARGV[1]) + if removed > 0 then + redis.call("LPUSH", KEYS[2], ARGV[1]) + end + return removed + "#, + ); + + let result: i32 = script + .key(work_key) + .key(queue_key) + .arg(task_json) + .invoke_async(&mut conn) + .await + .ok()?; + + if result > 0 { Some(()) } else { None } + } + + pub(crate) fn queue_key_for_task_type( + &self, + task_type: &TaskType, + ) -> String { + let prefix = &self.service.redis_prefix; + match task_type { + TaskType::Sync => format!("{prefix}:sync"), + TaskType::Fsck => format!("{prefix}:fsck"), + TaskType::Gc => format!("{prefix}:gc"), + TaskType::Webhook => format!("{prefix}:webhook"), + } + } +} diff --git a/lib/git/sync/language.rs b/lib/git/sync/language.rs new file mode 100644 index 0000000..f82e73d --- /dev/null +++ b/lib/git/sync/language.rs @@ -0,0 +1,176 @@ +use std::collections::HashMap; + +use db::{database::AppDatabase, sqlx}; +use uuid::Uuid; + +use crate::{bare::GitBare, cmd::oid::ObjectId, errors::GitError}; +fn language_from_extension(ext: &str) -> Option<&str> { + match ext { + "rs" => Some("Rust"), + "ts" | "tsx" => Some("TypeScript"), + "js" | "jsx" | "mjs" | "cjs" => Some("JavaScript"), + "py" | "pyi" => Some("Python"), + "go" => Some("Go"), + "java" => Some("Java"), + "kt" | "kts" => Some("Kotlin"), + "c" | "h" => Some("C"), + "cpp" | "cc" | "cxx" | "hpp" | "hxx" => Some("C++"), + "cs" => Some("C#"), + "rb" => Some("Ruby"), + "php" => Some("PHP"), + "swift" => Some("Swift"), + "scala" => Some("Scala"), + "lua" => Some("Lua"), + "r" | "R" => Some("R"), + "sql" => Some("SQL"), + "sh" | "bash" => Some("Shell"), + "ps1" => Some("PowerShell"), + "dart" => Some("Dart"), + "el" | "lisp" => Some("Emacs Lisp"), + "clj" | "cljs" => Some("Clojure"), + "hs" => Some("Haskell"), + "ex" | "exs" => Some("Elixir"), + "erl" => Some("Erlang"), + "vue" => Some("Vue"), + "svelte" => Some("Svelte"), + "css" | "scss" | "sass" | "less" => Some("CSS"), + "html" | "htm" => Some("HTML"), + "xml" | "xsl" | "xsd" => Some("XML"), + "json" | "jsonl" => Some("JSON"), + "yaml" | "yml" => Some("YAML"), + "toml" => Some("TOML"), + "md" | "markdown" => Some("Markdown"), + "dockerfile" => Some("Dockerfile"), + "proto" => Some("Protocol Buffers"), + "tf" => Some("HCL"), + "zig" => Some("Zig"), + "nim" => Some("Nim"), + "v" => Some("V"), + "wasm" => Some("WebAssembly"), + "glsl" => Some("GLSL"), + "cu" | "cuh" => Some("CUDA"), + "makefile" => Some("Makefile"), + _ => None, + } +} +fn language_from_filename(name: &str) -> Option<&str> { + let lower = name.to_ascii_lowercase(); + match lower.as_str() { + "makefile" | "gnumakefile" => Some("Makefile"), + "dockerfile" => Some("Dockerfile"), + "cmakelists.txt" => Some("CMake"), + "cargo.toml" => Some("TOML"), + "package.json" => Some("JSON"), + "tsconfig.json" => Some("JSON"), + ".gitignore" | ".gitattributes" => Some("Gitignore"), + _ => None, + } +} +fn collect_language_stats(bare: &GitBare) -> Result, GitError> { + let repo = bare.gix_repo()?; + let head_id = repo.head_id() + .map_err(|e| GitError::Internal(format!("failed to resolve HEAD: {}", e)))?; + let commit = repo.find_commit(head_id.detach()) + .map_err(|e| GitError::Internal(format!("failed to find HEAD commit: {}", e)))?; + let decoded = commit.decode() + .map_err(|e| GitError::Internal(format!("failed to decode commit: {}", e)))?; + let tree_oid = ObjectId::new(decoded.tree().to_hex().to_string()); + + let mut stats: HashMap = HashMap::new(); + walk_tree(bare, &tree_oid, &mut stats)?; + Ok(stats) +} +fn walk_tree( + bare: &GitBare, + tree_oid: &ObjectId, + stats: &mut HashMap, +) -> Result<(), GitError> { + let entries = bare.tree_entries(tree_oid.clone())?; + + for entry in entries { + if entry.kind == crate::cmd::tree::TreeKind::Tree { + walk_tree(bare, &entry.oid, stats)?; + continue; + } + if entry.kind == crate::cmd::tree::TreeKind::LfsPointer { + continue; + } + if entry.is_binary { + continue; + } + + let language = language_from_filename(&entry.name) + .or_else(|| { + let ext = entry.name.rsplit('.').next().unwrap_or(""); + language_from_extension(ext) + }); + + if let Some(lang) = language { + let size = blob_size(bare, &entry.oid)?; + *stats.entry(lang.to_string()).or_insert(0) += size; + } + } + Ok(()) +} +fn blob_size(bare: &GitBare, oid: &ObjectId) -> Result { + let repo = bare.gix_repo()?; + let gix_id: gix::hash::ObjectId = oid.try_into() + .map_err(|e| GitError::Internal(format!("invalid oid: {}", e)))?; + let header = repo.find_header(gix_id) + .map_err(|e| GitError::Internal(format!("blob header not found: {}", e)))?; + Ok(header.size() as u64) +} +pub async fn sync_languages( + db: &AppDatabase, + bare: &GitBare, + repo_id: Uuid, +) -> Result<(), GitError> { + let stats = collect_language_stats(bare)?; + if stats.is_empty() { + return Ok(()); + } + + let total_bytes: u64 = stats.values().sum(); + let pool = db.writer(); + + let mut tx = pool.begin() + .await + .map_err(|e| GitError::Internal(format!("failed to begin tx: {}", e)))?; + + sqlx::query("DELETE FROM repo_language WHERE repo = $1") + .bind(repo_id) + .execute(&mut *tx) + .await + .map_err(|e| GitError::Internal(format!("failed to delete repo_language: {}", e)))?; + + for (language, bytes) in &stats { + let percentage = if total_bytes > 0 { + (*bytes as f32 / total_bytes as f32) * 100.0 + } else { + 0.0 + }; + sqlx::query( + "INSERT INTO repo_language (repo, language, bytes, percentage) VALUES ($1, $2, $3, $4)" + ) + .bind(repo_id) + .bind(language) + .bind(*bytes as i64) + .bind(percentage) + .execute(&mut *tx) + .await + .map_err(|e| GitError::Internal(format!("failed to insert repo_language: {}", e)))?; + } + + tx.commit() + .await + .map_err(|e| GitError::Internal(format!("failed to commit tx: {}", e)))?; + + tracing::info!( + repo_id = %repo_id, + languages = stats.len(), + total_bytes, + "language stats synced" + ); + + Ok(()) +} diff --git a/lib/git/sync/lfs.rs b/lib/git/sync/lfs.rs new file mode 100644 index 0000000..010d48c --- /dev/null +++ b/lib/git/sync/lfs.rs @@ -0,0 +1,106 @@ +use std::collections::HashSet; + +use db::{database::AppDatabase, sqlx}; +use model::repos::RepoLfsObjectModel; +use uuid::Uuid; + +use crate::{bare::GitBare, errors::GitError}; +pub async fn sync_lfs_objects( + db: &AppDatabase, + bare: &GitBare, + repo_id: Uuid, +) -> Result<(), GitError> { + let pool = db.writer(); + + let existing: Vec = sqlx::query_as::<_, RepoLfsObjectModel>( + "SELECT repo, oid, size_bytes, storage_key, created_at FROM repo_lfs_object WHERE repo = $1" + ) + .bind(repo_id) + .fetch_all(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to query lfs objects: {}", e)))?; + + let mut existing_oids: HashSet = + existing.into_iter().map(|o| o.oid).collect(); + + let lfs_dir = bare.bare_dir.join(".lfs").join("objects"); + if !lfs_dir.exists() { + if !existing_oids.is_empty() { + let oids_vec: Vec = existing_oids.into_iter().collect(); + sqlx::query( + "DELETE FROM repo_lfs_object WHERE repo = $1 AND oid = ANY($2)", + ) + .bind(repo_id) + .bind(&oids_vec) + .execute(pool) + .await + .map_err(|e| { + GitError::Internal(format!( + "failed to delete stale lfs objects: {}", + e + )) + })?; + } + return Ok(()); + } + + let now = chrono::Utc::now(); + let mut new_objects: Vec<(String, i64, String)> = Vec::new(); + + if let Ok(prefix_entries) = std::fs::read_dir(&lfs_dir) { + for prefix_entry in prefix_entries.flatten() { + if let Ok(oid_entries) = std::fs::read_dir(prefix_entry.path()) { + for oid_entry in oid_entries.flatten() { + let oid_str = + oid_entry.file_name().to_string_lossy().to_string(); + if existing_oids.contains(&oid_str) { + existing_oids.remove(&oid_str); + continue; + } + + let path = oid_entry.path(); + let size_bytes = match std::fs::metadata(&path) { + Ok(meta) => meta.len() as i64, + Err(_) => continue, + }; + let storage_key = path.to_string_lossy().to_string(); + + new_objects.push((oid_str, size_bytes, storage_key)); + } + } + } + } + + for (oid, size_bytes, storage_key) in &new_objects { + sqlx::query( + "INSERT INTO repo_lfs_object (repo, oid, size_bytes, storage_key, created_at) VALUES ($1, $2, $3, $4, $5)" + ) + .bind(repo_id) + .bind(oid) + .bind(*size_bytes) + .bind(storage_key) + .bind(now) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to insert lfs object: {}", e)))?; + } + + if !existing_oids.is_empty() { + let oids_vec: Vec = existing_oids.into_iter().collect(); + sqlx::query( + "DELETE FROM repo_lfs_object WHERE repo = $1 AND oid = ANY($2)", + ) + .bind(repo_id) + .bind(&oids_vec) + .execute(pool) + .await + .map_err(|e| { + GitError::Internal(format!( + "failed to delete stale lfs objects: {}", + e + )) + })?; + } + + Ok(()) +} diff --git a/lib/git/sync/lock.rs b/lib/git/sync/lock.rs new file mode 100644 index 0000000..4666109 --- /dev/null +++ b/lib/git/sync/lock.rs @@ -0,0 +1,58 @@ +use deadpool_redis::cluster::Pool as RedisPool; +pub async fn acquire_repo_lock( + pool: &RedisPool, + repo_id: uuid::Uuid, + lock_value: &str, + ttl_secs: usize, +) -> redis::RedisResult { + let lock_key = format!("git:repo:lock:{}", repo_id); + let redis = pool.get().await.map_err(|e| { + redis::RedisError::from(( + redis::ErrorKind::Io, + "failed to get Redis connection", + e.to_string(), + )) + })?; + let mut conn: deadpool_redis::cluster::Connection = redis; + let acquired: Option = redis::cmd("SET") + .arg(&lock_key) + .arg(lock_value) + .arg("NX") + .arg("EX") + .arg(ttl_secs) + .query_async(&mut conn) + .await?; + Ok(acquired.is_some()) +} +pub async fn release_repo_lock( + pool: &RedisPool, + repo_id: uuid::Uuid, + lock_value: &str, +) { + let lock_key = format!("git:repo:lock:{}", repo_id); + let redis = match pool.get().await { + Ok(c) => c, + Err(e) => { + tracing::warn!(error = %e, repo_id = %repo_id, "lock_release_redis_connection_failed"); + return; + } + }; + let mut conn: deadpool_redis::cluster::Connection = redis; + let script = redis::Script::new( + r#" + if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("DEL", KEYS[1]) + return 1 + end + return 0 + "#, + ); + if let Err(e) = script + .key(&lock_key) + .arg(lock_value) + .invoke_async::(&mut conn) + .await + { + tracing::warn!(error = %e, repo_id = %repo_id, "lock_release_failed"); + } +} diff --git a/lib/git/sync/mod.rs b/lib/git/sync/mod.rs new file mode 100644 index 0000000..5906924 --- /dev/null +++ b/lib/git/sync/mod.rs @@ -0,0 +1,266 @@ +use deadpool_redis::cluster::Pool as RedisPool; +use redis::AsyncCommands; + +pub mod branch; +pub mod cicheck; +pub mod commit; +pub mod consumer; +pub mod language; +pub mod lfs; +pub mod lock; +pub mod push_queue; +pub mod tag; +pub mod webhook; +pub mod worker; +#[derive(Clone)] +pub struct ReceiveSyncService { + pool: RedisPool, + redis_prefix: String, +} + +impl ReceiveSyncService { + pub fn new(pool: RedisPool) -> Self { + Self { + pool, + redis_prefix: "{hook}".to_string(), + } + } + + pub fn pool(&self) -> RedisPool { + self.pool.clone() + } + + pub async fn queue_position( + &self, + repo_uid: uuid::Uuid, + ) -> Option<(usize, usize)> { + let queue_key = format!("{}:sync", self.redis_prefix); + let work_key = format!("{}:work", queue_key); + let redis = self.pool.get().await.ok()?; + let mut conn: deadpool_redis::cluster::Connection = redis; + let queue_items: Vec = + conn.lrange(&queue_key, 0, -1).await.ok()?; + let work_items: Vec = + conn.lrange(&work_key, 0, -1).await.unwrap_or_default(); + let repo_id = repo_uid.to_string(); + let queued_before = queue_items + .iter() + .rev() + .take_while(|item| { + serde_json::from_str::(item) + .map(|task| task.repo_id != repo_id) + .unwrap_or(true) + }) + .count(); + let total = work_items.len() + queue_items.len() + 1; + Some((work_items.len() + queued_before + 1, total)) + } + + pub(crate) fn push_queue_keys(repo_uid: uuid::Uuid) -> (String, String) { + let hash_tag = format!("{{push:{}}}", repo_uid); + ( + format!("git:{}:queue", hash_tag), + format!("git:{}:lock", hash_tag), + ) + } + + pub async fn join_push_queue( + &self, + repo_uid: uuid::Uuid, + request_id: &str, + ) -> redis::RedisResult<()> { + let (queue_key, _) = Self::push_queue_keys(repo_uid); + let redis = self.pool.get().await.map_err(|e| { + redis::RedisError::from(( + redis::ErrorKind::Io, + "failed to get Redis connection", + e.to_string(), + )) + })?; + let mut conn: deadpool_redis::cluster::Connection = redis; + redis::cmd("RPUSH") + .arg(&queue_key) + .arg(request_id) + .query_async::<()>(&mut conn) + .await + } + + pub async fn push_queue_position( + &self, + repo_uid: uuid::Uuid, + request_id: &str, + ) -> Option<(usize, usize)> { + let (queue_key, _) = Self::push_queue_keys(repo_uid); + let redis = self.pool.get().await.ok()?; + let mut conn: deadpool_redis::cluster::Connection = redis; + let queue_items: Vec = + conn.lrange(&queue_key, 0, -1).await.ok()?; + let position = + queue_items.iter().position(|item| item == request_id)? + 1; + Some((position, queue_items.len())) + } + + pub async fn try_acquire_push_lock( + &self, + repo_uid: uuid::Uuid, + request_id: &str, + ttl_secs: usize, + ) -> redis::RedisResult { + let (_, lock_key) = Self::push_queue_keys(repo_uid); + let redis = self.pool.get().await.map_err(|e| { + redis::RedisError::from(( + redis::ErrorKind::Io, + "failed to get Redis connection", + e.to_string(), + )) + })?; + let mut conn: deadpool_redis::cluster::Connection = redis; + let acquired: Option = redis::cmd("SET") + .arg(&lock_key) + .arg(request_id) + .arg("NX") + .arg("EX") + .arg(ttl_secs) + .query_async(&mut conn) + .await?; + Ok(acquired.is_some()) + } + + pub async fn release_push_queue( + &self, + repo_uid: uuid::Uuid, + request_id: &str, + ) { + let (queue_key, lock_key) = Self::push_queue_keys(repo_uid); + let redis = match self.pool.get().await { + Ok(c) => c, + Err(e) => { + tracing::warn!(error = %e, repo_id = %repo_uid, "push_queue_release_redis_connection_failed"); + return; + } + }; + let mut conn: deadpool_redis::cluster::Connection = redis; + let script = redis::Script::new( + r#" + redis.call("LREM", KEYS[1], 0, ARGV[1]) + if redis.call("GET", KEYS[2]) == ARGV[1] then + redis.call("DEL", KEYS[2]) + end + return 1 + "#, + ); + if let Err(e) = script + .key(&queue_key) + .key(&lock_key) + .arg(request_id) + .invoke_async::<()>(&mut conn) + .await + { + tracing::warn!(error = %e, repo_id = %repo_uid, "push_queue_release_failed"); + } + } + + pub async fn refresh_push_lock( + &self, + repo_uid: uuid::Uuid, + request_id: &str, + ttl_secs: usize, + ) -> redis::RedisResult { + let (_, lock_key) = Self::push_queue_keys(repo_uid); + let redis = self.pool.get().await.map_err(|e| { + redis::RedisError::from(( + redis::ErrorKind::Io, + "failed to get Redis connection", + e.to_string(), + )) + })?; + let mut conn: deadpool_redis::cluster::Connection = redis; + let refreshed: i32 = redis::Script::new( + r#" + if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("EXPIRE", KEYS[1], ARGV[2]) + return 1 + end + return 0 + "#, + ) + .key(&lock_key) + .arg(request_id) + .arg(ttl_secs) + .invoke_async(&mut conn) + .await?; + Ok(refreshed == 1) + } + + pub async fn send( + &self, + task: RepoReceiveSyncTask, + ) -> Option<(usize, usize)> { + let position = self.queue_position(task.repo_uid).await; + let hook_task = HookTask { + id: uuid::Uuid::new_v4().to_string(), + repo_id: task.repo_uid.to_string(), + task_type: TaskType::Sync, + payload: serde_json::Value::Null, + created_at: chrono::Utc::now(), + retry_count: 0, + }; + + let task_json = match serde_json::to_string(&hook_task) { + Ok(j) => j, + Err(e) => { + tracing::error!("failed to serialize hook task: {}", e); + return position; + } + }; + + let queue_key = format!("{}:sync", self.redis_prefix); + + let redis = match self.pool.get().await { + Ok(c) => c, + Err(e) => { + tracing::error!("failed to get Redis connection: {}", e); + return position; + } + }; + + let mut conn: deadpool_redis::cluster::Connection = redis; + if let Err(e) = redis::cmd("LPUSH") + .arg(&queue_key) + .arg(&task_json) + .query_async::<()>(&mut conn) + .await + { + tracing::error!( + "failed to enqueue sync task repo_id={} error={}", + task.repo_uid, + e + ); + } else { + tracing::info!(repo_id = %task.repo_uid, "hook task queued to Redis"); + } + position + } +} + +#[derive(Clone)] +pub struct RepoReceiveSyncTask { + pub repo_uid: uuid::Uuid, +} +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct HookTask { + pub id: String, + pub repo_id: String, + pub task_type: TaskType, + pub payload: serde_json::Value, + pub created_at: chrono::DateTime, + pub retry_count: usize, +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub enum TaskType { + Sync, + Fsck, + Gc, + Webhook, +} diff --git a/lib/git/sync/push_queue.rs b/lib/git/sync/push_queue.rs new file mode 100644 index 0000000..e40e059 --- /dev/null +++ b/lib/git/sync/push_queue.rs @@ -0,0 +1,202 @@ +use std::{ + fmt, + time::{Duration, Instant}, +}; + +use tokio::{task::JoinHandle, time::sleep}; + +use crate::sync::ReceiveSyncService; + +pub const PUSH_QUEUE_TIMEOUT: Duration = Duration::from_secs(120); +pub const PUSH_LOCK_TTL_SECS: usize = 300; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct PushQueuePosition { + pub position: usize, + pub total: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PushQueueEvent { + Waiting(PushQueuePosition), + Acquired, +} + +#[derive(Debug)] +pub enum PushQueueWaitError { + Join(redis::RedisError), + Lock(redis::RedisError), + Timeout, +} + +impl fmt::Display for PushQueueWaitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Join(e) => write!(f, "failed to join push queue: {e}"), + Self::Lock(e) => { + write!(f, "failed to acquire push queue lock: {e}") + } + Self::Timeout => write!(f, "push queue timed out"), + } + } +} + +impl std::error::Error for PushQueueWaitError {} + +pub struct PushQueueLease { + service: ReceiveSyncService, + repo_uid: uuid::Uuid, + request_id: String, + heartbeat: Option>, + released: bool, +} + +impl PushQueueLease { + fn new( + service: ReceiveSyncService, + repo_uid: uuid::Uuid, + request_id: String, + ) -> Self { + let heartbeat = Some(start_lock_heartbeat( + service.clone(), + repo_uid, + request_id.clone(), + )); + Self { + service, + repo_uid, + request_id, + heartbeat, + released: false, + } + } + + pub fn request_id(&self) -> &str { + &self.request_id + } + + pub async fn release(&mut self) { + if self.released { + return; + } + self.service + .release_push_queue(self.repo_uid, &self.request_id) + .await; + if let Some(heartbeat) = self.heartbeat.take() { + heartbeat.abort(); + } + self.released = true; + } +} + +impl Drop for PushQueueLease { + fn drop(&mut self) { + if self.released { + return; + } + if let Some(heartbeat) = self.heartbeat.take() { + heartbeat.abort(); + } + let service = self.service.clone(); + let repo_uid = self.repo_uid; + let request_id = self.request_id.clone(); + tokio::spawn(async move { + service.release_push_queue(repo_uid, &request_id).await; + }); + } +} + +fn start_lock_heartbeat( + service: ReceiveSyncService, + repo_uid: uuid::Uuid, + request_id: String, +) -> JoinHandle<()> { + tokio::spawn(async move { + let interval = + Duration::from_secs((PUSH_LOCK_TTL_SECS as u64 / 3).max(30)); + loop { + sleep(interval).await; + match service + .refresh_push_lock(repo_uid, &request_id, PUSH_LOCK_TTL_SECS) + .await + { + Ok(true) => {} + Ok(false) => { + tracing::warn!( + repo_id = %repo_uid, + request_id = %request_id, + "push_queue_lock_lost" + ); + break; + } + Err(e) => { + tracing::warn!( + error = %e, + repo_id = %repo_uid, + request_id = %request_id, + "push_queue_lock_refresh_failed" + ); + } + } + } + }) +} + +pub async fn wait_for_push_queue_slot( + service: ReceiveSyncService, + repo_uid: uuid::Uuid, + mut on_event: F, +) -> Result +where + F: FnMut(PushQueueEvent, &str), +{ + let request_id = uuid::Uuid::new_v4().to_string(); + service + .join_push_queue(repo_uid, &request_id) + .await + .map_err(PushQueueWaitError::Join)?; + + let deadline = Instant::now() + PUSH_QUEUE_TIMEOUT; + let mut last_position = None; + + loop { + let position = service.push_queue_position(repo_uid, &request_id).await; + if let Some((position, total)) = position { + let position = PushQueuePosition { position, total }; + if last_position != Some(position) && position.position > 1 { + on_event(PushQueueEvent::Waiting(position), &request_id); + } + last_position = Some(position); + + if position.position == 1 { + match service + .try_acquire_push_lock( + repo_uid, + &request_id, + PUSH_LOCK_TTL_SECS, + ) + .await + { + Ok(true) => { + on_event(PushQueueEvent::Acquired, &request_id); + return Ok(PushQueueLease::new( + service, repo_uid, request_id, + )); + } + Ok(false) => {} + Err(e) => { + service.release_push_queue(repo_uid, &request_id).await; + return Err(PushQueueWaitError::Lock(e)); + } + } + } + } + + if Instant::now() >= deadline { + service.release_push_queue(repo_uid, &request_id).await; + return Err(PushQueueWaitError::Timeout); + } + + sleep(Duration::from_secs(1)).await; + } +} diff --git a/lib/git/sync/tag.rs b/lib/git/sync/tag.rs new file mode 100644 index 0000000..f69ea29 --- /dev/null +++ b/lib/git/sync/tag.rs @@ -0,0 +1,100 @@ +use std::collections::HashSet; + +use db::{database::AppDatabase, sqlx}; +use model::repos::RepoRefModel; +use uuid::Uuid; + +use crate::{bare::GitBare, errors::GitError}; +#[derive(Debug, Clone)] +pub struct TagTip { + pub name: String, + pub target_oid: String, +} +pub fn collect_tag_tips(bare: &GitBare) -> Result, GitError> { + let repo = bare.gix_repo()?; + let refs = repo.references() + .map_err(|e| GitError::Internal(format!("failed to open references: {}", e)))?; + let iter = refs.all() + .map_err(|e| GitError::Internal(format!("failed to iterate refs: {}", e)))?; + + let mut tags = Vec::new(); + for ref_result in iter { + let reference = ref_result + .map_err(|e| GitError::Internal(format!("ref iteration error: {}", e)))?; + let full_name = reference.name().as_bstr().to_string(); + if !full_name.starts_with("refs/tags/") { + continue; + } + let target_oid = reference.target().try_id() + .map(|id| id.to_hex().to_string()) + .ok_or_else(|| GitError::Internal("ref has no direct target".to_string()))?; + let short_name = reference.name().shorten().to_string(); + tags.push(TagTip { + name: short_name, + target_oid, + }); + } + Ok(tags) +} +pub async fn sync_tags( + db: &AppDatabase, + bare: &GitBare, + repo_id: Uuid, +) -> Result<(), GitError> { + let pool = db.writer(); + + let existing: Vec = sqlx::query_as::<_, RepoRefModel>( + "SELECT id, repo, name, kind, target_sha, is_default, is_protected, created_at, updated_at FROM repo_ref WHERE repo = $1 AND kind = 'tag'" + ) + .bind(repo_id) + .fetch_all(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to query tags: {}", e)))?; + let mut existing_names: HashSet = + existing.iter().map(|t| t.name.clone()).collect(); + + let tags = collect_tag_tips(bare)?; + let now = chrono::Utc::now(); + + for tag in tags { + if existing_names.contains(&tag.name) { + existing_names.remove(&tag.name); + sqlx::query( + "UPDATE repo_ref SET target_sha = $1, updated_at = $2 WHERE repo = $3 AND name = $4 AND kind = 'tag'" + ) + .bind(&tag.target_oid) + .bind(now) + .bind(repo_id) + .bind(&tag.name) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to update tag: {}", e)))?; + } else { + let new_id = Uuid::new_v4(); + sqlx::query( + "INSERT INTO repo_ref (id, repo, name, kind, target_sha, is_default, is_protected, created_at, updated_at) VALUES ($1, $2, $3, 'tag', $4, false, false, $5, $6)" + ) + .bind(new_id) + .bind(repo_id) + .bind(&tag.name) + .bind(&tag.target_oid) + .bind(now) + .bind(now) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to insert tag: {}", e)))?; + } + } + + if !existing_names.is_empty() { + let names_vec: Vec = existing_names.into_iter().collect(); + sqlx::query("DELETE FROM repo_ref WHERE repo = $1 AND name = ANY($2) AND kind = 'tag'") + .bind(repo_id) + .bind(&names_vec) + .execute(pool) + .await + .map_err(|e| GitError::Internal(format!("failed to delete stale tags: {}", e)))?; + } + + Ok(()) +} diff --git a/lib/git/sync/webhook.rs b/lib/git/sync/webhook.rs new file mode 100644 index 0000000..d798e9f --- /dev/null +++ b/lib/git/sync/webhook.rs @@ -0,0 +1,194 @@ +use std::time::Duration; + +use deadpool_redis::cluster::Pool as RedisPool; +use hmac::{Hmac, KeyInit, Mac}; +use sha2::Sha256; + +type HmacSha256 = Hmac; +fn webhook_queue_keys(repo_id: uuid::Uuid) -> (String, String) { + let hash_tag = format!("{{wh:{}}}", repo_id); + ( + format!("{}:pending", hash_tag), + format!("{}:processing", hash_tag), + ) +} +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct WebhookDeliveryTask { + pub id: String, + pub webhook_id: String, + pub repo_id: String, + pub event: String, + pub url: String, + pub secret: Option, + pub payload: serde_json::Value, + pub created_at: chrono::DateTime, + pub retry_count: usize, +} +pub async fn enqueue_delivery( + task: WebhookDeliveryTask, + redis_pool: &RedisPool, +) -> Result<(), String> { + let task_json = serde_json::to_string(&task) + .map_err(|e| format!("serialize error: {}", e))?; + + let repo_id: uuid::Uuid = task + .repo_id + .parse() + .map_err(|e| format!("invalid repo_id: {}", e))?; + let (pending_key, _) = webhook_queue_keys(repo_id); + + let redis = redis_pool + .get() + .await + .map_err(|e| format!("redis pool: {}", e))?; + let mut conn: deadpool_redis::cluster::Connection = redis; + + redis::cmd("LPUSH") + .arg(&pending_key) + .arg(&task_json) + .query_async::<()>(&mut conn) + .await + .map_err(|e| format!("LPUSH error: {}", e))?; + + tracing::info!( + webhook_id = %task.webhook_id, + repo_id = %task.repo_id, + event = %task.event, + "webhook delivery enqueued" + ); + + Ok(()) +} +pub async fn poll_delivery_for_repo( + redis_pool: &RedisPool, + repo_id: uuid::Uuid, + block_timeout_secs: usize, +) -> Option { + let (pending_key, processing_key) = webhook_queue_keys(repo_id); + + let redis = redis_pool.get().await.ok()?; + let mut conn: deadpool_redis::cluster::Connection = redis; + + redis::cmd("BLMOVE") + .arg(&pending_key) + .arg(&processing_key) + .arg("RIGHT") + .arg("LEFT") + .arg(block_timeout_secs) + .query_async::>(&mut conn) + .await + .ok() + .flatten() +} +pub async fn ack_delivery( + redis_pool: &RedisPool, + repo_id: uuid::Uuid, + task_json: &str, +) { + let (_, processing_key) = webhook_queue_keys(repo_id); + + let redis = match redis_pool.get().await { + Ok(c) => c, + Err(e) => { + tracing::warn!(error = %e, "webhook ack: failed to get redis connection"); + return; + } + }; + let mut conn: deadpool_redis::cluster::Connection = redis; + + if let Err(e) = redis::cmd("LREM") + .arg(&processing_key) + .arg(1) + .arg(task_json) + .query_async::<()>(&mut conn) + .await + { + tracing::warn!(error = %e, "webhook ack: LREM failed"); + } +} +fn compute_hmac_signature(secret: &str, body: &[u8]) -> String { + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) + .expect("HMAC can take key of any size"); + mac.update(body); + let result = mac.finalize(); + let code_bytes = result.into_bytes(); + hex::encode(code_bytes) +} +pub async fn deliver_webhook( + task: &WebhookDeliveryTask, +) -> WebhookDeliveryResult { + let body_bytes = serde_json::to_vec(&task.payload).unwrap_or_default(); + + let signature = task + .secret + .as_ref() + .map(|s| compute_hmac_signature(s, &body_bytes)); + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .unwrap_or_default(); + + let mut request = client + .post(&task.url) + .header("Content-Type", "application/json") + .header("X-GitData-Event", &task.event) + .header("X-GitData-Delivery", &task.id); + + if let Some(sig) = &signature { + request = request + .header("X-GitData-Signature-256", format!("sha256={}", sig)); + } + + let response = request.body(body_bytes.clone()).send().await; + + match response { + Ok(resp) => { + let status = resp.status().as_u16() as i32; + let resp_headers: String = resp + .headers() + .iter() + .map(|(k, v)| { + format!("{}: {}", k, v.to_str().unwrap_or_default()) + }) + .collect::>() + .join("\n"); + let resp_body = resp.text().await.unwrap_or_default(); + + WebhookDeliveryResult { + response_status: Some(status), + response_headers: Some(resp_headers), + response_body: Some(resp_body), + error: None, + request_headers: Some(format!( + "Content-Type: application/json\nX-GitData-Event: {}\nX-GitData-Delivery: {}", + task.event, task.id + )), + request_body: Some( + String::from_utf8_lossy(&body_bytes).to_string(), + ), + } + } + Err(e) => WebhookDeliveryResult { + response_status: None, + response_headers: None, + response_body: None, + error: Some(e.to_string()), + request_headers: Some(format!( + "Content-Type: application/json\nX-GitData-Event: {}\nX-GitData-Delivery: {}", + task.event, task.id + )), + request_body: Some( + String::from_utf8_lossy(&body_bytes).to_string(), + ), + }, + } +} +pub struct WebhookDeliveryResult { + pub request_headers: Option, + pub request_body: Option, + pub response_status: Option, + pub response_headers: Option, + pub response_body: Option, + pub error: Option, +} diff --git a/lib/git/sync/worker.rs b/lib/git/sync/worker.rs new file mode 100644 index 0000000..2eea2cc --- /dev/null +++ b/lib/git/sync/worker.rs @@ -0,0 +1,496 @@ +use std::time::Duration; + +use cache::AppCache; +use config::AppConfig; +use db::{database::AppDatabase, sqlx}; +use deadpool_redis::cluster::Pool as RedisPool; +use model::repos::RepoModel; +use parsefile::TriggerEvent; + +use crate::sync::{ + HookTask, TaskType, + cicheck::{CiCheckOutcome, check_and_enqueue}, + consumer::SyncConsumer, + lock::{acquire_repo_lock, release_repo_lock}, + webhook::{WebhookDeliveryTask, deliver_webhook}, +}; +pub struct SyncWorker { + pub consumer: SyncConsumer, + pub db: AppDatabase, + pub cache: AppCache, + pub redis_pool: RedisPool, + pub config: AppConfig, + pub max_retries: usize, + pub worker_id: String, +} + +impl SyncWorker { + pub fn new( + consumer: SyncConsumer, + db: AppDatabase, + cache: AppCache, + redis_pool: RedisPool, + config: AppConfig, + worker_id: String, + ) -> Self { + Self { + consumer, + db, + cache, + redis_pool, + config, + max_retries: 3, + worker_id, + } + } + pub async fn run(&self) { + tracing::info!(worker_id = %self.worker_id, "sync worker starting"); + let mut backoff_secs: u64 = 1; + + loop { + let mut had_error = false; + + for task_type in &[ + TaskType::Sync, + TaskType::Fsck, + TaskType::Gc, + TaskType::Webhook, + ] { + let queue_key = + self.consumer.queue_key_for_task_type(task_type); + + if let Some((task_json, work_key)) = + self.consumer.next(task_type).await + { + let task: HookTask = match serde_json::from_str(&task_json) + { + Ok(t) => t, + Err(e) => { + tracing::error!(error = %e, "failed to deserialize hook task"); + self.consumer.ack(&task_json, &work_key).await; + continue; + } + }; + + tracing::info!( + task_id = %task.id, + repo_id = %task.repo_id, + task_type = ?task.task_type, + "processing hook task" + ); + + let result = self + .process_task(&task, &task_json, &work_key, &queue_key) + .await; + + match result { + ProcessResult::Success => { + self.consumer.ack(&task_json, &work_key).await; + backoff_secs = 1; + } + ProcessResult::Locked => { + self.consumer + .nak_with_retry( + &task_json, &work_key, &queue_key, + ) + .await; + backoff_secs = 1; + } + ProcessResult::Error => { + if task.retry_count >= self.max_retries { + tracing::warn!( + task_id = %task.id, + repo_id = %task.repo_id, + retry_count = task.retry_count, + "max retries exceeded, dropping task" + ); + self.consumer.ack(&task_json, &work_key).await; + } else { + tracing::warn!( + task_id = %task.id, + repo_id = %task.repo_id, + retry_count = task.retry_count, + "task failed, re-queueing" + ); + let mut updated_task = task.clone(); + updated_task.retry_count += 1; + if let Ok(updated_json) = + serde_json::to_string(&updated_task) + { + self.consumer + .nak_with_retry( + &updated_json, + &work_key, + &queue_key, + ) + .await; + } else { + self.consumer + .nak_with_retry( + &task_json, &work_key, &queue_key, + ) + .await; + } + } + had_error = true; + } + } + } + } + + if had_error { + tokio::time::sleep(Duration::from_secs(backoff_secs)).await; + backoff_secs = (backoff_secs * 2).min(32); + } + } + } + + async fn process_task( + &self, + task: &HookTask, + _task_json: &str, + _work_key: &str, + _queue_key: &str, + ) -> ProcessResult { + match task.task_type { + TaskType::Sync => self.run_sync(task).await, + TaskType::Fsck => self.run_fsck(task).await, + TaskType::Gc => self.run_gc(task).await, + TaskType::Webhook => self.run_webhook(task).await, + } + } + + async fn run_sync(&self, task: &HookTask) -> ProcessResult { + let repo_id = match task.repo_id.parse::() { + Ok(id) => id, + Err(e) => { + tracing::error!(error = %e, repo_id = %task.repo_id, "invalid repo_id UUID"); + return ProcessResult::Error; + } + }; + + let lock_value = format!("{}:{}", self.worker_id, task.id); + let lock_result = acquire_repo_lock( + &self.redis_pool, + repo_id, + &lock_value, + 300, // 5 min TTL + ) + .await; + + match lock_result { + Ok(true) => {} + Ok(false) => return ProcessResult::Locked, + Err(e) => { + tracing::error!(error = %e, repo_id = %repo_id, "failed to acquire repo lock"); + return ProcessResult::Error; + } + } + + let result = self.do_sync(repo_id).await; + + release_repo_lock(&self.redis_pool, repo_id, &lock_value).await; + + match result { + Ok(()) => ProcessResult::Success, + Err(e) => { + tracing::error!(error = %e, repo_id = %repo_id, "sync pipeline failed"); + ProcessResult::Error + } + } + } + + async fn do_sync(&self, repo_id: uuid::Uuid) -> anyhow::Result<()> { + let pool = self.db.reader(); + + let repo_model = sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at FROM repo WHERE id = $1" + ) + .bind(repo_id) + .fetch_optional(pool) + .await? + .ok_or_else(|| anyhow::anyhow!("repo not found: {}", repo_id))?; + + let bare_dir = std::path::PathBuf::from(&repo_model.storage_path); + if !bare_dir.exists() { + tracing::warn!(repo_id = %repo_id, "bare repo directory missing, skipping sync"); + return Ok(()); + } + + let bare = crate::bare::GitBare { bare_dir }; + if let Err(e) = + crate::sync::branch::sync_refs(&self.db, &bare, repo_id).await + { + tracing::error!(error = %e, repo_id = %repo_id, "sync_refs failed"); + } + + if let Err(e) = + crate::sync::commit::sync_commits(&self.db, &bare, repo_id).await + { + tracing::error!(error = %e, repo_id = %repo_id, "sync_commits failed"); + } + + if let Err(e) = + crate::sync::tag::sync_tags(&self.db, &bare, repo_id).await + { + tracing::error!(error = %e, repo_id = %repo_id, "sync_tags failed"); + } + + if let Err(e) = + crate::sync::lfs::sync_lfs_objects(&self.db, &bare, repo_id).await + { + tracing::error!(error = %e, repo_id = %repo_id, "sync_lfs_objects failed"); + } + + if let Err(e) = + crate::sync::language::sync_languages(&self.db, &bare, repo_id).await + { + tracing::error!(error = %e, repo_id = %repo_id, "sync_languages failed"); + } + + let gc_result = bare.git_command_trusted_unchecked(vec![ + "gc".to_string(), + "--auto".to_string(), + "--quiet".to_string(), + ]); + if let Ok(output) = gc_result { + if !output.success { + tracing::warn!(repo_id = %repo_id, "git gc failed: {}", output.stderr_lossy()); + } + } + + let pattern = format!("git:rpc:cache:*:{}:*", repo_id); + let _ = self.cache.delete_pattern(&pattern).await; + + tracing::info!(repo_id = %repo_id, "sync completed"); + if let Err(e) = self + .run_ci_check(&repo_model.default_branch, &bare, repo_id) + .await + { + tracing::warn!(error = %e, repo_id = %repo_id, "CI check failed"); + } + + Ok(()) + } + + async fn run_fsck(&self, task: &HookTask) -> ProcessResult { + let repo_id = match task.repo_id.parse::() { + Ok(id) => id, + Err(_) => return ProcessResult::Error, + }; + + let pool = self.db.reader(); + + let storage_path = sqlx::query_scalar::<_, String>( + "SELECT storage_path FROM repo WHERE id = $1", + ) + .bind(repo_id) + .fetch_optional(pool) + .await; + + let storage_path = match storage_path { + Ok(Some(s)) => s, + _ => return ProcessResult::Error, + }; + + let bare = crate::bare::GitBare { + bare_dir: std::path::PathBuf::from(&storage_path), + }; + + let result = bare.git_command_trusted_unchecked(vec![ + "fsck".to_string(), + "--full".to_string(), + ]); + + match result { + Ok(output) if output.success => ProcessResult::Success, + Ok(output) => { + tracing::warn!(repo_id = %repo_id, "fsck failed: {}", output.stderr_lossy()); + ProcessResult::Error + } + Err(e) => { + tracing::error!(error = %e, repo_id = %repo_id, "fsck command failed"); + ProcessResult::Error + } + } + } + + async fn run_gc(&self, task: &HookTask) -> ProcessResult { + let repo_id = match task.repo_id.parse::() { + Ok(id) => id, + Err(_) => return ProcessResult::Error, + }; + + let pool = self.db.reader(); + + let storage_path = sqlx::query_scalar::<_, String>( + "SELECT storage_path FROM repo WHERE id = $1", + ) + .bind(repo_id) + .fetch_optional(pool) + .await; + + let storage_path = match storage_path { + Ok(Some(s)) => s, + _ => return ProcessResult::Error, + }; + + let bare = crate::bare::GitBare { + bare_dir: std::path::PathBuf::from(&storage_path), + }; + + let result = bare + .git_command_trusted(vec!["gc".to_string(), "--auto".to_string()]); + + match result { + Ok(_) => ProcessResult::Success, + Err(e) => { + tracing::error!(error = %e, repo_id = %repo_id, "gc command failed"); + ProcessResult::Error + } + } + } + + async fn run_ci_check( + &self, + default_branch: &str, + bare: &crate::bare::GitBare, + repo_id: uuid::Uuid, + ) -> anyhow::Result<()> { + let event = TriggerEvent::PushBranch(default_branch.to_owned()); + + let outcome = + check_and_enqueue(bare, repo_id, &event, &self.redis_pool).await?; + + match outcome { + CiCheckOutcome::Enqueued => { + tracing::info!( + repo_id = %repo_id, + branch = %default_branch, + "CI pipeline triggered" + ); + } + CiCheckOutcome::NoPipelineFile => { + tracing::debug!(repo_id = %repo_id, "no pipeline.yaml found"); + } + CiCheckOutcome::NotTriggered => { + tracing::debug!( + repo_id = %repo_id, + branch = %default_branch, + "pipeline.yaml exists but not triggered for this event" + ); + } + } + + Ok(()) + } + async fn run_webhook(&self, task: &HookTask) -> ProcessResult { + let repo_id = match task.repo_id.parse::() { + Ok(id) => id, + Err(_) => return ProcessResult::Error, + }; + + let event = task + .payload + .get("webhook_event") + .and_then(|v| v.as_str()) + .unwrap_or("push"); + + let webhooks: Vec<(uuid::Uuid, String, Option, String)> = + sqlx::query_as( + "SELECT id, url, secret_hash, events \ + FROM repo_webhook WHERE repo = $1 AND active = true", + ) + .bind(repo_id) + .fetch_all(self.db.reader()) + .await + .unwrap_or_default(); + + for (wh_id, wh_url, wh_secret, wh_events) in webhooks { + let subscribed: Vec<&str> = + wh_events.split('.').filter(|s| !s.is_empty()).collect(); + let matches = subscribed.iter().any(|e| { + *e == event + || (*e == "push" + && (event == "push_branch" || event == "push_tag")) + }); + if !matches { + continue; + } + + let delivery_id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + + sqlx::query( + "INSERT INTO repo_webhook_delivery \ + (id, repo, webhook, event, request_headers, request_body, \ + response_status, response_headers, response_body, error, delivered_at, created_at) \ + VALUES ($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL, NULL, NULL, $5)", + ) + .bind(delivery_id) + .bind(repo_id) + .bind(wh_id) + .bind(event) + .bind(now) + .execute(self.db.writer()) + .await + .ok(); + + let wh_task = WebhookDeliveryTask { + id: delivery_id.to_string(), + webhook_id: wh_id.to_string(), + repo_id: repo_id.to_string(), + event: event.to_string(), + url: wh_url, + secret: wh_secret, + payload: task.payload.clone(), + created_at: now, + retry_count: 0, + }; + + let result = deliver_webhook(&wh_task).await; + + sqlx::query( + "UPDATE repo_webhook_delivery SET \ + request_headers = $1, request_body = $2, \ + response_status = $3, response_headers = $4, response_body = $5, \ + error = $6, delivered_at = $7 WHERE id = $8", + ) + .bind(&result.request_headers) + .bind(&result.request_body) + .bind(result.response_status) + .bind(&result.response_headers) + .bind(&result.response_body) + .bind(&result.error) + .bind(now) + .bind(delivery_id) + .execute(self.db.writer()) + .await + .ok(); + + if result.error.is_some() { + tracing::warn!( + webhook_id = %wh_id, + repo_id = %repo_id, + error = ?result.error, + "webhook delivery failed" + ); + } else { + tracing::info!( + webhook_id = %wh_id, + repo_id = %repo_id, + status = ?result.response_status, + "webhook delivered" + ); + } + } + + ProcessResult::Success + } +} + +enum ProcessResult { + Success, + Locked, + Error, +} diff --git a/lib/issues/Cargo.toml b/lib/issues/Cargo.toml new file mode 100644 index 0000000..300fec1 --- /dev/null +++ b/lib/issues/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "issues" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[dependencies] + +[lints] +workspace = true diff --git a/lib/issues/src/main.rs b/lib/issues/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/lib/issues/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/lib/migrate/Cargo.toml b/lib/migrate/Cargo.toml new file mode 100644 index 0000000..013a088 --- /dev/null +++ b/lib/migrate/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "migrate" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[dependencies] +tokio = { workspace = true, features = ["full"] } +sqlx = { workspace = true, features = ["postgres", "runtime-tokio", "chrono", "uuid"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +anyhow = { workspace = true } +clap = { version = "4", features = ["derive"] } +serde = { workspace = true, features = ["derive"] } + +[lints] +workspace = true \ No newline at end of file diff --git a/lib/migrate/sql/agent/agent_compression_lock_down_01.sql b/lib/migrate/sql/agent/agent_compression_lock_down_01.sql new file mode 100644 index 0000000..63e999a --- /dev/null +++ b/lib/migrate/sql/agent/agent_compression_lock_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_compression_lock CASCADE; diff --git a/lib/migrate/sql/agent/agent_compression_lock_up_01.sql b/lib/migrate/sql/agent/agent_compression_lock_up_01.sql new file mode 100644 index 0000000..6831829 --- /dev/null +++ b/lib/migrate/sql/agent/agent_compression_lock_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS agent_compression_lock ( + session_id UUID PRIMARY KEY REFERENCES agent_session(id) ON DELETE CASCADE, + holder TEXT NOT NULL, + acquired_at TIMESTAMPTZ NOT NULL DEFAULT now(), + expires_at TIMESTAMPTZ NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_compression_lock_expires + ON agent_compression_lock(expires_at); diff --git a/lib/migrate/sql/agent/agent_conversation_compacted_summary_down_02.sql b/lib/migrate/sql/agent/agent_conversation_compacted_summary_down_02.sql new file mode 100644 index 0000000..77ef7d0 --- /dev/null +++ b/lib/migrate/sql/agent/agent_conversation_compacted_summary_down_02.sql @@ -0,0 +1 @@ +ALTER TABLE agent_conversation DROP COLUMN IF EXISTS compacted_summary; diff --git a/lib/migrate/sql/agent/agent_conversation_compacted_summary_up_02.sql b/lib/migrate/sql/agent/agent_conversation_compacted_summary_up_02.sql new file mode 100644 index 0000000..0f60054 --- /dev/null +++ b/lib/migrate/sql/agent/agent_conversation_compacted_summary_up_02.sql @@ -0,0 +1 @@ +ALTER TABLE agent_conversation ADD COLUMN IF NOT EXISTS compacted_summary TEXT; diff --git a/lib/migrate/sql/agent/agent_conversation_down_01.sql b/lib/migrate/sql/agent/agent_conversation_down_01.sql new file mode 100644 index 0000000..aa4eadb --- /dev/null +++ b/lib/migrate/sql/agent/agent_conversation_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_conversation CASCADE; diff --git a/lib/migrate/sql/agent/agent_conversation_up_01.sql b/lib/migrate/sql/agent/agent_conversation_up_01.sql new file mode 100644 index 0000000..67a34f2 --- /dev/null +++ b/lib/migrate/sql/agent/agent_conversation_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS agent_conversation ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session UUID NOT NULL, + title TEXT NOT NULL DEFAULT '', + created_by UUID NOT NULL, + last_message_at TIMESTAMPTZ, + archived_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/agent/agent_hook_registration_down_01.sql b/lib/migrate/sql/agent/agent_hook_registration_down_01.sql new file mode 100644 index 0000000..04e031d --- /dev/null +++ b/lib/migrate/sql/agent/agent_hook_registration_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_hook_registration CASCADE; diff --git a/lib/migrate/sql/agent/agent_hook_registration_up_01.sql b/lib/migrate/sql/agent/agent_hook_registration_up_01.sql new file mode 100644 index 0000000..ba49da2 --- /dev/null +++ b/lib/migrate/sql/agent/agent_hook_registration_up_01.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS agent_hook_registration ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session_id UUID NOT NULL REFERENCES agent_session(id) ON DELETE CASCADE, + hook_type TEXT NOT NULL, + handler_name TEXT NOT NULL, + config_json TEXT, + priority INTEGER NOT NULL DEFAULT 0, + enabled BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_agent_hook_session + ON agent_hook_registration(session_id, hook_type); diff --git a/lib/migrate/sql/agent/agent_knowledge_base_down_01.sql b/lib/migrate/sql/agent/agent_knowledge_base_down_01.sql new file mode 100644 index 0000000..2523cae --- /dev/null +++ b/lib/migrate/sql/agent/agent_knowledge_base_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_knowledge_base CASCADE; diff --git a/lib/migrate/sql/agent/agent_knowledge_base_up_01.sql b/lib/migrate/sql/agent/agent_knowledge_base_up_01.sql new file mode 100644 index 0000000..b618534 --- /dev/null +++ b/lib/migrate/sql/agent/agent_knowledge_base_up_01.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS agent_knowledge_base ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session UUID NOT NULL, + title TEXT NOT NULL, + source_type TEXT NOT NULL, + source_url TEXT, + content TEXT, + embedding_ref TEXT, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/agent/agent_long_term_memory_down_01.sql b/lib/migrate/sql/agent/agent_long_term_memory_down_01.sql new file mode 100644 index 0000000..25a9fac --- /dev/null +++ b/lib/migrate/sql/agent/agent_long_term_memory_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_long_term_memory CASCADE; diff --git a/lib/migrate/sql/agent/agent_long_term_memory_session_key_down_02.sql b/lib/migrate/sql/agent/agent_long_term_memory_session_key_down_02.sql new file mode 100644 index 0000000..1e60633 --- /dev/null +++ b/lib/migrate/sql/agent/agent_long_term_memory_session_key_down_02.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_agent_ltm_session_key; diff --git a/lib/migrate/sql/agent/agent_long_term_memory_session_key_up_02.sql b/lib/migrate/sql/agent/agent_long_term_memory_session_key_up_02.sql new file mode 100644 index 0000000..31bab62 --- /dev/null +++ b/lib/migrate/sql/agent/agent_long_term_memory_session_key_up_02.sql @@ -0,0 +1,4 @@ +-- Add partial unique index so we can upsert on (session, key) ignoring soft-deleted rows. +CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_ltm_session_key + ON agent_long_term_memory (session, key) + WHERE deleted_at IS NULL; diff --git a/lib/migrate/sql/agent/agent_long_term_memory_up_01.sql b/lib/migrate/sql/agent/agent_long_term_memory_up_01.sql new file mode 100644 index 0000000..dcd851f --- /dev/null +++ b/lib/migrate/sql/agent/agent_long_term_memory_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS agent_long_term_memory ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session UUID NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + importance INTEGER NOT NULL DEFAULT 0, + last_used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/agent/agent_message_down_01.sql b/lib/migrate/sql/agent/agent_message_down_01.sql new file mode 100644 index 0000000..7b3664b --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_message CASCADE; diff --git a/lib/migrate/sql/agent/agent_message_down_03.sql b/lib/migrate/sql/agent/agent_message_down_03.sql new file mode 100644 index 0000000..19031f9 --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_down_03.sql @@ -0,0 +1,6 @@ +DROP INDEX IF EXISTS idx_agent_message_search; + +ALTER TABLE agent_message + DROP COLUMN IF EXISTS search_vector, + DROP COLUMN IF EXISTS finish_reason, + DROP COLUMN IF EXISTS token_count; diff --git a/lib/migrate/sql/agent/agent_message_fork_down_01.sql b/lib/migrate/sql/agent/agent_message_fork_down_01.sql new file mode 100644 index 0000000..07d6dac --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_fork_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_message_fork CASCADE; diff --git a/lib/migrate/sql/agent/agent_message_fork_up_01.sql b/lib/migrate/sql/agent/agent_message_fork_up_01.sql new file mode 100644 index 0000000..4698605 --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_fork_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS agent_message_fork ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + source_message UUID NOT NULL, + forked_conversation UUID NOT NULL, + forked_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/agent/agent_message_reasoning_content_down_02.sql b/lib/migrate/sql/agent/agent_message_reasoning_content_down_02.sql new file mode 100644 index 0000000..078d788 --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_reasoning_content_down_02.sql @@ -0,0 +1,2 @@ +-- Remove reasoning_content column from agent_message table +ALTER TABLE agent_message DROP COLUMN IF EXISTS reasoning_content; diff --git a/lib/migrate/sql/agent/agent_message_reasoning_content_up_02.sql b/lib/migrate/sql/agent/agent_message_reasoning_content_up_02.sql new file mode 100644 index 0000000..77ab6fd --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_reasoning_content_up_02.sql @@ -0,0 +1,2 @@ +-- Add reasoning_content column to agent_message table +ALTER TABLE agent_message ADD COLUMN IF NOT EXISTS reasoning_content TEXT; diff --git a/lib/migrate/sql/agent/agent_message_up_01.sql b/lib/migrate/sql/agent/agent_message_up_01.sql new file mode 100644 index 0000000..08b6ad2 --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS agent_message ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + conversation UUID NOT NULL, + parent UUID, + role TEXT NOT NULL, + author UUID, + content TEXT NOT NULL, + content_type TEXT NOT NULL DEFAULT 'text', + status TEXT NOT NULL DEFAULT 'sent', + model_invocation UUID, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/agent/agent_message_up_03.sql b/lib/migrate/sql/agent/agent_message_up_03.sql new file mode 100644 index 0000000..f4ae49e --- /dev/null +++ b/lib/migrate/sql/agent/agent_message_up_03.sql @@ -0,0 +1,10 @@ +ALTER TABLE agent_message + ADD COLUMN IF NOT EXISTS token_count INTEGER, + ADD COLUMN IF NOT EXISTS finish_reason TEXT; + +ALTER TABLE agent_message + ADD COLUMN IF NOT EXISTS search_vector tsvector + GENERATED ALWAYS AS (to_tsvector('english', coalesce(content, ''))) STORED; + +CREATE INDEX IF NOT EXISTS idx_agent_message_search + ON agent_message USING GIN (search_vector); diff --git a/lib/migrate/sql/agent/agent_model_invocation_down_01.sql b/lib/migrate/sql/agent/agent_model_invocation_down_01.sql new file mode 100644 index 0000000..7374c19 --- /dev/null +++ b/lib/migrate/sql/agent/agent_model_invocation_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_model_invocation CASCADE; diff --git a/lib/migrate/sql/agent/agent_model_invocation_up_01.sql b/lib/migrate/sql/agent/agent_model_invocation_up_01.sql new file mode 100644 index 0000000..b58ec47 --- /dev/null +++ b/lib/migrate/sql/agent/agent_model_invocation_up_01.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS agent_model_invocation ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session UUID NOT NULL, + conversation UUID, + message UUID, + model_version UUID NOT NULL, + request_id TEXT, + status TEXT NOT NULL DEFAULT 'pending', + prompt TEXT, + response TEXT, + error TEXT, + started_at TIMESTAMPTZ NOT NULL DEFAULT now(), + finished_at TIMESTAMPTZ, + latency_ms BIGINT +); diff --git a/lib/migrate/sql/agent/agent_session_config_down_02.sql b/lib/migrate/sql/agent/agent_session_config_down_02.sql new file mode 100644 index 0000000..c4b8315 --- /dev/null +++ b/lib/migrate/sql/agent/agent_session_config_down_02.sql @@ -0,0 +1,8 @@ +ALTER TABLE agent_session + DROP COLUMN IF EXISTS rollback_from_version, + DROP COLUMN IF EXISTS published_at, + DROP COLUMN IF EXISTS version, + DROP COLUMN IF EXISTS visibility, + DROP COLUMN IF EXISTS variables, + DROP COLUMN IF EXISTS knowledge_base_ids, + DROP COLUMN IF EXISTS tool_policy; diff --git a/lib/migrate/sql/agent/agent_session_config_up_02.sql b/lib/migrate/sql/agent/agent_session_config_up_02.sql new file mode 100644 index 0000000..30fe96e --- /dev/null +++ b/lib/migrate/sql/agent/agent_session_config_up_02.sql @@ -0,0 +1,8 @@ +ALTER TABLE agent_session + ADD COLUMN IF NOT EXISTS tool_policy TEXT, + ADD COLUMN IF NOT EXISTS knowledge_base_ids TEXT, + ADD COLUMN IF NOT EXISTS variables TEXT, + ADD COLUMN IF NOT EXISTS visibility TEXT NOT NULL DEFAULT 'private', + ADD COLUMN IF NOT EXISTS version INTEGER NOT NULL DEFAULT 1, + ADD COLUMN IF NOT EXISTS published_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS rollback_from_version INTEGER; diff --git a/lib/migrate/sql/agent/agent_session_down_01.sql b/lib/migrate/sql/agent/agent_session_down_01.sql new file mode 100644 index 0000000..4bca45f --- /dev/null +++ b/lib/migrate/sql/agent/agent_session_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_session CASCADE; diff --git a/lib/migrate/sql/agent/agent_session_down_03.sql b/lib/migrate/sql/agent/agent_session_down_03.sql new file mode 100644 index 0000000..9fc5da4 --- /dev/null +++ b/lib/migrate/sql/agent/agent_session_down_03.sql @@ -0,0 +1,7 @@ +ALTER TABLE agent_session + DROP COLUMN IF EXISTS source, + DROP COLUMN IF EXISTS parent_session_id, + DROP COLUMN IF EXISTS toolset_json, + DROP COLUMN IF EXISTS memory_provider, + DROP COLUMN IF EXISTS memory_provider_config, + DROP COLUMN IF EXISTS iteration_budget; diff --git a/lib/migrate/sql/agent/agent_session_up_01.sql b/lib/migrate/sql/agent/agent_session_up_01.sql new file mode 100644 index 0000000..a5100ec --- /dev/null +++ b/lib/migrate/sql/agent/agent_session_up_01.sql @@ -0,0 +1,24 @@ +CREATE TABLE IF NOT EXISTS agent_session ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "user" UUID, + wk UUID, + name TEXT NOT NULL, + description TEXT, + agent_kind TEXT NOT NULL, + model_version UUID, + system_prompt TEXT, + temperature REAL, + max_output_tokens INTEGER, + tool_policy TEXT, + knowledge_base_ids TEXT, + variables TEXT, + visibility TEXT NOT NULL DEFAULT 'private', + version INTEGER NOT NULL DEFAULT 1, + published_at TIMESTAMPTZ, + rollback_from_version INTEGER, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/agent/agent_session_up_03.sql b/lib/migrate/sql/agent/agent_session_up_03.sql new file mode 100644 index 0000000..ddf7d76 --- /dev/null +++ b/lib/migrate/sql/agent/agent_session_up_03.sql @@ -0,0 +1,8 @@ +-- depends_on:agent_conversation +ALTER TABLE agent_session + ADD COLUMN IF NOT EXISTS source TEXT NOT NULL DEFAULT 'api', + ADD COLUMN IF NOT EXISTS parent_session_id UUID REFERENCES agent_session(id), + ADD COLUMN IF NOT EXISTS toolset_json TEXT, + ADD COLUMN IF NOT EXISTS memory_provider TEXT, + ADD COLUMN IF NOT EXISTS memory_provider_config TEXT, + ADD COLUMN IF NOT EXISTS iteration_budget INTEGER NOT NULL DEFAULT 90; diff --git a/lib/migrate/sql/agent/agent_subagent_session_down_01.sql b/lib/migrate/sql/agent/agent_subagent_session_down_01.sql new file mode 100644 index 0000000..78dd522 --- /dev/null +++ b/lib/migrate/sql/agent/agent_subagent_session_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_subagent_session CASCADE; diff --git a/lib/migrate/sql/agent/agent_subagent_session_up_01.sql b/lib/migrate/sql/agent/agent_subagent_session_up_01.sql new file mode 100644 index 0000000..f390044 --- /dev/null +++ b/lib/migrate/sql/agent/agent_subagent_session_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS agent_subagent_session ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + parent_session UUID NOT NULL, + child_session UUID NOT NULL, + name TEXT NOT NULL, + purpose TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + ended_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/agent/agent_token_usage_down_01.sql b/lib/migrate/sql/agent/agent_token_usage_down_01.sql new file mode 100644 index 0000000..30ade9c --- /dev/null +++ b/lib/migrate/sql/agent/agent_token_usage_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_token_usage CASCADE; diff --git a/lib/migrate/sql/agent/agent_token_usage_down_02.sql b/lib/migrate/sql/agent/agent_token_usage_down_02.sql new file mode 100644 index 0000000..04967cd --- /dev/null +++ b/lib/migrate/sql/agent/agent_token_usage_down_02.sql @@ -0,0 +1,4 @@ +ALTER TABLE agent_token_usage + DROP COLUMN IF EXISTS reasoning_tokens, + DROP COLUMN IF EXISTS cache_write_tokens, + DROP COLUMN IF EXISTS cache_read_tokens; diff --git a/lib/migrate/sql/agent/agent_token_usage_up_01.sql b/lib/migrate/sql/agent/agent_token_usage_up_01.sql new file mode 100644 index 0000000..5d24f02 --- /dev/null +++ b/lib/migrate/sql/agent/agent_token_usage_up_01.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS agent_token_usage ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + invocation UUID NOT NULL, + session UUID NOT NULL, + model_version UUID NOT NULL, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + cached_input_tokens BIGINT NOT NULL DEFAULT 0, + total_tokens BIGINT NOT NULL DEFAULT 0, + cost NUMERIC, + currency TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/agent/agent_token_usage_up_02.sql b/lib/migrate/sql/agent/agent_token_usage_up_02.sql new file mode 100644 index 0000000..46d14e8 --- /dev/null +++ b/lib/migrate/sql/agent/agent_token_usage_up_02.sql @@ -0,0 +1,4 @@ +ALTER TABLE agent_token_usage + ADD COLUMN IF NOT EXISTS cache_read_tokens BIGINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS cache_write_tokens BIGINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS reasoning_tokens BIGINT NOT NULL DEFAULT 0; diff --git a/lib/migrate/sql/agent/agent_tool_call_log_down_01.sql b/lib/migrate/sql/agent/agent_tool_call_log_down_01.sql new file mode 100644 index 0000000..269dd62 --- /dev/null +++ b/lib/migrate/sql/agent/agent_tool_call_log_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_tool_call_log CASCADE; diff --git a/lib/migrate/sql/agent/agent_tool_call_log_tool_id_down_02.sql b/lib/migrate/sql/agent/agent_tool_call_log_tool_id_down_02.sql new file mode 100644 index 0000000..9b678a0 --- /dev/null +++ b/lib/migrate/sql/agent/agent_tool_call_log_tool_id_down_02.sql @@ -0,0 +1,2 @@ +ALTER TABLE agent_tool_call_log + DROP COLUMN IF EXISTS tool_call_id; diff --git a/lib/migrate/sql/agent/agent_tool_call_log_tool_id_up_02.sql b/lib/migrate/sql/agent/agent_tool_call_log_tool_id_up_02.sql new file mode 100644 index 0000000..264f2bc --- /dev/null +++ b/lib/migrate/sql/agent/agent_tool_call_log_tool_id_up_02.sql @@ -0,0 +1,2 @@ +ALTER TABLE agent_tool_call_log + ADD COLUMN IF NOT EXISTS tool_call_id TEXT; diff --git a/lib/migrate/sql/agent/agent_tool_call_log_up_01.sql b/lib/migrate/sql/agent/agent_tool_call_log_up_01.sql new file mode 100644 index 0000000..a285d93 --- /dev/null +++ b/lib/migrate/sql/agent/agent_tool_call_log_up_01.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS agent_tool_call_log ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + invocation UUID, + session UUID NOT NULL, + conversation UUID, + message UUID, + tool_call_id TEXT, + tool_name TEXT NOT NULL, + arguments TEXT, + result TEXT, + error TEXT, + status TEXT NOT NULL DEFAULT 'pending', + started_at TIMESTAMPTZ NOT NULL DEFAULT now(), + finished_at TIMESTAMPTZ, + latency_ms BIGINT +); diff --git a/lib/migrate/sql/agent/agent_trace_down_01.sql b/lib/migrate/sql/agent/agent_trace_down_01.sql new file mode 100644 index 0000000..831b700 --- /dev/null +++ b/lib/migrate/sql/agent/agent_trace_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS agent_trace CASCADE; diff --git a/lib/migrate/sql/agent/agent_trace_up_01.sql b/lib/migrate/sql/agent/agent_trace_up_01.sql new file mode 100644 index 0000000..99d2e45 --- /dev/null +++ b/lib/migrate/sql/agent/agent_trace_up_01.sql @@ -0,0 +1,17 @@ +CREATE TABLE IF NOT EXISTS agent_trace ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + invocation UUID NOT NULL REFERENCES agent_model_invocation(id), + conversation UUID NOT NULL, + sequence INT NOT NULL, + phase TEXT NOT NULL CHECK (phase IN ('think', 'answer', 'act', 'summarize')), + content TEXT, + tool_calls JSONB, + tool_results JSONB, + input_tokens BIGINT, + output_tokens BIGINT, + metadata JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_agent_trace_invocation ON agent_trace(invocation); +CREATE INDEX IF NOT EXISTS idx_agent_trace_conversation ON agent_trace(conversation, sequence); diff --git a/lib/migrate/sql/agent/agent_z_model_invocations_compat_down_01.sql b/lib/migrate/sql/agent/agent_z_model_invocations_compat_down_01.sql new file mode 100644 index 0000000..4b42353 --- /dev/null +++ b/lib/migrate/sql/agent/agent_z_model_invocations_compat_down_01.sql @@ -0,0 +1,9 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 FROM pg_class + WHERE oid = to_regclass('agent_model_invocations') AND relkind = 'v' + ) THEN + DROP VIEW agent_model_invocations; + END IF; +END $$; diff --git a/lib/migrate/sql/agent/agent_z_model_invocations_compat_up_01.sql b/lib/migrate/sql/agent/agent_z_model_invocations_compat_up_01.sql new file mode 100644 index 0000000..9af20df --- /dev/null +++ b/lib/migrate/sql/agent/agent_z_model_invocations_compat_up_01.sql @@ -0,0 +1,21 @@ +DO $$ +BEGIN + IF to_regclass('agent_model_invocations') IS NULL THEN + CREATE VIEW agent_model_invocations AS + SELECT + id, + session, + conversation, + message, + model_version, + request_id, + status, + prompt, + response, + error, + started_at, + finished_at, + latency_ms + FROM agent_model_invocation; + END IF; +END $$; diff --git a/lib/migrate/sql/ai/ai_model_card_down_01.sql b/lib/migrate/sql/ai/ai_model_card_down_01.sql new file mode 100644 index 0000000..4763729 --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_card_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_model_card CASCADE; diff --git a/lib/migrate/sql/ai/ai_model_card_up_01.sql b/lib/migrate/sql/ai/ai_model_card_up_01.sql new file mode 100644 index 0000000..cd1d4ec --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_card_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS ai_model_card ( + model UUID PRIMARY KEY, + overview TEXT, + strengths TEXT, + limitations TEXT, + safety_notes TEXT, + eval_summary TEXT, + metadata TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/ai/ai_model_discussion_down_01.sql b/lib/migrate/sql/ai/ai_model_discussion_down_01.sql new file mode 100644 index 0000000..254d7ac --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_discussion_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_model_discussion CASCADE; diff --git a/lib/migrate/sql/ai/ai_model_discussion_up_01.sql b/lib/migrate/sql/ai/ai_model_discussion_up_01.sql new file mode 100644 index 0000000..16956b1 --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_discussion_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS ai_model_discussion ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + model UUID NOT NULL, + "user" UUID NOT NULL, + parent UUID, + body TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/ai/ai_model_down_01.sql b/lib/migrate/sql/ai/ai_model_down_01.sql new file mode 100644 index 0000000..ccf5a13 --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_model CASCADE; diff --git a/lib/migrate/sql/ai/ai_model_like_down_01.sql b/lib/migrate/sql/ai/ai_model_like_down_01.sql new file mode 100644 index 0000000..2bbec5d --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_like_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_model_like CASCADE; diff --git a/lib/migrate/sql/ai/ai_model_like_up_01.sql b/lib/migrate/sql/ai/ai_model_like_up_01.sql new file mode 100644 index 0000000..fd626ca --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_like_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS ai_model_like ( + model UUID NOT NULL, + "user" UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (model, "user") +); diff --git a/lib/migrate/sql/ai/ai_model_model_tag_down_01.sql b/lib/migrate/sql/ai/ai_model_model_tag_down_01.sql new file mode 100644 index 0000000..583b04a --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_model_tag_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_model_model_tag CASCADE; diff --git a/lib/migrate/sql/ai/ai_model_model_tag_up_01.sql b/lib/migrate/sql/ai/ai_model_model_tag_up_01.sql new file mode 100644 index 0000000..1ab23fa --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_model_tag_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS ai_model_model_tag ( + model UUID NOT NULL, + tag TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (model, tag) +); diff --git a/lib/migrate/sql/ai/ai_model_up_01.sql b/lib/migrate/sql/ai/ai_model_up_01.sql new file mode 100644 index 0000000..a23c766 --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_up_01.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS ai_model ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + provider UUID NOT NULL, + name TEXT NOT NULL, + display_name TEXT NOT NULL, + description TEXT, + modality TEXT NOT NULL, + context_window INTEGER, + input_token_limit INTEGER, + output_token_limit INTEGER, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + public BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/ai/ai_model_version_down_01.sql b/lib/migrate/sql/ai/ai_model_version_down_01.sql new file mode 100644 index 0000000..033ebef --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_version_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_model_version CASCADE; diff --git a/lib/migrate/sql/ai/ai_model_version_up_01.sql b/lib/migrate/sql/ai/ai_model_version_up_01.sql new file mode 100644 index 0000000..cb839ca --- /dev/null +++ b/lib/migrate/sql/ai/ai_model_version_up_01.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS ai_model_version ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + model UUID NOT NULL, + version TEXT NOT NULL, + provider_model_name TEXT NOT NULL, + input_price_per_million NUMERIC, + output_price_per_million NUMERIC, + cached_input_price_per_million NUMERIC, + training_cutoff TEXT, + released_at TIMESTAMPTZ, + deprecated_at TIMESTAMPTZ, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/ai/ai_provider_down_01.sql b/lib/migrate/sql/ai/ai_provider_down_01.sql new file mode 100644 index 0000000..ed37ba4 --- /dev/null +++ b/lib/migrate/sql/ai/ai_provider_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_provider CASCADE; diff --git a/lib/migrate/sql/ai/ai_provider_up_01.sql b/lib/migrate/sql/ai/ai_provider_up_01.sql new file mode 100644 index 0000000..32e69c5 --- /dev/null +++ b/lib/migrate/sql/ai/ai_provider_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS ai_provider ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + base_url TEXT, + website_url TEXT, + logo_url TEXT, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/issue/issue_assignee_down_01.sql b/lib/migrate/sql/issue/issue_assignee_down_01.sql new file mode 100644 index 0000000..3df6a6b --- /dev/null +++ b/lib/migrate/sql/issue/issue_assignee_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_assignee CASCADE; diff --git a/lib/migrate/sql/issue/issue_assignee_up_01.sql b/lib/migrate/sql/issue/issue_assignee_up_01.sql new file mode 100644 index 0000000..b19739a --- /dev/null +++ b/lib/migrate/sql/issue/issue_assignee_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS issue_assignee ( + issue UUID NOT NULL REFERENCES issue(id), + "user" UUID NOT NULL, + assigned_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (issue, "user") +); diff --git a/lib/migrate/sql/issue/issue_comment_down_01.sql b/lib/migrate/sql/issue/issue_comment_down_01.sql new file mode 100644 index 0000000..b889eb4 --- /dev/null +++ b/lib/migrate/sql/issue/issue_comment_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_comment CASCADE; diff --git a/lib/migrate/sql/issue/issue_comment_up_01.sql b/lib/migrate/sql/issue/issue_comment_up_01.sql new file mode 100644 index 0000000..c5e28d4 --- /dev/null +++ b/lib/migrate/sql/issue/issue_comment_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS issue_comment ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + issue UUID NOT NULL REFERENCES issue(id), + author UUID NOT NULL, + body TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/issue/issue_down_01.sql b/lib/migrate/sql/issue/issue_down_01.sql new file mode 100644 index 0000000..896d131 --- /dev/null +++ b/lib/migrate/sql/issue/issue_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue CASCADE; diff --git a/lib/migrate/sql/issue/issue_event_down_01.sql b/lib/migrate/sql/issue/issue_event_down_01.sql new file mode 100644 index 0000000..e90988c --- /dev/null +++ b/lib/migrate/sql/issue/issue_event_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_event CASCADE; diff --git a/lib/migrate/sql/issue/issue_event_up_01.sql b/lib/migrate/sql/issue/issue_event_up_01.sql new file mode 100644 index 0000000..256a033 --- /dev/null +++ b/lib/migrate/sql/issue/issue_event_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS issue_event ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + issue UUID NOT NULL REFERENCES issue(id), + actor UUID, + event TEXT NOT NULL, + from_value TEXT, + to_value TEXT, + metadata TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/issue/issue_label_down_01.sql b/lib/migrate/sql/issue/issue_label_down_01.sql new file mode 100644 index 0000000..f39be4d --- /dev/null +++ b/lib/migrate/sql/issue/issue_label_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_label CASCADE; diff --git a/lib/migrate/sql/issue/issue_label_up_01.sql b/lib/migrate/sql/issue/issue_label_up_01.sql new file mode 100644 index 0000000..761cc37 --- /dev/null +++ b/lib/migrate/sql/issue/issue_label_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS issue_label ( + issue UUID NOT NULL REFERENCES issue(id), + label UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (issue, label) +); diff --git a/lib/migrate/sql/issue/issue_milestone_down_01.sql b/lib/migrate/sql/issue/issue_milestone_down_01.sql new file mode 100644 index 0000000..b8104ad --- /dev/null +++ b/lib/migrate/sql/issue/issue_milestone_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_milestone CASCADE; diff --git a/lib/migrate/sql/issue/issue_milestone_up_01.sql b/lib/migrate/sql/issue/issue_milestone_up_01.sql new file mode 100644 index 0000000..318f778 --- /dev/null +++ b/lib/migrate/sql/issue/issue_milestone_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS issue_milestone ( + issue UUID NOT NULL REFERENCES issue(id), + milestone UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (issue, milestone) +); diff --git a/lib/migrate/sql/issue/issue_pull_request_down_01.sql b/lib/migrate/sql/issue/issue_pull_request_down_01.sql new file mode 100644 index 0000000..447ebf9 --- /dev/null +++ b/lib/migrate/sql/issue/issue_pull_request_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_pull_request CASCADE; diff --git a/lib/migrate/sql/issue/issue_pull_request_up_01.sql b/lib/migrate/sql/issue/issue_pull_request_up_01.sql new file mode 100644 index 0000000..a64f9b6 --- /dev/null +++ b/lib/migrate/sql/issue/issue_pull_request_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS issue_pull_request ( + issue UUID NOT NULL REFERENCES issue(id), + pull_request UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (issue, pull_request) +); diff --git a/lib/migrate/sql/issue/issue_reaction_down_01.sql b/lib/migrate/sql/issue/issue_reaction_down_01.sql new file mode 100644 index 0000000..58957ab --- /dev/null +++ b/lib/migrate/sql/issue/issue_reaction_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_reaction CASCADE; diff --git a/lib/migrate/sql/issue/issue_reaction_up_01.sql b/lib/migrate/sql/issue/issue_reaction_up_01.sql new file mode 100644 index 0000000..b239cef --- /dev/null +++ b/lib/migrate/sql/issue/issue_reaction_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS issue_reaction ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + issue UUID NOT NULL REFERENCES issue(id), + comment UUID REFERENCES issue_comment(id), + "user" UUID NOT NULL, + reaction TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/issue/issue_reference_down_01.sql b/lib/migrate/sql/issue/issue_reference_down_01.sql new file mode 100644 index 0000000..c70af9b --- /dev/null +++ b/lib/migrate/sql/issue/issue_reference_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_reference CASCADE; diff --git a/lib/migrate/sql/issue/issue_reference_up_01.sql b/lib/migrate/sql/issue/issue_reference_up_01.sql new file mode 100644 index 0000000..6687f5f --- /dev/null +++ b/lib/migrate/sql/issue/issue_reference_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS issue_reference ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + issue UUID NOT NULL REFERENCES issue(id), + target_type TEXT NOT NULL, + target_id UUID NOT NULL, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/issue/issue_repo_down_01.sql b/lib/migrate/sql/issue/issue_repo_down_01.sql new file mode 100644 index 0000000..793e8e0 --- /dev/null +++ b/lib/migrate/sql/issue/issue_repo_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS issue_repo CASCADE; diff --git a/lib/migrate/sql/issue/issue_repo_up_01.sql b/lib/migrate/sql/issue/issue_repo_up_01.sql new file mode 100644 index 0000000..cf0355c --- /dev/null +++ b/lib/migrate/sql/issue/issue_repo_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS issue_repo ( + issue UUID NOT NULL REFERENCES issue(id), + repo UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (issue, repo) +); diff --git a/lib/migrate/sql/issue/issue_up_01.sql b/lib/migrate/sql/issue/issue_up_01.sql new file mode 100644 index 0000000..002af55 --- /dev/null +++ b/lib/migrate/sql/issue/issue_up_01.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS issue ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + number BIGINT NOT NULL, + title TEXT NOT NULL, + body TEXT, + state TEXT NOT NULL DEFAULT 'open', + priority TEXT NOT NULL DEFAULT 'medium', + author UUID NOT NULL, + closed_by UUID, + closed_at TIMESTAMPTZ, + due_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/issue/label_down_01.sql b/lib/migrate/sql/issue/label_down_01.sql new file mode 100644 index 0000000..ace0845 --- /dev/null +++ b/lib/migrate/sql/issue/label_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS label CASCADE; diff --git a/lib/migrate/sql/issue/label_up_01.sql b/lib/migrate/sql/issue/label_up_01.sql new file mode 100644 index 0000000..17ecb13 --- /dev/null +++ b/lib/migrate/sql/issue/label_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS label ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + name TEXT NOT NULL, + color TEXT NOT NULL, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/issue/milestone_down_01.sql b/lib/migrate/sql/issue/milestone_down_01.sql new file mode 100644 index 0000000..ba8a87a --- /dev/null +++ b/lib/migrate/sql/issue/milestone_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS milestone CASCADE; diff --git a/lib/migrate/sql/issue/milestone_up_01.sql b/lib/migrate/sql/issue/milestone_up_01.sql new file mode 100644 index 0000000..62f76c7 --- /dev/null +++ b/lib/migrate/sql/issue/milestone_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS milestone ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + title TEXT NOT NULL, + description TEXT, + state TEXT NOT NULL DEFAULT 'open', + due_at TIMESTAMPTZ, + closed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/notify/user_app_notify_down_01.sql b/lib/migrate/sql/notify/user_app_notify_down_01.sql new file mode 100644 index 0000000..dc2891c --- /dev/null +++ b/lib/migrate/sql/notify/user_app_notify_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_app_notify CASCADE; diff --git a/lib/migrate/sql/notify/user_app_notify_up_01.sql b/lib/migrate/sql/notify/user_app_notify_up_01.sql new file mode 100644 index 0000000..e85c49b --- /dev/null +++ b/lib/migrate/sql/notify/user_app_notify_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS user_app_notify ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "user" UUID NOT NULL, + title TEXT NOT NULL, + body TEXT NOT NULL, + notify_type TEXT NOT NULL, + target_type TEXT, + target_id UUID, + metadata TEXT, + read_at TIMESTAMPTZ, + archived_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/notify/user_email_notify_down_01.sql b/lib/migrate/sql/notify/user_email_notify_down_01.sql new file mode 100644 index 0000000..69fc324 --- /dev/null +++ b/lib/migrate/sql/notify/user_email_notify_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_email_notify CASCADE; diff --git a/lib/migrate/sql/notify/user_email_notify_up_01.sql b/lib/migrate/sql/notify/user_email_notify_up_01.sql new file mode 100644 index 0000000..f730b08 --- /dev/null +++ b/lib/migrate/sql/notify/user_email_notify_up_01.sql @@ -0,0 +1,23 @@ +CREATE TABLE IF NOT EXISTS user_email_notify ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "user" UUID NOT NULL, + email TEXT NOT NULL, + subject TEXT NOT NULL, + template TEXT NOT NULL, + body_text TEXT, + body_html TEXT, + notify_type TEXT NOT NULL, + target_type TEXT, + target_id UUID, + metadata TEXT, + status TEXT NOT NULL DEFAULT 'queued', + provider_message_id TEXT, + error TEXT, + retry_count INTEGER NOT NULL DEFAULT 0, + queued_at TIMESTAMPTZ NOT NULL DEFAULT now(), + sent_at TIMESTAMPTZ, + delivered_at TIMESTAMPTZ, + opened_at TIMESTAMPTZ, + clicked_at TIMESTAMPTZ, + failed_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/pull_request/pull_request_assignee_down_01.sql b/lib/migrate/sql/pull_request/pull_request_assignee_down_01.sql new file mode 100644 index 0000000..30878ae --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_assignee_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_assignee CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_assignee_up_01.sql b/lib/migrate/sql/pull_request/pull_request_assignee_up_01.sql new file mode 100644 index 0000000..2bb9031 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_assignee_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS pull_request_assignee ( + pull_request UUID NOT NULL REFERENCES pull_request(id), + "user" UUID NOT NULL, + assigned_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pull_request, "user") +); diff --git a/lib/migrate/sql/pull_request/pull_request_comment_down_01.sql b/lib/migrate/sql/pull_request/pull_request_comment_down_01.sql new file mode 100644 index 0000000..4e7a986 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_comment_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_comment CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_comment_up_01.sql b/lib/migrate/sql/pull_request/pull_request_comment_up_01.sql new file mode 100644 index 0000000..aca9d29 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_comment_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS pull_request_comment ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + pull_request UUID NOT NULL REFERENCES pull_request(id), + author UUID NOT NULL, + body TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/pull_request/pull_request_commit_down_01.sql b/lib/migrate/sql/pull_request/pull_request_commit_down_01.sql new file mode 100644 index 0000000..ac4901c --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_commit_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_commit CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_commit_up_01.sql b/lib/migrate/sql/pull_request/pull_request_commit_up_01.sql new file mode 100644 index 0000000..b97815e --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_commit_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS pull_request_commit ( + pull_request UUID NOT NULL REFERENCES pull_request(id), + commit UUID NOT NULL, + sha TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pull_request, commit) +); diff --git a/lib/migrate/sql/pull_request/pull_request_down_01.sql b/lib/migrate/sql/pull_request/pull_request_down_01.sql new file mode 100644 index 0000000..8eff68d --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_label_down_01.sql b/lib/migrate/sql/pull_request/pull_request_label_down_01.sql new file mode 100644 index 0000000..70a7794 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_label_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_label CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_label_up_01.sql b/lib/migrate/sql/pull_request/pull_request_label_up_01.sql new file mode 100644 index 0000000..80f8873 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_label_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS pull_request_label ( + pull_request UUID NOT NULL REFERENCES pull_request(id), + label UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pull_request, label) +); diff --git a/lib/migrate/sql/pull_request/pull_request_reaction_down_01.sql b/lib/migrate/sql/pull_request/pull_request_reaction_down_01.sql new file mode 100644 index 0000000..5c1a290 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_reaction_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_reaction CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_reaction_up_01.sql b/lib/migrate/sql/pull_request/pull_request_reaction_up_01.sql new file mode 100644 index 0000000..302e586 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_reaction_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS pull_request_reaction ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + pull_request UUID NOT NULL REFERENCES pull_request(id), + comment UUID REFERENCES pull_request_comment(id), + "user" UUID NOT NULL, + reaction TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/pull_request/pull_request_review_comment_down_01.sql b/lib/migrate/sql/pull_request/pull_request_review_comment_down_01.sql new file mode 100644 index 0000000..2aba044 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_comment_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_review_comment CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_review_comment_up_01.sql b/lib/migrate/sql/pull_request/pull_request_review_comment_up_01.sql new file mode 100644 index 0000000..e9c9e37 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_comment_up_01.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS pull_request_review_comment ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + pull_request UUID NOT NULL REFERENCES pull_request(id), + review UUID REFERENCES pull_request_review(id), + author UUID NOT NULL, + body TEXT NOT NULL, + path TEXT NOT NULL, + commit_sha TEXT NOT NULL, + original_commit_sha TEXT, + line INTEGER, + original_line INTEGER, + side TEXT, + resolved BOOLEAN NOT NULL DEFAULT FALSE, + resolved_by UUID, + resolved_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/pull_request/pull_request_review_down_01.sql b/lib/migrate/sql/pull_request/pull_request_review_down_01.sql new file mode 100644 index 0000000..75be7f4 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_review CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_review_reaction_down_01.sql b/lib/migrate/sql/pull_request/pull_request_review_reaction_down_01.sql new file mode 100644 index 0000000..7b7afde --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_reaction_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_review_reaction CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_review_reaction_up_01.sql b/lib/migrate/sql/pull_request/pull_request_review_reaction_up_01.sql new file mode 100644 index 0000000..55fa418 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_reaction_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS pull_request_review_reaction ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + review_comment UUID NOT NULL REFERENCES pull_request_review_comment(id), + "user" UUID NOT NULL, + reaction TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/pull_request/pull_request_review_request_down_01.sql b/lib/migrate/sql/pull_request/pull_request_review_request_down_01.sql new file mode 100644 index 0000000..64d9f3c --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_request_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS pull_request_review_request CASCADE; diff --git a/lib/migrate/sql/pull_request/pull_request_review_request_up_01.sql b/lib/migrate/sql/pull_request/pull_request_review_request_up_01.sql new file mode 100644 index 0000000..f2d4a3b --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_request_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS pull_request_review_request ( + pull_request UUID NOT NULL REFERENCES pull_request(id), + reviewer UUID, + "group" UUID, + requested_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + removed_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/pull_request/pull_request_review_up_01.sql b/lib/migrate/sql/pull_request/pull_request_review_up_01.sql new file mode 100644 index 0000000..ed143f9 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_review_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS pull_request_review ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + pull_request UUID NOT NULL REFERENCES pull_request(id), + reviewer UUID NOT NULL, + state TEXT NOT NULL, + body TEXT, + commit_sha TEXT, + submitted_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + dismissed_by UUID, + dismissed_at TIMESTAMPTZ, + dismiss_reason TEXT +); diff --git a/lib/migrate/sql/pull_request/pull_request_up_01.sql b/lib/migrate/sql/pull_request/pull_request_up_01.sql new file mode 100644 index 0000000..f1a0379 --- /dev/null +++ b/lib/migrate/sql/pull_request/pull_request_up_01.sql @@ -0,0 +1,22 @@ +CREATE TABLE IF NOT EXISTS pull_request ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL, + number BIGINT NOT NULL, + title TEXT NOT NULL, + body TEXT, + state TEXT NOT NULL DEFAULT 'open', + draft BOOLEAN NOT NULL DEFAULT FALSE, + author UUID NOT NULL, + source_repo UUID NOT NULL, + source_branch TEXT NOT NULL, + source_sha TEXT NOT NULL, + target_branch TEXT NOT NULL, + target_sha TEXT NOT NULL, + merged_by UUID, + merged_at TIMESTAMPTZ, + closed_by UUID, + closed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/repo/repo_audit_log_down_01.sql b/lib/migrate/sql/repo/repo_audit_log_down_01.sql new file mode 100644 index 0000000..97324f5 --- /dev/null +++ b/lib/migrate/sql/repo/repo_audit_log_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_audit_log CASCADE; diff --git a/lib/migrate/sql/repo/repo_audit_log_up_01.sql b/lib/migrate/sql/repo/repo_audit_log_up_01.sql new file mode 100644 index 0000000..7a5f269 --- /dev/null +++ b/lib/migrate/sql/repo/repo_audit_log_up_01.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS repo_audit_log ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + actor UUID, + action TEXT NOT NULL, + target_type TEXT NOT NULL, + target_id TEXT, + ip_address TEXT, + user_agent TEXT, + metadata TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_commit_down_01.sql b/lib/migrate/sql/repo/repo_commit_down_01.sql new file mode 100644 index 0000000..f5c284d --- /dev/null +++ b/lib/migrate/sql/repo/repo_commit_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_commit CASCADE; diff --git a/lib/migrate/sql/repo/repo_commit_status_down_01.sql b/lib/migrate/sql/repo/repo_commit_status_down_01.sql new file mode 100644 index 0000000..5330ecf --- /dev/null +++ b/lib/migrate/sql/repo/repo_commit_status_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_commit_status; \ No newline at end of file diff --git a/lib/migrate/sql/repo/repo_commit_status_up_01.sql b/lib/migrate/sql/repo/repo_commit_status_up_01.sql new file mode 100644 index 0000000..57ae5b7 --- /dev/null +++ b/lib/migrate/sql/repo/repo_commit_status_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS repo_commit_status ( + id UUID PRIMARY KEY, + repo UUID NOT NULL REFERENCES repo(id) ON DELETE CASCADE, + commit_sha TEXT NOT NULL, + state TEXT NOT NULL CHECK (state IN ('pending', 'success', 'failure', 'error')), + target_url TEXT, + description TEXT, + context TEXT NOT NULL DEFAULT 'default', + creator UUID NOT NULL REFERENCES "user"(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); +CREATE INDEX IF NOT EXISTS idx_commit_status_repo_sha ON repo_commit_status(repo, commit_sha); +CREATE INDEX IF NOT EXISTS idx_commit_status_context ON repo_commit_status(repo, commit_sha, context); diff --git a/lib/migrate/sql/repo/repo_commit_up_01.sql b/lib/migrate/sql/repo/repo_commit_up_01.sql new file mode 100644 index 0000000..936f465 --- /dev/null +++ b/lib/migrate/sql/repo/repo_commit_up_01.sql @@ -0,0 +1,14 @@ +-- depends_on: repo_committer +CREATE TABLE IF NOT EXISTS repo_commit ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + sha TEXT NOT NULL, + tree_sha TEXT NOT NULL, + parent_shas TEXT NOT NULL, + author UUID NOT NULL REFERENCES repo_committer(id), + committer UUID NOT NULL REFERENCES repo_committer(id), + message TEXT NOT NULL, + authored_at TIMESTAMPTZ NOT NULL, + committed_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); \ No newline at end of file diff --git a/lib/migrate/sql/repo/repo_committer_down_01.sql b/lib/migrate/sql/repo/repo_committer_down_01.sql new file mode 100644 index 0000000..df58380 --- /dev/null +++ b/lib/migrate/sql/repo/repo_committer_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_committer CASCADE; diff --git a/lib/migrate/sql/repo/repo_committer_up_01.sql b/lib/migrate/sql/repo/repo_committer_up_01.sql new file mode 100644 index 0000000..de2cd4f --- /dev/null +++ b/lib/migrate/sql/repo/repo_committer_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS repo_committer ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + "user" UUID, + name TEXT NOT NULL, + email TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_deploy_key_down_01.sql b/lib/migrate/sql/repo/repo_deploy_key_down_01.sql new file mode 100644 index 0000000..c43b410 --- /dev/null +++ b/lib/migrate/sql/repo/repo_deploy_key_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_deploy_key CASCADE; diff --git a/lib/migrate/sql/repo/repo_deploy_key_up_01.sql b/lib/migrate/sql/repo/repo_deploy_key_up_01.sql new file mode 100644 index 0000000..c4d84db --- /dev/null +++ b/lib/migrate/sql/repo/repo_deploy_key_up_01.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS repo_deploy_key ( + id BIGSERIAL PRIMARY KEY, + repo UUID NOT NULL REFERENCES repo(id), + title TEXT NOT NULL, + public_key TEXT NOT NULL, + fingerprint TEXT NOT NULL, + key_type TEXT NOT NULL, + key_bits INTEGER, + read_only BOOLEAN NOT NULL DEFAULT TRUE, + last_used_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ, + is_revoked BOOLEAN NOT NULL DEFAULT FALSE, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_down_01.sql b/lib/migrate/sql/repo/repo_down_01.sql new file mode 100644 index 0000000..3942f1c --- /dev/null +++ b/lib/migrate/sql/repo/repo_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo CASCADE; diff --git a/lib/migrate/sql/repo/repo_fork_down_01.sql b/lib/migrate/sql/repo/repo_fork_down_01.sql new file mode 100644 index 0000000..56a017b --- /dev/null +++ b/lib/migrate/sql/repo/repo_fork_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_fork CASCADE; diff --git a/lib/migrate/sql/repo/repo_fork_up_01.sql b/lib/migrate/sql/repo/repo_fork_up_01.sql new file mode 100644 index 0000000..1360048 --- /dev/null +++ b/lib/migrate/sql/repo/repo_fork_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS repo_fork ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + source_repo UUID NOT NULL, + forked_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_history_name_down_01.sql b/lib/migrate/sql/repo/repo_history_name_down_01.sql new file mode 100644 index 0000000..76b515c --- /dev/null +++ b/lib/migrate/sql/repo/repo_history_name_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_history_name CASCADE; diff --git a/lib/migrate/sql/repo/repo_history_name_up_01.sql b/lib/migrate/sql/repo/repo_history_name_up_01.sql new file mode 100644 index 0000000..07476d1 --- /dev/null +++ b/lib/migrate/sql/repo/repo_history_name_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS repo_history_name ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + name TEXT NOT NULL, + changed_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_language_down_01.sql b/lib/migrate/sql/repo/repo_language_down_01.sql new file mode 100644 index 0000000..a716609 --- /dev/null +++ b/lib/migrate/sql/repo/repo_language_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_language CASCADE; diff --git a/lib/migrate/sql/repo/repo_language_up_01.sql b/lib/migrate/sql/repo/repo_language_up_01.sql new file mode 100644 index 0000000..46dee54 --- /dev/null +++ b/lib/migrate/sql/repo/repo_language_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS repo_language ( + repo UUID NOT NULL REFERENCES repo(id), + language TEXT NOT NULL, + bytes BIGINT NOT NULL DEFAULT 0, + percentage REAL NOT NULL DEFAULT 0, + PRIMARY KEY (repo, language) +); diff --git a/lib/migrate/sql/repo/repo_lfs_lock_down_01.sql b/lib/migrate/sql/repo/repo_lfs_lock_down_01.sql new file mode 100644 index 0000000..504f7fb --- /dev/null +++ b/lib/migrate/sql/repo/repo_lfs_lock_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_lfs_lock CASCADE; diff --git a/lib/migrate/sql/repo/repo_lfs_lock_up_01.sql b/lib/migrate/sql/repo/repo_lfs_lock_up_01.sql new file mode 100644 index 0000000..67ba80d --- /dev/null +++ b/lib/migrate/sql/repo/repo_lfs_lock_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS repo_lfs_lock ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + path TEXT NOT NULL, + locked_by UUID NOT NULL, + ref_name TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_lfs_object_down_01.sql b/lib/migrate/sql/repo/repo_lfs_object_down_01.sql new file mode 100644 index 0000000..ce29cad --- /dev/null +++ b/lib/migrate/sql/repo/repo_lfs_object_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_lfs_object CASCADE; diff --git a/lib/migrate/sql/repo/repo_lfs_object_up_01.sql b/lib/migrate/sql/repo/repo_lfs_object_up_01.sql new file mode 100644 index 0000000..5a33993 --- /dev/null +++ b/lib/migrate/sql/repo/repo_lfs_object_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS repo_lfs_object ( + repo UUID NOT NULL REFERENCES repo(id), + oid TEXT NOT NULL, + size_bytes BIGINT NOT NULL DEFAULT 0, + storage_key TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (repo, oid) +); diff --git a/lib/migrate/sql/repo/repo_license_down_01.sql b/lib/migrate/sql/repo/repo_license_down_01.sql new file mode 100644 index 0000000..0aceb8c --- /dev/null +++ b/lib/migrate/sql/repo/repo_license_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_license CASCADE; diff --git a/lib/migrate/sql/repo/repo_license_up_01.sql b/lib/migrate/sql/repo/repo_license_up_01.sql new file mode 100644 index 0000000..958beef --- /dev/null +++ b/lib/migrate/sql/repo/repo_license_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS repo_license ( + repo UUID PRIMARY KEY REFERENCES repo(id), + spdx_id TEXT, + name TEXT NOT NULL, + url TEXT, + detected_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_lock_down_01.sql b/lib/migrate/sql/repo/repo_lock_down_01.sql new file mode 100644 index 0000000..f0f56ea --- /dev/null +++ b/lib/migrate/sql/repo/repo_lock_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_lock CASCADE; diff --git a/lib/migrate/sql/repo/repo_lock_up_01.sql b/lib/migrate/sql/repo/repo_lock_up_01.sql new file mode 100644 index 0000000..1166359 --- /dev/null +++ b/lib/migrate/sql/repo/repo_lock_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS repo_lock ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + locked_by UUID NOT NULL, + reason TEXT NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + released_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/repo/repo_protect_down_01.sql b/lib/migrate/sql/repo/repo_protect_down_01.sql new file mode 100644 index 0000000..cc61b65 --- /dev/null +++ b/lib/migrate/sql/repo/repo_protect_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_protect CASCADE; diff --git a/lib/migrate/sql/repo/repo_protect_up_01.sql b/lib/migrate/sql/repo/repo_protect_up_01.sql new file mode 100644 index 0000000..3967716 --- /dev/null +++ b/lib/migrate/sql/repo/repo_protect_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS repo_protect ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + pattern TEXT NOT NULL, + require_pull_request BOOLEAN NOT NULL DEFAULT TRUE, + required_approvals INTEGER NOT NULL DEFAULT 1, + require_status_checks BOOLEAN NOT NULL DEFAULT FALSE, + required_status_contexts TEXT NOT NULL DEFAULT '', + enforce_admins BOOLEAN NOT NULL DEFAULT FALSE, + allow_force_pushes BOOLEAN NOT NULL DEFAULT FALSE, + allow_deletions BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_ref_down_01.sql b/lib/migrate/sql/repo/repo_ref_down_01.sql new file mode 100644 index 0000000..137f78d --- /dev/null +++ b/lib/migrate/sql/repo/repo_ref_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_ref CASCADE; diff --git a/lib/migrate/sql/repo/repo_ref_up_01.sql b/lib/migrate/sql/repo/repo_ref_up_01.sql new file mode 100644 index 0000000..6a7451c --- /dev/null +++ b/lib/migrate/sql/repo/repo_ref_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS repo_ref ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + name TEXT NOT NULL, + kind TEXT NOT NULL, + target_sha TEXT NOT NULL, + is_default BOOLEAN NOT NULL DEFAULT FALSE, + is_protected BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_release_asset_down_01.sql b/lib/migrate/sql/repo/repo_release_asset_down_01.sql new file mode 100644 index 0000000..6c228e2 --- /dev/null +++ b/lib/migrate/sql/repo/repo_release_asset_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_release_asset; \ No newline at end of file diff --git a/lib/migrate/sql/repo/repo_release_asset_up_01.sql b/lib/migrate/sql/repo/repo_release_asset_up_01.sql new file mode 100644 index 0000000..15b94df --- /dev/null +++ b/lib/migrate/sql/repo/repo_release_asset_up_01.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS repo_release_asset ( + id UUID PRIMARY KEY, + release_id UUID NOT NULL REFERENCES repo_release(id) ON DELETE CASCADE, + name TEXT NOT NULL, + content_type TEXT, + size BIGINT NOT NULL DEFAULT 0, + download_count BIGINT NOT NULL DEFAULT 0, + storage_path TEXT NOT NULL, + uploader UUID NOT NULL REFERENCES "user"(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); +CREATE INDEX IF NOT EXISTS idx_release_asset_release ON repo_release_asset(release_id); diff --git a/lib/migrate/sql/repo/repo_release_down_01.sql b/lib/migrate/sql/repo/repo_release_down_01.sql new file mode 100644 index 0000000..45d4f1c --- /dev/null +++ b/lib/migrate/sql/repo/repo_release_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_release CASCADE; diff --git a/lib/migrate/sql/repo/repo_release_up_01.sql b/lib/migrate/sql/repo/repo_release_up_01.sql new file mode 100644 index 0000000..0795efe --- /dev/null +++ b/lib/migrate/sql/repo/repo_release_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS repo_release ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + tag_name TEXT NOT NULL, + target_commit_sha TEXT NOT NULL, + name TEXT NOT NULL, + body TEXT, + draft BOOLEAN NOT NULL DEFAULT FALSE, + prerelease BOOLEAN NOT NULL DEFAULT FALSE, + author UUID NOT NULL, + published_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_secret_down_01.sql b/lib/migrate/sql/repo/repo_secret_down_01.sql new file mode 100644 index 0000000..26fcc54 --- /dev/null +++ b/lib/migrate/sql/repo/repo_secret_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_secret CASCADE; diff --git a/lib/migrate/sql/repo/repo_secret_up_01.sql b/lib/migrate/sql/repo/repo_secret_up_01.sql new file mode 100644 index 0000000..34c8e3f --- /dev/null +++ b/lib/migrate/sql/repo/repo_secret_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS repo_secret ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + name TEXT NOT NULL, + encrypted_value TEXT NOT NULL, + key_id TEXT NOT NULL, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_star_down_01.sql b/lib/migrate/sql/repo/repo_star_down_01.sql new file mode 100644 index 0000000..2b922eb --- /dev/null +++ b/lib/migrate/sql/repo/repo_star_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_star CASCADE; diff --git a/lib/migrate/sql/repo/repo_star_up_01.sql b/lib/migrate/sql/repo/repo_star_up_01.sql new file mode 100644 index 0000000..a5965e8 --- /dev/null +++ b/lib/migrate/sql/repo/repo_star_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS repo_star ( + repo UUID NOT NULL REFERENCES repo(id), + "user" UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (repo, "user") +); diff --git a/lib/migrate/sql/repo/repo_topic_down_01.sql b/lib/migrate/sql/repo/repo_topic_down_01.sql new file mode 100644 index 0000000..17bdf9c --- /dev/null +++ b/lib/migrate/sql/repo/repo_topic_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_topic CASCADE; diff --git a/lib/migrate/sql/repo/repo_topic_up_01.sql b/lib/migrate/sql/repo/repo_topic_up_01.sql new file mode 100644 index 0000000..69d9faf --- /dev/null +++ b/lib/migrate/sql/repo/repo_topic_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS repo_topic ( + repo UUID NOT NULL REFERENCES repo(id), + topic TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (repo, topic) +); diff --git a/lib/migrate/sql/repo/repo_up_01.sql b/lib/migrate/sql/repo/repo_up_01.sql new file mode 100644 index 0000000..6445bd8 --- /dev/null +++ b/lib/migrate/sql/repo/repo_up_01.sql @@ -0,0 +1,17 @@ +CREATE TABLE IF NOT EXISTS repo ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + name TEXT NOT NULL, + description TEXT, + default_branch TEXT NOT NULL DEFAULT 'main', + visibility TEXT NOT NULL DEFAULT 'private', + size_bytes BIGINT NOT NULL DEFAULT 0, + is_archived BOOLEAN NOT NULL DEFAULT FALSE, + is_template BOOLEAN NOT NULL DEFAULT FALSE, + is_mirror BOOLEAN NOT NULL DEFAULT FALSE, + created_by UUID NOT NULL, + storage_path TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/repo/repo_watch_down_01.sql b/lib/migrate/sql/repo/repo_watch_down_01.sql new file mode 100644 index 0000000..fe2b2f7 --- /dev/null +++ b/lib/migrate/sql/repo/repo_watch_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_watch CASCADE; diff --git a/lib/migrate/sql/repo/repo_watch_up_01.sql b/lib/migrate/sql/repo/repo_watch_up_01.sql new file mode 100644 index 0000000..e95dbb9 --- /dev/null +++ b/lib/migrate/sql/repo/repo_watch_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS repo_watch ( + repo UUID NOT NULL REFERENCES repo(id), + "user" UUID NOT NULL, + level TEXT NOT NULL DEFAULT 'participating', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (repo, "user") +); diff --git a/lib/migrate/sql/repo/repo_webhook_delivery_down_01.sql b/lib/migrate/sql/repo/repo_webhook_delivery_down_01.sql new file mode 100644 index 0000000..cc4f6e5 --- /dev/null +++ b/lib/migrate/sql/repo/repo_webhook_delivery_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_webhook_delivery CASCADE; diff --git a/lib/migrate/sql/repo/repo_webhook_delivery_up_01.sql b/lib/migrate/sql/repo/repo_webhook_delivery_up_01.sql new file mode 100644 index 0000000..20f35b6 --- /dev/null +++ b/lib/migrate/sql/repo/repo_webhook_delivery_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS repo_webhook_delivery ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + webhook UUID NOT NULL REFERENCES repo_webhook(id), + event TEXT NOT NULL, + request_headers TEXT, + request_body TEXT, + response_status INTEGER, + response_headers TEXT, + response_body TEXT, + error TEXT, + delivered_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/repo/repo_webhook_down_01.sql b/lib/migrate/sql/repo/repo_webhook_down_01.sql new file mode 100644 index 0000000..9326406 --- /dev/null +++ b/lib/migrate/sql/repo/repo_webhook_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS repo_webhook CASCADE; diff --git a/lib/migrate/sql/repo/repo_webhook_up_01.sql b/lib/migrate/sql/repo/repo_webhook_up_01.sql new file mode 100644 index 0000000..6f6d98a --- /dev/null +++ b/lib/migrate/sql/repo/repo_webhook_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS repo_webhook ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + repo UUID NOT NULL REFERENCES repo(id), + url TEXT NOT NULL, + secret_hash TEXT, + events TEXT NOT NULL, + active BOOLEAN NOT NULL DEFAULT TRUE, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/dm_conversation_down_01.sql b/lib/migrate/sql/room/dm_conversation_down_01.sql new file mode 100644 index 0000000..30e712e --- /dev/null +++ b/lib/migrate/sql/room/dm_conversation_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS dm_conversation; diff --git a/lib/migrate/sql/room/dm_conversation_up_01.sql b/lib/migrate/sql/room/dm_conversation_up_01.sql new file mode 100644 index 0000000..47e3551 --- /dev/null +++ b/lib/migrate/sql/room/dm_conversation_up_01.sql @@ -0,0 +1,17 @@ +-- depends_on: room +CREATE TABLE IF NOT EXISTS dm_conversation ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + room UUID NOT NULL REFERENCES room(id) ON DELETE CASCADE, + initiator UUID NOT NULL, + recipient UUID NOT NULL, + is_closed BOOLEAN NOT NULL DEFAULT FALSE, + closed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE (initiator, recipient), + CHECK (initiator < recipient) +); + +CREATE INDEX IF NOT EXISTS idx_dm_conversation_room ON dm_conversation (room); +CREATE INDEX IF NOT EXISTS idx_dm_conversation_initiator ON dm_conversation (initiator); +CREATE INDEX IF NOT EXISTS idx_dm_conversation_recipient ON dm_conversation (recipient); diff --git a/lib/migrate/sql/room/message_read_down_01.sql b/lib/migrate/sql/room/message_read_down_01.sql new file mode 100644 index 0000000..1a6fbba --- /dev/null +++ b/lib/migrate/sql/room/message_read_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS message_read; diff --git a/lib/migrate/sql/room/message_read_up_01.sql b/lib/migrate/sql/room/message_read_up_01.sql new file mode 100644 index 0000000..7f80294 --- /dev/null +++ b/lib/migrate/sql/room/message_read_up_01.sql @@ -0,0 +1,12 @@ +-- depends_on: room_message +CREATE TABLE IF NOT EXISTS message_read ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message UUID NOT NULL, + room UUID NOT NULL REFERENCES room(id), + "user" UUID NOT NULL, + read_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE (message, "user") +); + +CREATE INDEX IF NOT EXISTS idx_message_read_room_user ON message_read (room, "user"); +CREATE INDEX IF NOT EXISTS idx_message_read_message ON message_read (message); diff --git a/lib/migrate/sql/room/message_star_down_01.sql b/lib/migrate/sql/room/message_star_down_01.sql new file mode 100644 index 0000000..938c23a --- /dev/null +++ b/lib/migrate/sql/room/message_star_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS message_star; diff --git a/lib/migrate/sql/room/message_star_up_01.sql b/lib/migrate/sql/room/message_star_up_01.sql new file mode 100644 index 0000000..af87d00 --- /dev/null +++ b/lib/migrate/sql/room/message_star_up_01.sql @@ -0,0 +1,12 @@ +-- depends_on: room_message +CREATE TABLE IF NOT EXISTS message_star ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message UUID NOT NULL, + room UUID NOT NULL REFERENCES room(id), + "user" UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE (message, "user") +); + +CREATE INDEX IF NOT EXISTS idx_message_star_user_room ON message_star ("user", room); +CREATE INDEX IF NOT EXISTS idx_message_star_message ON message_star (message); diff --git a/lib/migrate/sql/room/room_ai_down_01.sql b/lib/migrate/sql/room/room_ai_down_01.sql new file mode 100644 index 0000000..2ecc847 --- /dev/null +++ b/lib/migrate/sql/room/room_ai_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_ai CASCADE; diff --git a/lib/migrate/sql/room/room_ai_up_01.sql b/lib/migrate/sql/room/room_ai_up_01.sql new file mode 100644 index 0000000..90f2704 --- /dev/null +++ b/lib/migrate/sql/room/room_ai_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS room_ai ( + room UUID NOT NULL REFERENCES room(id), + agent_session UUID NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT FALSE, + auto_reply BOOLEAN NOT NULL DEFAULT FALSE, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (room, agent_session) +); diff --git a/lib/migrate/sql/room/room_attachment_down_01.sql b/lib/migrate/sql/room/room_attachment_down_01.sql new file mode 100644 index 0000000..0e9c2b7 --- /dev/null +++ b/lib/migrate/sql/room/room_attachment_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_attachment CASCADE; diff --git a/lib/migrate/sql/room/room_attachment_up_01.sql b/lib/migrate/sql/room/room_attachment_up_01.sql new file mode 100644 index 0000000..16ceea7 --- /dev/null +++ b/lib/migrate/sql/room/room_attachment_up_01.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS room_attachment ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message UUID NOT NULL, + seq BIGINT NOT NULL, + file_name TEXT NOT NULL, + content_type TEXT, + size_bytes BIGINT NOT NULL DEFAULT 0, + storage_key TEXT NOT NULL, + url TEXT, + uploaded_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/room_category_down_01.sql b/lib/migrate/sql/room/room_category_down_01.sql new file mode 100644 index 0000000..a1be06a --- /dev/null +++ b/lib/migrate/sql/room/room_category_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_category CASCADE; diff --git a/lib/migrate/sql/room/room_category_up_01.sql b/lib/migrate/sql/room/room_category_up_01.sql new file mode 100644 index 0000000..e4f43d1 --- /dev/null +++ b/lib/migrate/sql/room/room_category_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS room_category ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + name TEXT NOT NULL, + position INTEGER NOT NULL DEFAULT 0, + collapsed BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/room_down_01.sql b/lib/migrate/sql/room/room_down_01.sql new file mode 100644 index 0000000..34f8373 --- /dev/null +++ b/lib/migrate/sql/room/room_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room CASCADE; diff --git a/lib/migrate/sql/room/room_mention_down_01.sql b/lib/migrate/sql/room/room_mention_down_01.sql new file mode 100644 index 0000000..f63129d --- /dev/null +++ b/lib/migrate/sql/room/room_mention_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_mention CASCADE; diff --git a/lib/migrate/sql/room/room_mention_up_01.sql b/lib/migrate/sql/room/room_mention_up_01.sql new file mode 100644 index 0000000..fc07588 --- /dev/null +++ b/lib/migrate/sql/room/room_mention_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS room_mention ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message UUID NOT NULL, + seq BIGINT NOT NULL, + mention_type TEXT NOT NULL, + target_id UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/room_message_down_01.sql b/lib/migrate/sql/room/room_message_down_01.sql new file mode 100644 index 0000000..e49bd0d --- /dev/null +++ b/lib/migrate/sql/room/room_message_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_message CASCADE; diff --git a/lib/migrate/sql/room/room_message_edit_history_down_01.sql b/lib/migrate/sql/room/room_message_edit_history_down_01.sql new file mode 100644 index 0000000..6a51e30 --- /dev/null +++ b/lib/migrate/sql/room/room_message_edit_history_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_message_edit_history CASCADE; diff --git a/lib/migrate/sql/room/room_message_edit_history_up_01.sql b/lib/migrate/sql/room/room_message_edit_history_up_01.sql new file mode 100644 index 0000000..cc3c9c7 --- /dev/null +++ b/lib/migrate/sql/room/room_message_edit_history_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS room_message_edit_history ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message UUID NOT NULL, + seq BIGINT NOT NULL, + editor UUID NOT NULL, + old_content TEXT NOT NULL, + new_content TEXT NOT NULL, + edited_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/room_message_up_01.sql b/lib/migrate/sql/room/room_message_up_01.sql new file mode 100644 index 0000000..49226c1 --- /dev/null +++ b/lib/migrate/sql/room/room_message_up_01.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS room_message ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + room UUID NOT NULL REFERENCES room(id), + seq BIGINT NOT NULL, + thread UUID, + parent UUID, + author UUID NOT NULL, + content TEXT NOT NULL, + content_type TEXT NOT NULL DEFAULT 'text', + pinned BOOLEAN NOT NULL DEFAULT FALSE, + edited_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/room/room_permission_overwrite_down_01.sql b/lib/migrate/sql/room/room_permission_overwrite_down_01.sql new file mode 100644 index 0000000..30db0f7 --- /dev/null +++ b/lib/migrate/sql/room/room_permission_overwrite_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_permission_overwrite CASCADE; diff --git a/lib/migrate/sql/room/room_permission_overwrite_up_01.sql b/lib/migrate/sql/room/room_permission_overwrite_up_01.sql new file mode 100644 index 0000000..5402777 --- /dev/null +++ b/lib/migrate/sql/room/room_permission_overwrite_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS room_permission_overwrite ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + room UUID NOT NULL REFERENCES room(id), + target_type TEXT NOT NULL, + target_id UUID NOT NULL, + allow_permissions TEXT NOT NULL, + deny_permissions TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/room_pin_down_01.sql b/lib/migrate/sql/room/room_pin_down_01.sql new file mode 100644 index 0000000..48fc2ed --- /dev/null +++ b/lib/migrate/sql/room/room_pin_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_pin CASCADE; diff --git a/lib/migrate/sql/room/room_pin_up_01.sql b/lib/migrate/sql/room/room_pin_up_01.sql new file mode 100644 index 0000000..a89fc25 --- /dev/null +++ b/lib/migrate/sql/room/room_pin_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS room_pin ( + room UUID NOT NULL REFERENCES room(id), + message UUID NOT NULL, + seq BIGINT NOT NULL, + pinned_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (room, message) +); diff --git a/lib/migrate/sql/room/room_reaction_down_01.sql b/lib/migrate/sql/room/room_reaction_down_01.sql new file mode 100644 index 0000000..b5e4b5c --- /dev/null +++ b/lib/migrate/sql/room/room_reaction_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_reaction CASCADE; diff --git a/lib/migrate/sql/room/room_reaction_up_01.sql b/lib/migrate/sql/room/room_reaction_up_01.sql new file mode 100644 index 0000000..8806528 --- /dev/null +++ b/lib/migrate/sql/room/room_reaction_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS room_reaction ( + message UUID NOT NULL, + "user" UUID NOT NULL, + seq BIGINT NOT NULL, + reaction TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (message, "user", reaction) +); diff --git a/lib/migrate/sql/room/room_server_label_down_01.sql b/lib/migrate/sql/room/room_server_label_down_01.sql new file mode 100644 index 0000000..c949eb9 --- /dev/null +++ b/lib/migrate/sql/room/room_server_label_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_server_label CASCADE; diff --git a/lib/migrate/sql/room/room_server_label_up_01.sql b/lib/migrate/sql/room/room_server_label_up_01.sql new file mode 100644 index 0000000..af3f794 --- /dev/null +++ b/lib/migrate/sql/room/room_server_label_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS room_server_label ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + name TEXT NOT NULL, + color TEXT NOT NULL, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/room/room_thread_down_01.sql b/lib/migrate/sql/room/room_thread_down_01.sql new file mode 100644 index 0000000..d8569db --- /dev/null +++ b/lib/migrate/sql/room/room_thread_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS room_thread CASCADE; diff --git a/lib/migrate/sql/room/room_thread_up_01.sql b/lib/migrate/sql/room/room_thread_up_01.sql new file mode 100644 index 0000000..19cc5c9 --- /dev/null +++ b/lib/migrate/sql/room/room_thread_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS room_thread ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + room UUID NOT NULL REFERENCES room(id), + seq BIGINT NOT NULL, + starter_message UUID, + title TEXT NOT NULL, + created_by UUID NOT NULL, + archived BOOLEAN NOT NULL DEFAULT FALSE, + locked BOOLEAN NOT NULL DEFAULT FALSE, + last_message_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + archived_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/room/room_up_01.sql b/lib/migrate/sql/room/room_up_01.sql new file mode 100644 index 0000000..8125d48 --- /dev/null +++ b/lib/migrate/sql/room/room_up_01.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS room ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + parent UUID, + name TEXT NOT NULL, + topic TEXT, + room_type TEXT NOT NULL, + position INTEGER NOT NULL DEFAULT 0, + is_private BOOLEAN NOT NULL DEFAULT FALSE, + is_archived BOOLEAN NOT NULL DEFAULT FALSE, + created_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + deleted_at TIMESTAMPTZ +); diff --git a/lib/migrate/sql/room/system_message_down_01.sql b/lib/migrate/sql/room/system_message_down_01.sql new file mode 100644 index 0000000..3407f97 --- /dev/null +++ b/lib/migrate/sql/room/system_message_down_01.sql @@ -0,0 +1,2 @@ +ALTER TABLE room_message DROP COLUMN IF EXISTS metadata; +ALTER TABLE room_message DROP COLUMN IF EXISTS system_type; diff --git a/lib/migrate/sql/room/system_message_up_01.sql b/lib/migrate/sql/room/system_message_up_01.sql new file mode 100644 index 0000000..14be638 --- /dev/null +++ b/lib/migrate/sql/room/system_message_up_01.sql @@ -0,0 +1,10 @@ +-- depends_on: room_message +ALTER TABLE room_message ADD COLUMN IF NOT EXISTS system_type TEXT; +ALTER TABLE room_message ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}'; + +COMMENT ON COLUMN room_message.system_type IS + 'System message type: user_joined, user_left, room_renamed, room_topic_changed, ' + 'room_archived, message_pinned, dm_created, etc. NULL for regular user messages. ' + 'Learned from Rocket.Chat MessageTypes system.'; +COMMENT ON COLUMN room_message.metadata IS + 'Structured metadata for system messages (e.g. old_name, new_name for rename events).'; diff --git a/lib/migrate/sql/room/user_room_state_down_01.sql b/lib/migrate/sql/room/user_room_state_down_01.sql new file mode 100644 index 0000000..5ee0469 --- /dev/null +++ b/lib/migrate/sql/room/user_room_state_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_room_state; diff --git a/lib/migrate/sql/room/user_room_state_up_01.sql b/lib/migrate/sql/room/user_room_state_up_01.sql new file mode 100644 index 0000000..af9f950 --- /dev/null +++ b/lib/migrate/sql/room/user_room_state_up_01.sql @@ -0,0 +1,19 @@ +-- depends_on: room +CREATE TABLE IF NOT EXISTS user_room_state ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "user" UUID NOT NULL, + room UUID NOT NULL REFERENCES room(id), + last_read_seq BIGINT NOT NULL DEFAULT 0, + last_read_at TIMESTAMPTZ, + is_pinned BOOLEAN NOT NULL DEFAULT FALSE, + is_muted BOOLEAN NOT NULL DEFAULT FALSE, + hide_muted BOOLEAN NOT NULL DEFAULT FALSE, + notify_level TEXT NOT NULL DEFAULT 'all', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE ("user", room) +); + +CREATE INDEX IF NOT EXISTS idx_user_room_state_user ON user_room_state ("user"); +CREATE INDEX IF NOT EXISTS idx_user_room_state_room ON user_room_state (room); +CREATE INDEX IF NOT EXISTS idx_user_room_state_pinned ON user_room_state ("user", is_pinned) WHERE is_pinned = TRUE; diff --git a/lib/migrate/sql/user/user_2fa_down_01.sql b/lib/migrate/sql/user/user_2fa_down_01.sql new file mode 100644 index 0000000..8ced9a0 --- /dev/null +++ b/lib/migrate/sql/user/user_2fa_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_2fa CASCADE; diff --git a/lib/migrate/sql/user/user_2fa_up_01.sql b/lib/migrate/sql/user/user_2fa_up_01.sql new file mode 100644 index 0000000..c8935d2 --- /dev/null +++ b/lib/migrate/sql/user/user_2fa_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS user_2fa ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + secret TEXT, + backup_codes TEXT NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_accessibility_down_01.sql b/lib/migrate/sql/user/user_accessibility_down_01.sql new file mode 100644 index 0000000..3696ecf --- /dev/null +++ b/lib/migrate/sql/user/user_accessibility_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_accessibility CASCADE; diff --git a/lib/migrate/sql/user/user_accessibility_up_01.sql b/lib/migrate/sql/user/user_accessibility_up_01.sql new file mode 100644 index 0000000..cc9bbdf --- /dev/null +++ b/lib/migrate/sql/user/user_accessibility_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS user_accessibility ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + reduce_motion BOOLEAN NOT NULL DEFAULT FALSE, + high_contrast BOOLEAN NOT NULL DEFAULT FALSE, + screen_reader_optimized BOOLEAN NOT NULL DEFAULT FALSE, + font_scale_percent INTEGER NOT NULL DEFAULT 100, + color_blind_mode TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_appearance_down_01.sql b/lib/migrate/sql/user/user_appearance_down_01.sql new file mode 100644 index 0000000..8d53c16 --- /dev/null +++ b/lib/migrate/sql/user/user_appearance_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_appearance CASCADE; diff --git a/lib/migrate/sql/user/user_appearance_up_01.sql b/lib/migrate/sql/user/user_appearance_up_01.sql new file mode 100644 index 0000000..325d66e --- /dev/null +++ b/lib/migrate/sql/user/user_appearance_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS user_appearance ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + theme TEXT NOT NULL DEFAULT 'system', + code_theme TEXT NOT NULL DEFAULT 'default', + layout_density TEXT NOT NULL DEFAULT 'normal', + sidebar_collapsed BOOLEAN NOT NULL DEFAULT FALSE, + show_line_numbers BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_billing_down_01.sql b/lib/migrate/sql/user/user_billing_down_01.sql new file mode 100644 index 0000000..1be8c62 --- /dev/null +++ b/lib/migrate/sql/user/user_billing_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_billing CASCADE; diff --git a/lib/migrate/sql/user/user_billing_history_down_01.sql b/lib/migrate/sql/user/user_billing_history_down_01.sql new file mode 100644 index 0000000..8a050b6 --- /dev/null +++ b/lib/migrate/sql/user/user_billing_history_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_billing_history CASCADE; diff --git a/lib/migrate/sql/user/user_billing_history_up_01.sql b/lib/migrate/sql/user/user_billing_history_up_01.sql new file mode 100644 index 0000000..07f5e1c --- /dev/null +++ b/lib/migrate/sql/user/user_billing_history_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS user_billing_history ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "user" UUID NOT NULL REFERENCES "user"(id), + amount NUMERIC NOT NULL, + currency TEXT NOT NULL, + reason TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_billing_up_01.sql b/lib/migrate/sql/user/user_billing_up_01.sql new file mode 100644 index 0000000..f218651 --- /dev/null +++ b/lib/migrate/sql/user/user_billing_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS user_billing ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + balance NUMERIC NOT NULL DEFAULT 0, + is_pro BOOLEAN NOT NULL DEFAULT FALSE, + total_supply NUMERIC NOT NULL DEFAULT 0, + total_supply_usable NUMERIC NOT NULL DEFAULT 0, + cycle_start TIMESTAMPTZ, + cycle_end TIMESTAMPTZ, + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_blacklist_down_01.sql b/lib/migrate/sql/user/user_blacklist_down_01.sql new file mode 100644 index 0000000..61a5b7e --- /dev/null +++ b/lib/migrate/sql/user/user_blacklist_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_blacklist CASCADE; diff --git a/lib/migrate/sql/user/user_blacklist_up_01.sql b/lib/migrate/sql/user/user_blacklist_up_01.sql new file mode 100644 index 0000000..6dfdcce --- /dev/null +++ b/lib/migrate/sql/user/user_blacklist_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS user_blacklist ( + "user" UUID NOT NULL REFERENCES "user"(id), + black UUID NOT NULL REFERENCES "user"(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY ("user", black) +); diff --git a/lib/migrate/sql/user/user_down_01.sql b/lib/migrate/sql/user/user_down_01.sql new file mode 100644 index 0000000..feb4402 --- /dev/null +++ b/lib/migrate/sql/user/user_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "user" CASCADE; diff --git a/lib/migrate/sql/user/user_email_down_01.sql b/lib/migrate/sql/user/user_email_down_01.sql new file mode 100644 index 0000000..c9a324a --- /dev/null +++ b/lib/migrate/sql/user/user_email_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_email CASCADE; diff --git a/lib/migrate/sql/user/user_email_up_01.sql b/lib/migrate/sql/user/user_email_up_01.sql new file mode 100644 index 0000000..e73bae8 --- /dev/null +++ b/lib/migrate/sql/user/user_email_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS user_email ( + "user" UUID NOT NULL REFERENCES "user"(id), + email TEXT NOT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + last_use_login TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY ("user", email) +); diff --git a/lib/migrate/sql/user/user_favorite_down_01.sql b/lib/migrate/sql/user/user_favorite_down_01.sql new file mode 100644 index 0000000..580de26 --- /dev/null +++ b/lib/migrate/sql/user/user_favorite_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_favorite CASCADE; diff --git a/lib/migrate/sql/user/user_favorite_up_01.sql b/lib/migrate/sql/user/user_favorite_up_01.sql new file mode 100644 index 0000000..75d1138 --- /dev/null +++ b/lib/migrate/sql/user/user_favorite_up_01.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS user_favorite ( + "user" UUID NOT NULL REFERENCES "user"(id), + target UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY ("user", target) +); diff --git a/lib/migrate/sql/user/user_gpg_key_down_01.sql b/lib/migrate/sql/user/user_gpg_key_down_01.sql new file mode 100644 index 0000000..d12de78 --- /dev/null +++ b/lib/migrate/sql/user/user_gpg_key_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_gpg_key CASCADE; diff --git a/lib/migrate/sql/user/user_gpg_key_up_01.sql b/lib/migrate/sql/user/user_gpg_key_up_01.sql new file mode 100644 index 0000000..4bebca9 --- /dev/null +++ b/lib/migrate/sql/user/user_gpg_key_up_01.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS user_gpg_key ( + id BIGSERIAL PRIMARY KEY, + "user" UUID NOT NULL REFERENCES "user"(id), + title TEXT NOT NULL, + public_key TEXT NOT NULL, + fingerprint TEXT NOT NULL, + key_id TEXT NOT NULL, + primary_key_id TEXT, + emails TEXT NOT NULL, + is_verified BOOLEAN NOT NULL DEFAULT FALSE, + last_used_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ, + is_revoked BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_notification_down_01.sql b/lib/migrate/sql/user/user_notification_down_01.sql new file mode 100644 index 0000000..701f35a --- /dev/null +++ b/lib/migrate/sql/user/user_notification_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_notification CASCADE; diff --git a/lib/migrate/sql/user/user_notification_up_01.sql b/lib/migrate/sql/user/user_notification_up_01.sql new file mode 100644 index 0000000..46084e1 --- /dev/null +++ b/lib/migrate/sql/user/user_notification_up_01.sql @@ -0,0 +1,18 @@ +CREATE TABLE IF NOT EXISTS user_notification ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + email_enabled BOOLEAN NOT NULL DEFAULT TRUE, + in_app_enabled BOOLEAN NOT NULL DEFAULT TRUE, + push_enabled BOOLEAN NOT NULL DEFAULT FALSE, + digest_mode TEXT NOT NULL DEFAULT 'daily', + dnd_enabled BOOLEAN NOT NULL DEFAULT FALSE, + dnd_start_minute INTEGER, + dnd_end_minute INTEGER, + marketing_enabled BOOLEAN NOT NULL DEFAULT TRUE, + security_enabled BOOLEAN NOT NULL DEFAULT TRUE, + product_enabled BOOLEAN NOT NULL DEFAULT TRUE, + push_subscription_endpoint TEXT, + push_subscription_keys_p256dh TEXT, + push_subscription_keys_auth TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_password_down_01.sql b/lib/migrate/sql/user/user_password_down_01.sql new file mode 100644 index 0000000..ae0cc9e --- /dev/null +++ b/lib/migrate/sql/user/user_password_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_password CASCADE; diff --git a/lib/migrate/sql/user/user_password_reset_down_01.sql b/lib/migrate/sql/user/user_password_reset_down_01.sql new file mode 100644 index 0000000..4d46741 --- /dev/null +++ b/lib/migrate/sql/user/user_password_reset_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_password_reset CASCADE; diff --git a/lib/migrate/sql/user/user_password_reset_up_01.sql b/lib/migrate/sql/user/user_password_reset_up_01.sql new file mode 100644 index 0000000..fde446c --- /dev/null +++ b/lib/migrate/sql/user/user_password_reset_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS user_password_reset ( + token TEXT PRIMARY KEY, + "user" UUID NOT NULL REFERENCES "user"(id), + expires_at TIMESTAMPTZ NOT NULL, + used BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_password_up_01.sql b/lib/migrate/sql/user/user_password_up_01.sql new file mode 100644 index 0000000..66ffb2e --- /dev/null +++ b/lib/migrate/sql/user/user_password_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS user_password ( + "user" UUID NOT NULL REFERENCES "user"(id), + hash TEXT NOT NULL, + salt TEXT NOT NULL, + is_active BOOLEAN NOT NULL DEFAULT TRUE, + reason TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY ("user", created_at) +); diff --git a/lib/migrate/sql/user/user_privacy_down_01.sql b/lib/migrate/sql/user/user_privacy_down_01.sql new file mode 100644 index 0000000..a17f82c --- /dev/null +++ b/lib/migrate/sql/user/user_privacy_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_privacy CASCADE; diff --git a/lib/migrate/sql/user/user_privacy_up_01.sql b/lib/migrate/sql/user/user_privacy_up_01.sql new file mode 100644 index 0000000..964c319 --- /dev/null +++ b/lib/migrate/sql/user/user_privacy_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS user_privacy ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + profile_visibility TEXT NOT NULL DEFAULT 'public', + email_visibility TEXT NOT NULL DEFAULT 'private', + activity_visibility TEXT NOT NULL DEFAULT 'public', + allow_search_indexing BOOLEAN NOT NULL DEFAULT TRUE, + allow_direct_messages BOOLEAN NOT NULL DEFAULT TRUE, + show_online_status BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_profile_down_01.sql b/lib/migrate/sql/user/user_profile_down_01.sql new file mode 100644 index 0000000..47a090e --- /dev/null +++ b/lib/migrate/sql/user/user_profile_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_profile CASCADE; diff --git a/lib/migrate/sql/user/user_profile_up_01.sql b/lib/migrate/sql/user/user_profile_up_01.sql new file mode 100644 index 0000000..b61bef0 --- /dev/null +++ b/lib/migrate/sql/user/user_profile_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS user_profile ( + "user" UUID PRIMARY KEY REFERENCES "user"(id), + language TEXT NOT NULL DEFAULT 'en', + theme TEXT NOT NULL DEFAULT 'system', + timezone TEXT NOT NULL DEFAULT 'UTC', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_session_down_01.sql b/lib/migrate/sql/user/user_session_down_01.sql new file mode 100644 index 0000000..b65a522 --- /dev/null +++ b/lib/migrate/sql/user/user_session_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_session CASCADE; diff --git a/lib/migrate/sql/user/user_session_up_01.sql b/lib/migrate/sql/user/user_session_up_01.sql new file mode 100644 index 0000000..ae62824 --- /dev/null +++ b/lib/migrate/sql/user/user_session_up_01.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS user_session ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "user" UUID NOT NULL REFERENCES "user"(id), + token_hash TEXT NOT NULL, + device_name TEXT, + user_agent TEXT, + ip_address TEXT, + last_seen_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ NOT NULL, + is_revoked BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_ssh_key_down_01.sql b/lib/migrate/sql/user/user_ssh_key_down_01.sql new file mode 100644 index 0000000..025fa53 --- /dev/null +++ b/lib/migrate/sql/user/user_ssh_key_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_ssh_key CASCADE; diff --git a/lib/migrate/sql/user/user_ssh_key_up_01.sql b/lib/migrate/sql/user/user_ssh_key_up_01.sql new file mode 100644 index 0000000..15a88cf --- /dev/null +++ b/lib/migrate/sql/user/user_ssh_key_up_01.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS user_ssh_key ( + id BIGSERIAL PRIMARY KEY, + "user" UUID NOT NULL REFERENCES "user"(id), + title TEXT NOT NULL, + public_key TEXT NOT NULL, + fingerprint TEXT NOT NULL, + key_type TEXT NOT NULL, + key_bits INTEGER, + is_verified BOOLEAN NOT NULL DEFAULT FALSE, + last_used_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ, + is_revoked BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_token_down_01.sql b/lib/migrate/sql/user/user_token_down_01.sql new file mode 100644 index 0000000..3a2f55a --- /dev/null +++ b/lib/migrate/sql/user/user_token_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_token CASCADE; diff --git a/lib/migrate/sql/user/user_token_up_01.sql b/lib/migrate/sql/user/user_token_up_01.sql new file mode 100644 index 0000000..01ce193 --- /dev/null +++ b/lib/migrate/sql/user/user_token_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS user_token ( + id BIGSERIAL PRIMARY KEY, + "user" UUID NOT NULL REFERENCES "user"(id), + name TEXT NOT NULL, + token_hash TEXT NOT NULL, + scopes TEXT NOT NULL, + expires_at TIMESTAMPTZ, + is_revoked BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_up_01.sql b/lib/migrate/sql/user/user_up_01.sql new file mode 100644 index 0000000..0c188ac --- /dev/null +++ b/lib/migrate/sql/user/user_up_01.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS "user" ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username TEXT NOT NULL, + display_name TEXT NOT NULL, + avatar_url TEXT NOT NULL, + website_url TEXT NOT NULL, + allow_use BOOLEAN NOT NULL DEFAULT FALSE, + can_search BOOLEAN NOT NULL DEFAULT FALSE, + last_sign_in_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/user/user_user_notifications_compat_down_01.sql b/lib/migrate/sql/user/user_user_notifications_compat_down_01.sql new file mode 100644 index 0000000..10ea4b9 --- /dev/null +++ b/lib/migrate/sql/user/user_user_notifications_compat_down_01.sql @@ -0,0 +1,9 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 FROM pg_class + WHERE oid = to_regclass('user_notifications') AND relkind = 'v' + ) THEN + DROP VIEW user_notifications; + END IF; +END $$; diff --git a/lib/migrate/sql/user/user_user_notifications_compat_up_01.sql b/lib/migrate/sql/user/user_user_notifications_compat_up_01.sql new file mode 100644 index 0000000..69ac412 --- /dev/null +++ b/lib/migrate/sql/user/user_user_notifications_compat_up_01.sql @@ -0,0 +1,24 @@ +DO $$ +BEGIN + IF to_regclass('user_notifications') IS NULL THEN + CREATE VIEW user_notifications AS + SELECT + "user", + email_enabled, + in_app_enabled, + push_enabled, + digest_mode, + dnd_enabled, + dnd_start_minute, + dnd_end_minute, + marketing_enabled, + security_enabled, + product_enabled, + push_subscription_endpoint, + push_subscription_keys_p256dh, + push_subscription_keys_auth, + created_at, + updated_at + FROM user_notification; + END IF; +END $$; diff --git a/lib/migrate/sql/user/user_user_pass_compat_down_01.sql b/lib/migrate/sql/user/user_user_pass_compat_down_01.sql new file mode 100644 index 0000000..3b2ecf8 --- /dev/null +++ b/lib/migrate/sql/user/user_user_pass_compat_down_01.sql @@ -0,0 +1,9 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 FROM pg_class + WHERE oid = to_regclass('user_pass') AND relkind = 'v' + ) THEN + DROP VIEW user_pass; + END IF; +END $$; diff --git a/lib/migrate/sql/user/user_user_pass_compat_up_01.sql b/lib/migrate/sql/user/user_user_pass_compat_up_01.sql new file mode 100644 index 0000000..8ada6b2 --- /dev/null +++ b/lib/migrate/sql/user/user_user_pass_compat_up_01.sql @@ -0,0 +1,15 @@ +DO $$ +BEGIN + IF to_regclass('user_pass') IS NULL THEN + CREATE VIEW user_pass AS + SELECT + "user", + hash, + salt, + is_active, + reason, + created_at, + updated_at + FROM user_password; + END IF; +END $$; diff --git a/lib/migrate/sql/workspace/wk_apply_join_down_01.sql b/lib/migrate/sql/workspace/wk_apply_join_down_01.sql new file mode 100644 index 0000000..5d5df30 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_apply_join_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_apply_join CASCADE; diff --git a/lib/migrate/sql/workspace/wk_apply_join_up_01.sql b/lib/migrate/sql/workspace/wk_apply_join_up_01.sql new file mode 100644 index 0000000..bd6b0aa --- /dev/null +++ b/lib/migrate/sql/workspace/wk_apply_join_up_01.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS wk_apply_join ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + "user" UUID NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + question TEXT, + answer TEXT, + message TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/workspace/wk_billing_down_01.sql b/lib/migrate/sql/workspace/wk_billing_down_01.sql new file mode 100644 index 0000000..b1eb572 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_billing_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_billing CASCADE; diff --git a/lib/migrate/sql/workspace/wk_billing_up_01.sql b/lib/migrate/sql/workspace/wk_billing_up_01.sql new file mode 100644 index 0000000..888d094 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_billing_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS wk_billing ( + wk UUID PRIMARY KEY, + balance NUMERIC NOT NULL DEFAULT 0, + total_supply NUMERIC NOT NULL DEFAULT 0, + total_supply_usable NUMERIC NOT NULL DEFAULT 0, + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/workspace/wk_gp_member_down_01.sql b/lib/migrate/sql/workspace/wk_gp_member_down_01.sql new file mode 100644 index 0000000..ada6080 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_gp_member_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_gp_member CASCADE; diff --git a/lib/migrate/sql/workspace/wk_gp_member_up_01.sql b/lib/migrate/sql/workspace/wk_gp_member_up_01.sql new file mode 100644 index 0000000..eb8dbed --- /dev/null +++ b/lib/migrate/sql/workspace/wk_gp_member_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS wk_gp_member ( + "user" UUID NOT NULL, + gp UUID NOT NULL, + join_at TIMESTAMPTZ NOT NULL DEFAULT now(), + leave_at TIMESTAMPTZ, + PRIMARY KEY ("user", gp, join_at) +); diff --git a/lib/migrate/sql/workspace/wk_gp_role_down_01.sql b/lib/migrate/sql/workspace/wk_gp_role_down_01.sql new file mode 100644 index 0000000..a613596 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_gp_role_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_gp_role CASCADE; diff --git a/lib/migrate/sql/workspace/wk_gp_role_up_01.sql b/lib/migrate/sql/workspace/wk_gp_role_up_01.sql new file mode 100644 index 0000000..5bd25ee --- /dev/null +++ b/lib/migrate/sql/workspace/wk_gp_role_up_01.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS wk_gp_role ( + wk UUID NOT NULL, + gp UUID NOT NULL, + repo_read BOOLEAN NOT NULL DEFAULT FALSE, + repo_write BOOLEAN NOT NULL DEFAULT FALSE, + channel_read BOOLEAN NOT NULL DEFAULT FALSE, + channel_write BOOLEAN NOT NULL DEFAULT FALSE, + ai_read BOOLEAN NOT NULL DEFAULT FALSE, + ai_write BOOLEAN NOT NULL DEFAULT FALSE, + pr_review BOOLEAN NOT NULL DEFAULT FALSE, + issues_ass BOOLEAN NOT NULL DEFAULT FALSE, + log_view BOOLEAN NOT NULL DEFAULT FALSE, + PRIMARY KEY (wk, gp) +); diff --git a/lib/migrate/sql/workspace/wk_group_down_01.sql b/lib/migrate/sql/workspace/wk_group_down_01.sql new file mode 100644 index 0000000..54ed5d3 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_group_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_group CASCADE; diff --git a/lib/migrate/sql/workspace/wk_group_up_01.sql b/lib/migrate/sql/workspace/wk_group_up_01.sql new file mode 100644 index 0000000..937c6cf --- /dev/null +++ b/lib/migrate/sql/workspace/wk_group_up_01.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS wk_group ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + wk UUID NOT NULL, + avatar_url TEXT, + is_deleted BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/workspace/wk_history_name_down_01.sql b/lib/migrate/sql/workspace/wk_history_name_down_01.sql new file mode 100644 index 0000000..9d6a677 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_history_name_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_history_name CASCADE; diff --git a/lib/migrate/sql/workspace/wk_history_name_up_01.sql b/lib/migrate/sql/workspace/wk_history_name_up_01.sql new file mode 100644 index 0000000..3d2e56d --- /dev/null +++ b/lib/migrate/sql/workspace/wk_history_name_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS wk_history_name ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + wk UUID NOT NULL, + name TEXT NOT NULL, + changed_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/workspace/wk_join_approval_down_01.sql b/lib/migrate/sql/workspace/wk_join_approval_down_01.sql new file mode 100644 index 0000000..e1c2e1f --- /dev/null +++ b/lib/migrate/sql/workspace/wk_join_approval_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_join_approval CASCADE; diff --git a/lib/migrate/sql/workspace/wk_join_approval_up_01.sql b/lib/migrate/sql/workspace/wk_join_approval_up_01.sql new file mode 100644 index 0000000..d1d8aaa --- /dev/null +++ b/lib/migrate/sql/workspace/wk_join_approval_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS wk_join_approval ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + apply UUID NOT NULL, + wk UUID NOT NULL, + "user" UUID NOT NULL, + approver UUID NOT NULL, + approved BOOLEAN NOT NULL DEFAULT FALSE, + reason TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/workspace/wk_join_strategy_down_01.sql b/lib/migrate/sql/workspace/wk_join_strategy_down_01.sql new file mode 100644 index 0000000..1497e81 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_join_strategy_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_join_strategy CASCADE; diff --git a/lib/migrate/sql/workspace/wk_join_strategy_up_01.sql b/lib/migrate/sql/workspace/wk_join_strategy_up_01.sql new file mode 100644 index 0000000..e1b5165 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_join_strategy_up_01.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS wk_join_strategy ( + wk UUID PRIMARY KEY, + require_approval BOOLEAN NOT NULL DEFAULT TRUE, + require_question BOOLEAN NOT NULL DEFAULT FALSE, + question TEXT, + answer TEXT, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/sql/workspace/wk_member_down_01.sql b/lib/migrate/sql/workspace/wk_member_down_01.sql new file mode 100644 index 0000000..2016988 --- /dev/null +++ b/lib/migrate/sql/workspace/wk_member_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wk_member CASCADE; diff --git a/lib/migrate/sql/workspace/wk_member_up_01.sql b/lib/migrate/sql/workspace/wk_member_up_01.sql new file mode 100644 index 0000000..12c160f --- /dev/null +++ b/lib/migrate/sql/workspace/wk_member_up_01.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS wk_member ( + wk UUID NOT NULL, + "user" UUID NOT NULL, + owner BOOLEAN NOT NULL DEFAULT FALSE, + admin BOOLEAN NOT NULL DEFAULT FALSE, + join_at TIMESTAMPTZ NOT NULL DEFAULT now(), + leave_at TIMESTAMPTZ, + PRIMARY KEY (wk, "user") +); diff --git a/lib/migrate/sql/workspace/workspace_down_01.sql b/lib/migrate/sql/workspace/workspace_down_01.sql new file mode 100644 index 0000000..ab80a64 --- /dev/null +++ b/lib/migrate/sql/workspace/workspace_down_01.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS workspace CASCADE; diff --git a/lib/migrate/sql/workspace/workspace_up_01.sql b/lib/migrate/sql/workspace/workspace_up_01.sql new file mode 100644 index 0000000..ee629a9 --- /dev/null +++ b/lib/migrate/sql/workspace/workspace_up_01.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS workspace ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/lib/migrate/src/main.rs b/lib/migrate/src/main.rs new file mode 100644 index 0000000..ad2867b --- /dev/null +++ b/lib/migrate/src/main.rs @@ -0,0 +1,419 @@ +use anyhow::{Context, Result, bail}; +use clap::{Parser, Subcommand}; +use sqlx::postgres::PgPoolOptions; +use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::path::{Path, PathBuf}; +use tracing::info; + +#[derive(Parser)] +#[command(name = "migrate", about = "Database migration tool")] +struct Cli { + #[arg(short, long)] + database_url: String, + + #[command(subcommand)] + command: Command, +} + +#[derive(Subcommand)] +enum Command { + Up, + Down, + Fresh, + List, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Migration { + domain: String, + table: String, + version: u32, + direction: MigrationDir, + path: PathBuf, + depends_on: Vec, +} + +impl Ord for Migration { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (&self.domain, &self.table, self.version, &self.direction) + .cmp(&(&other.domain, &other.table, other.version, &other.direction)) + } +} + +impl PartialOrd for Migration { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +enum MigrationDir { + Up, + Down, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter("migrate=info") + .init(); + + let cli = Cli::parse(); + + let database_url = std::env::var("DATABASE_URL") + .context("DATABASE_URL must be set or provided via --database-url")?; + + let pool = PgPoolOptions::new() + .max_connections(1) + .connect(&database_url) + .await + .context("Failed to connect to database")?; + + let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql"); + + match cli.command { + Command::Up => run_up(&pool, &sql_root).await, + Command::Down => run_down(&pool, &sql_root).await, + Command::Fresh => run_fresh(&pool, &sql_root).await, + Command::List => run_list(&pool, &sql_root).await, + } +} + +fn discover_migrations(sql_root: &Path) -> Result> { + let mut migrations = Vec::new(); + + if !sql_root.exists() { + bail!("SQL directory not found: {}", sql_root.display()); + } + + for dir_entry in std::fs::read_dir(sql_root)? { + let dir = dir_entry?; + if !dir.file_type()?.is_dir() { + continue; + } + let domain = dir.file_name().to_string_lossy().to_string(); + + for file_entry in std::fs::read_dir(dir.path())? { + let file = file_entry?; + let path = file.path(); + if path.extension().and_then(|e| e.to_str()) != Some("sql") { + continue; + } + + let stem = path + .file_stem() + .and_then(|s| s.to_str()) + .context("Invalid filename")?; + + let (table, direction, version) = parse_migration_stem(stem)?; + + let content = std::fs::read_to_string(&path) + .context(format!("Failed to read {path:?}"))?; + let depends_on = parse_depends_on(&content); + + migrations.push(Migration { + domain: domain.clone(), + table, + version, + direction, + path, + depends_on, + }); + } + } + + migrations.sort(); + Ok(migrations) +} + +fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> { + if let Some(pos) = stem.rfind("_up_") { + let table = stem[..pos].to_string(); + let ver_str = &stem[pos + 4..]; + let version = + ver_str.parse::().context("Invalid version number")?; + Ok((table, MigrationDir::Up, version)) + } else if let Some(pos) = stem.rfind("_down_") { + let table = stem[..pos].to_string(); + let ver_str = &stem[pos + 6..]; + let version = + ver_str.parse::().context("Invalid version number")?; + Ok((table, MigrationDir::Down, version)) + } else { + bail!("Migration filename must contain _up_ or _down_: {stem}"); + } +} + +async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS _sql_migrations ( + domain TEXT NOT NULL, + table_name TEXT NOT NULL, + version INTEGER NOT NULL, + applied_at TIMESTAMPTZ NOT NULL DEFAULT now(), + checksum TEXT NOT NULL DEFAULT '', + PRIMARY KEY (domain, table_name, version) + ) + "#, + ) + .execute(pool) + .await?; + Ok(()) +} + +async fn applied_set( + pool: &sqlx::PgPool, +) -> Result> { + let rows: Vec<(String, String, i32, String)> = + sqlx::query_as("SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version") + .fetch_all(pool) + .await?; + + Ok(rows + .into_iter() + .map(|(d, t, v, c)| ((d, t, v as u32), c)) + .collect()) +} + +async fn record_migration( + pool: &sqlx::PgPool, + m: &Migration, + checksum: &str, +) -> Result<()> { + sqlx::query( + r#" + INSERT INTO _sql_migrations (domain, table_name, version, checksum) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING + "#, + ) + .bind(&m.domain) + .bind(&m.table) + .bind(m.version as i32) + .bind(checksum) + .execute(pool) + .await?; + Ok(()) +} + +async fn delete_migration(pool: &sqlx::PgPool, m: &Migration) -> Result<()> { + sqlx::query( + "DELETE FROM _sql_migrations WHERE domain = $1 AND table_name = $2 AND version = $3", + ) + .bind(&m.domain) + .bind(&m.table) + .bind(m.version as i32) + .execute(pool) + .await?; + Ok(()) +} + +fn compute_checksum(content: &str) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + content.hash(&mut hasher); + format!("{:x}", hasher.finish()) +} + +fn parse_depends_on(content: &str) -> Vec { + content + .lines() + .filter_map(|line| { + let line = line.trim(); + line.strip_prefix("-- depends_on:") + .map(|deps| { + deps.split(',') + .map(|d| d.trim().to_string()) + .filter(|d| !d.is_empty()) + .collect::>() + }) + }) + .flatten() + .collect() +} + +fn topo_sort(migrations: &mut [Migration]) -> Result<()> { + let table_to_idx: HashMap = migrations + .iter() + .enumerate() + .map(|(i, m)| (m.table.clone(), i)) + .collect(); + + let n = migrations.len(); + let mut in_degree = vec![0u32; n]; + let mut adj: Vec> = vec![Vec::new(); n]; + + for (i, m) in migrations.iter().enumerate() { + for dep in &m.depends_on { + if let Some(&j) = table_to_idx.get(dep) { + adj[j].push(i); + in_degree[i] += 1; + } + } + } + + let mut queue: VecDeque = (0..n) + .filter(|&i| in_degree[i] == 0) + .collect(); + + let mut order = Vec::with_capacity(n); + while let Some(i) = queue.pop_front() { + order.push(i); + for &next in &adj[i] { + in_degree[next] -= 1; + if in_degree[next] == 0 { + queue.push_back(next); + } + } + } + + if order.len() != n { + bail!("Circular dependency detected among migrations"); + } + + let original: Vec = migrations.iter().cloned().collect(); + for (slot, &idx) in order.iter().enumerate() { + migrations[slot] = original[idx].clone(); + } + + Ok(()) +} +fn into_static(s: String) -> &'static str { + Box::leak(s.into_boxed_str()) +} + +async fn exec_sql(pool: &sqlx::PgPool, sql: &str) -> Result<()> { + sqlx::raw_sql(into_static(sql.to_owned())) + .execute(pool) + .await?; + Ok(()) +} + +async fn run_up(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { + ensure_migrations_table(pool).await?; + let all = discover_migrations(sql_root)?; + let applied = applied_set(pool).await?; + + let mut up_migrations: Vec<_> = all + .into_iter() + .filter(|m| m.direction == MigrationDir::Up) + .filter(|m| { + !applied.contains_key(&( + m.domain.clone(), + m.table.clone(), + m.version, + )) + }) + .collect(); + + if up_migrations.is_empty() { + info!("All migrations are already applied."); + return Ok(()); + } + + topo_sort(&mut up_migrations)?; + + for m in &up_migrations { + let sql = std::fs::read_to_string(&m.path) + .context(format!("Failed to read {:?}", m.path))?; + let checksum = compute_checksum(&sql); + + info!("Applying {}/{}/v{}", m.domain, m.table, m.version); + exec_sql(pool, &sql).await?; + record_migration(pool, m, &checksum).await?; + } + + info!("Applied {} migration(s).", up_migrations.len()); + Ok(()) +} + +async fn run_down(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { + ensure_migrations_table(pool).await?; + let all = discover_migrations(sql_root)?; + let applied = applied_set(pool).await?; + + let mut down_targets: Vec<_> = all + .into_iter() + .filter(|m| m.direction == MigrationDir::Down) + .filter(|m| { + applied.contains_key(&( + m.domain.clone(), + m.table.clone(), + m.version, + )) + }) + .collect(); + down_targets.sort(); + + if down_targets.is_empty() { + info!("No migrations to roll back."); + return Ok(()); + } + + let m = &down_targets[down_targets.len() - 1]; + let sql = std::fs::read_to_string(&m.path)?; + + info!("Rolling back {}/{}/v{}", m.domain, m.table, m.version); + exec_sql(pool, &sql).await?; + delete_migration(pool, m).await?; + + info!("Rolled back 1 migration."); + Ok(()) +} + +async fn run_fresh(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { + info!("Dropping all tables and re-applying migrations..."); + + exec_sql(pool, "DROP TABLE IF EXISTS _sql_migrations CASCADE").await?; + + let all = discover_migrations(sql_root)?; + let down_migrations: Vec<_> = all + .into_iter() + .filter(|m| m.direction == MigrationDir::Down) + .collect(); + + let mut drops: Vec<_> = down_migrations.iter().collect(); + drops.sort(); + drops.reverse(); + + for m in &drops { + let sql = std::fs::read_to_string(&m.path)?; + let _ = exec_sql(pool, &sql).await; + } + + run_up(pool, sql_root).await +} + +async fn run_list(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { + ensure_migrations_table(pool).await?; + let all = discover_migrations(sql_root)?; + let applied = applied_set(pool).await?; + + let up_migrations: Vec<_> = all + .into_iter() + .filter(|m| m.direction == MigrationDir::Up) + .collect(); + + println!( + "{:<20} {:<30} {:>8} {}", + "Domain", "Table", "Version", "Status" + ); + println!("{:-<20} {:-<30} {:-<8} {:-<10}", "", "", "", ""); + + for m in &up_migrations { + let key = (m.domain.clone(), m.table.clone(), m.version); + let status = if applied.contains_key(&key) { + "Applied" + } else { + "Pending" + }; + println!( + "{:<20} {:<30} {:>8} {}", + m.domain, m.table, m.version, status + ); + } + + Ok(()) +} diff --git a/lib/model/Cargo.toml b/lib/model/Cargo.toml new file mode 100644 index 0000000..84400d8 --- /dev/null +++ b/lib/model/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "model" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true +[lib] +path = "lib.rs" +name = "model" +[dependencies] +sqlx = { workspace = true, features = [ + "derive", + "postgres", + "runtime-tokio", + "uuid", + "rust_decimal", + "chrono" +]} +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +rust_decimal = { workspace = true, features = ["serde"] } +uuid = { workspace = true, features = ["serde","v7"] } +chrono = { workspace = true, features = ["serde"] } +db = { workspace = true } +[lints] +workspace = true diff --git a/lib/model/agent/agent_conversation.rs b/lib/model/agent/agent_conversation.rs new file mode 100644 index 0000000..1761c31 --- /dev/null +++ b/lib/model/agent/agent_conversation.rs @@ -0,0 +1,18 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentConversationModel { + pub id: Uuid, + pub session: Uuid, + pub title: String, + pub created_by: Uuid, + pub last_message_at: Option>, + pub archived_at: Option>, + pub compacted_summary: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/agent/agent_knowledge_base.rs b/lib/model/agent/agent_knowledge_base.rs new file mode 100644 index 0000000..d4fce79 --- /dev/null +++ b/lib/model/agent/agent_knowledge_base.rs @@ -0,0 +1,19 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentKnowledgeBaseModel { + pub id: Uuid, + pub session: Uuid, + pub title: String, + pub source_type: String, + pub source_url: Option, + pub content: Option, + pub embedding_ref: Option, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/agent/agent_long_term_memories.rs b/lib/model/agent/agent_long_term_memories.rs new file mode 100644 index 0000000..3b4ca20 --- /dev/null +++ b/lib/model/agent/agent_long_term_memories.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentLongTermMemoryModel { + pub id: Uuid, + pub session: Uuid, + pub key: String, + pub value: String, + pub importance: i32, + pub last_used_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/agent/agent_message.rs b/lib/model/agent/agent_message.rs new file mode 100644 index 0000000..2cfd6ee --- /dev/null +++ b/lib/model/agent/agent_message.rs @@ -0,0 +1,21 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentMessageModel { + pub id: Uuid, + pub conversation: Uuid, + pub parent: Option, + pub role: String, + pub author: Option, + pub content: String, + pub content_type: String, + pub status: String, + pub model_invocation: Option, + pub reasoning_content: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/agent/agent_message_fork.rs b/lib/model/agent/agent_message_fork.rs new file mode 100644 index 0000000..2919483 --- /dev/null +++ b/lib/model/agent/agent_message_fork.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentMessageForkModel { + pub id: Uuid, + pub source_message: Uuid, + pub forked_conversation: Uuid, + pub forked_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/agent/agent_model_invocations.rs b/lib/model/agent/agent_model_invocations.rs new file mode 100644 index 0000000..ec981d7 --- /dev/null +++ b/lib/model/agent/agent_model_invocations.rs @@ -0,0 +1,21 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentModelInvocationModel { + pub id: Uuid, + pub session: Uuid, + pub conversation: Option, + pub message: Option, + pub model_version: Uuid, + pub request_id: Option, + pub status: String, + pub prompt: Option, + pub response: Option, + pub error: Option, + pub started_at: DateTime, + pub finished_at: Option>, + pub latency_ms: Option, +} diff --git a/lib/model/agent/agent_session.rs b/lib/model/agent/agent_session.rs new file mode 100644 index 0000000..e3c9f26 --- /dev/null +++ b/lib/model/agent/agent_session.rs @@ -0,0 +1,36 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentSessionModel { + pub id: Uuid, + pub user: Option, + pub wk: Option, + pub name: String, + pub description: Option, + pub agent_kind: String, + pub model_version: Option, + pub system_prompt: Option, + pub temperature: Option, + pub max_output_tokens: Option, + pub tool_policy: Option, + pub knowledge_base_ids: Option, + pub variables: Option, + pub visibility: String, + pub version: i32, + pub published_at: Option>, + pub rollback_from_version: Option, + pub enabled: bool, + pub source: Option, + pub parent_session_id: Option, + pub toolset_json: Option, + pub memory_provider: Option, + pub memory_provider_config: Option, + pub iteration_budget: Option, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/agent/agent_subagent_session.rs b/lib/model/agent/agent_subagent_session.rs new file mode 100644 index 0000000..d690880 --- /dev/null +++ b/lib/model/agent/agent_subagent_session.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentSubagentSessionModel { + pub id: Uuid, + pub parent_session: Uuid, + pub child_session: Uuid, + pub name: String, + pub purpose: Option, + pub created_at: DateTime, + pub ended_at: Option>, +} diff --git a/lib/model/agent/agent_token_usage.rs b/lib/model/agent/agent_token_usage.rs new file mode 100644 index 0000000..0e6c144 --- /dev/null +++ b/lib/model/agent/agent_token_usage.rs @@ -0,0 +1,22 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::{FromRow, types::Decimal}; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentTokenUsageModel { + pub id: Uuid, + pub invocation: Uuid, + pub session: Uuid, + pub model_version: Uuid, + pub input_tokens: i64, + pub output_tokens: i64, + pub cached_input_tokens: i64, + pub cache_read_tokens: Option, + pub cache_write_tokens: Option, + pub reasoning_tokens: Option, + pub total_tokens: i64, + pub cost: Option, + pub currency: Option, + pub created_at: DateTime, +} diff --git a/lib/model/agent/agent_tool_call_log.rs b/lib/model/agent/agent_tool_call_log.rs new file mode 100644 index 0000000..8934f53 --- /dev/null +++ b/lib/model/agent/agent_tool_call_log.rs @@ -0,0 +1,22 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentToolCallLogModel { + pub id: Uuid, + pub invocation: Option, + pub session: Uuid, + pub conversation: Option, + pub message: Option, + pub tool_call_id: Option, + pub tool_name: String, + pub arguments: Option, + pub result: Option, + pub error: Option, + pub status: String, + pub started_at: DateTime, + pub finished_at: Option>, + pub latency_ms: Option, +} diff --git a/lib/model/agent/agent_trace.rs b/lib/model/agent/agent_trace.rs new file mode 100644 index 0000000..3d13836 --- /dev/null +++ b/lib/model/agent/agent_trace.rs @@ -0,0 +1,33 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AgentTraceModel { + pub id: Uuid, + pub invocation: Uuid, + pub conversation: Uuid, + pub sequence: i32, + pub phase: String, + pub content: Option, + pub tool_calls: Option, + pub tool_results: Option, + pub input_tokens: Option, + pub output_tokens: Option, + pub metadata: Option, + pub created_at: DateTime, +} + +impl AgentTraceModel { + pub fn phase_label(&self) -> &str { + match self.phase.as_str() { + "think" => "Thinking", + "answer" => "Answering", + "act" => "Acting", + "summarize" => "Summarizing", + _ => &self.phase, + } + } +} diff --git a/lib/model/agent/mod.rs b/lib/model/agent/mod.rs new file mode 100644 index 0000000..37f8c1f --- /dev/null +++ b/lib/model/agent/mod.rs @@ -0,0 +1,23 @@ +pub mod agent_conversation; +pub mod agent_knowledge_base; +pub mod agent_long_term_memories; +pub mod agent_message; +pub mod agent_message_fork; +pub mod agent_model_invocations; +pub mod agent_session; +pub mod agent_subagent_session; +pub mod agent_token_usage; +pub mod agent_tool_call_log; +pub mod agent_trace; + +pub use agent_conversation::AgentConversationModel; +pub use agent_knowledge_base::AgentKnowledgeBaseModel; +pub use agent_long_term_memories::AgentLongTermMemoryModel; +pub use agent_message::AgentMessageModel; +pub use agent_message_fork::AgentMessageForkModel; +pub use agent_model_invocations::AgentModelInvocationModel; +pub use agent_session::AgentSessionModel; +pub use agent_subagent_session::AgentSubagentSessionModel; +pub use agent_token_usage::AgentTokenUsageModel; +pub use agent_tool_call_log::AgentToolCallLogModel; +pub use agent_trace::AgentTraceModel; diff --git a/lib/model/ai/ai_model.rs b/lib/model/ai/ai_model.rs new file mode 100644 index 0000000..6879b66 --- /dev/null +++ b/lib/model/ai/ai_model.rs @@ -0,0 +1,22 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiModelModel { + pub id: Uuid, + pub provider: Uuid, + pub name: String, + pub display_name: String, + pub description: Option, + pub modality: String, + pub context_window: Option, + pub input_token_limit: Option, + pub output_token_limit: Option, + pub enabled: bool, + pub public: bool, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/ai/ai_model_card.rs b/lib/model/ai/ai_model_card.rs new file mode 100644 index 0000000..bda0aab --- /dev/null +++ b/lib/model/ai/ai_model_card.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiModelCardModel { + pub model: Uuid, + pub overview: Option, + pub strengths: Option, + pub limitations: Option, + pub safety_notes: Option, + pub eval_summary: Option, + pub metadata: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/ai/ai_model_discussion.rs b/lib/model/ai/ai_model_discussion.rs new file mode 100644 index 0000000..7268474 --- /dev/null +++ b/lib/model/ai/ai_model_discussion.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiModelDiscussionModel { + pub id: Uuid, + pub model: Uuid, + pub user: Uuid, + pub parent: Option, + pub body: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/ai/ai_model_like.rs b/lib/model/ai/ai_model_like.rs new file mode 100644 index 0000000..2ff4362 --- /dev/null +++ b/lib/model/ai/ai_model_like.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiModelLikeModel { + pub model: Uuid, + pub user: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/ai/ai_model_model_tag.rs b/lib/model/ai/ai_model_model_tag.rs new file mode 100644 index 0000000..4723b0b --- /dev/null +++ b/lib/model/ai/ai_model_model_tag.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiModelModelTagModel { + pub model: Uuid, + pub tag: String, + pub created_at: DateTime, +} diff --git a/lib/model/ai/ai_model_version.rs b/lib/model/ai/ai_model_version.rs new file mode 100644 index 0000000..545e0cf --- /dev/null +++ b/lib/model/ai/ai_model_version.rs @@ -0,0 +1,21 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::{FromRow, types::Decimal}; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiModelVersionModel { + pub id: Uuid, + pub model: Uuid, + pub version: String, + pub provider_model_name: String, + pub input_price_per_million: Option, + pub output_price_per_million: Option, + pub cached_input_price_per_million: Option, + pub training_cutoff: Option, + pub released_at: Option>, + pub deprecated_at: Option>, + pub enabled: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/ai/ai_provider.rs b/lib/model/ai/ai_provider.rs new file mode 100644 index 0000000..57505ab --- /dev/null +++ b/lib/model/ai/ai_provider.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct AiProviderModel { + pub id: Uuid, + pub name: String, + pub base_url: Option, + pub website_url: Option, + pub logo_url: Option, + pub enabled: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/ai/mod.rs b/lib/model/ai/mod.rs new file mode 100644 index 0000000..4718c87 --- /dev/null +++ b/lib/model/ai/mod.rs @@ -0,0 +1,15 @@ +pub mod ai_model; +pub mod ai_model_card; +pub mod ai_model_discussion; +pub mod ai_model_like; +pub mod ai_model_model_tag; +pub mod ai_model_version; +pub mod ai_provider; + +pub use ai_model::AiModelModel; +pub use ai_model_card::AiModelCardModel; +pub use ai_model_discussion::AiModelDiscussionModel; +pub use ai_model_like::AiModelLikeModel; +pub use ai_model_model_tag::AiModelModelTagModel; +pub use ai_model_version::AiModelVersionModel; +pub use ai_provider::AiProviderModel; diff --git a/lib/model/issues/issue.rs b/lib/model/issues/issue.rs new file mode 100644 index 0000000..6c8f071 --- /dev/null +++ b/lib/model/issues/issue.rs @@ -0,0 +1,22 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueModel { + pub id: Uuid, + pub wk: Uuid, + pub number: i64, + pub title: String, + pub body: Option, + pub state: String, + pub priority: String, + pub author: Uuid, + pub closed_by: Option, + pub closed_at: Option>, + pub due_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/issues/issue_assignee.rs b/lib/model/issues/issue_assignee.rs new file mode 100644 index 0000000..8996932 --- /dev/null +++ b/lib/model/issues/issue_assignee.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueAssigneeModel { + pub issue: Uuid, + pub user: Uuid, + pub assigned_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_comment.rs b/lib/model/issues/issue_comment.rs new file mode 100644 index 0000000..9417d8d --- /dev/null +++ b/lib/model/issues/issue_comment.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueCommentModel { + pub id: Uuid, + pub issue: Uuid, + pub author: Uuid, + pub body: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/issues/issue_event.rs b/lib/model/issues/issue_event.rs new file mode 100644 index 0000000..2829a7a --- /dev/null +++ b/lib/model/issues/issue_event.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueEventModel { + pub id: Uuid, + pub issue: Uuid, + pub actor: Option, + pub event: String, + pub from_value: Option, + pub to_value: Option, + pub metadata: Option, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_label.rs b/lib/model/issues/issue_label.rs new file mode 100644 index 0000000..8d08ace --- /dev/null +++ b/lib/model/issues/issue_label.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueLabelModel { + pub issue: Uuid, + pub label: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_milestone.rs b/lib/model/issues/issue_milestone.rs new file mode 100644 index 0000000..7387292 --- /dev/null +++ b/lib/model/issues/issue_milestone.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueMilestoneModel { + pub issue: Uuid, + pub milestone: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_pull_request.rs b/lib/model/issues/issue_pull_request.rs new file mode 100644 index 0000000..d495009 --- /dev/null +++ b/lib/model/issues/issue_pull_request.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssuePullRequestModel { + pub issue: Uuid, + pub pull_request: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_reaction.rs b/lib/model/issues/issue_reaction.rs new file mode 100644 index 0000000..00941b2 --- /dev/null +++ b/lib/model/issues/issue_reaction.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueReactionModel { + pub id: Uuid, + pub issue: Uuid, + pub comment: Option, + pub user: Uuid, + pub reaction: String, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_reference.rs b/lib/model/issues/issue_reference.rs new file mode 100644 index 0000000..b5cf8a0 --- /dev/null +++ b/lib/model/issues/issue_reference.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueReferenceModel { + pub id: Uuid, + pub issue: Uuid, + pub target_type: String, + pub target_id: Uuid, + pub created_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/issues/issue_repo.rs b/lib/model/issues/issue_repo.rs new file mode 100644 index 0000000..1112d6a --- /dev/null +++ b/lib/model/issues/issue_repo.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct IssueRepoModel { + pub issue: Uuid, + pub repo: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/issues/label.rs b/lib/model/issues/label.rs new file mode 100644 index 0000000..98c71b0 --- /dev/null +++ b/lib/model/issues/label.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct LabelModel { + pub id: Uuid, + pub wk: Uuid, + pub name: String, + pub color: String, + pub description: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/issues/milestone.rs b/lib/model/issues/milestone.rs new file mode 100644 index 0000000..eb09b6c --- /dev/null +++ b/lib/model/issues/milestone.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct MilestoneModel { + pub id: Uuid, + pub wk: Uuid, + pub title: String, + pub description: Option, + pub state: String, + pub due_at: Option>, + pub closed_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/issues/mod.rs b/lib/model/issues/mod.rs new file mode 100644 index 0000000..4b8c8de --- /dev/null +++ b/lib/model/issues/mod.rs @@ -0,0 +1,25 @@ +pub mod issue; +pub mod issue_assignee; +pub mod issue_comment; +pub mod issue_event; +pub mod issue_label; +pub mod issue_milestone; +pub mod issue_pull_request; +pub mod issue_reaction; +pub mod issue_reference; +pub mod issue_repo; +pub mod label; +pub mod milestone; + +pub use issue::IssueModel; +pub use issue_assignee::IssueAssigneeModel; +pub use issue_comment::IssueCommentModel; +pub use issue_event::IssueEventModel; +pub use issue_label::IssueLabelModel; +pub use issue_milestone::IssueMilestoneModel; +pub use issue_pull_request::IssuePullRequestModel; +pub use issue_reaction::IssueReactionModel; +pub use issue_reference::IssueReferenceModel; +pub use issue_repo::IssueRepoModel; +pub use label::LabelModel; +pub use milestone::MilestoneModel; diff --git a/lib/model/lib.rs b/lib/model/lib.rs new file mode 100644 index 0000000..7cf073d --- /dev/null +++ b/lib/model/lib.rs @@ -0,0 +1,18 @@ +use db::AppDatabase; + +pub mod agent; +pub mod ai; +pub mod issues; +pub mod logs; +pub mod notify; +pub mod pull_request; +pub mod repos; +pub mod room; +pub mod system; +pub mod users; +pub mod workspace; + +#[derive(Clone)] +pub struct DatabaseMapper { + pub db: AppDatabase, +} diff --git a/lib/model/notify/mod.rs b/lib/model/notify/mod.rs new file mode 100644 index 0000000..e2dc81f --- /dev/null +++ b/lib/model/notify/mod.rs @@ -0,0 +1,5 @@ +pub mod user_app_notify; +pub mod user_email_notify; + +pub use user_app_notify::UserAppNotifyModel; +pub use user_email_notify::UserEmailNotifyModel; diff --git a/lib/model/notify/user_app_notify.rs b/lib/model/notify/user_app_notify.rs new file mode 100644 index 0000000..7ce3ba4 --- /dev/null +++ b/lib/model/notify/user_app_notify.rs @@ -0,0 +1,20 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserAppNotifyModel { + pub id: Uuid, + pub user: Uuid, + pub title: String, + pub body: String, + pub notify_type: String, + pub target_type: Option, + pub target_id: Option, + pub metadata: Option, + pub read_at: Option>, + pub archived_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/notify/user_email_notify.rs b/lib/model/notify/user_email_notify.rs new file mode 100644 index 0000000..7694e35 --- /dev/null +++ b/lib/model/notify/user_email_notify.rs @@ -0,0 +1,29 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserEmailNotifyModel { + pub id: Uuid, + pub user: Uuid, + pub email: String, + pub subject: String, + pub template: String, + pub body_text: Option, + pub body_html: Option, + pub notify_type: String, + pub target_type: Option, + pub target_id: Option, + pub metadata: Option, + pub status: String, + pub provider_message_id: Option, + pub error: Option, + pub retry_count: i32, + pub queued_at: DateTime, + pub sent_at: Option>, + pub delivered_at: Option>, + pub opened_at: Option>, + pub clicked_at: Option>, + pub failed_at: Option>, +} diff --git a/lib/model/pull_request/mod.rs b/lib/model/pull_request/mod.rs new file mode 100644 index 0000000..10e7b30 --- /dev/null +++ b/lib/model/pull_request/mod.rs @@ -0,0 +1,21 @@ +pub mod pull_request; +pub mod pull_request_assignee; +pub mod pull_request_comment; +pub mod pull_request_commit; +pub mod pull_request_label; +pub mod pull_request_reaction; +pub mod pull_request_review; +pub mod pull_request_review_comment; +pub mod pull_request_review_reaction; +pub mod pull_request_review_request; + +pub use pull_request::PullRequestModel; +pub use pull_request_assignee::PullRequestAssigneeModel; +pub use pull_request_comment::PullRequestCommentModel; +pub use pull_request_commit::PullRequestCommitModel; +pub use pull_request_label::PullRequestLabelModel; +pub use pull_request_reaction::PullRequestReactionModel; +pub use pull_request_review::PullRequestReviewModel; +pub use pull_request_review_comment::PullRequestReviewCommentModel; +pub use pull_request_review_reaction::PullRequestReviewReactionModel; +pub use pull_request_review_request::PullRequestReviewRequestModel; diff --git a/lib/model/pull_request/pull_request.rs b/lib/model/pull_request/pull_request.rs new file mode 100644 index 0000000..b8d1847 --- /dev/null +++ b/lib/model/pull_request/pull_request.rs @@ -0,0 +1,28 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestModel { + pub id: Uuid, + pub repo: Uuid, + pub number: i64, + pub title: String, + pub body: Option, + pub state: String, + pub draft: bool, + pub author: Uuid, + pub source_repo: Uuid, + pub source_branch: String, + pub source_sha: String, + pub target_branch: String, + pub target_sha: String, + pub merged_by: Option, + pub merged_at: Option>, + pub closed_by: Option, + pub closed_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/pull_request/pull_request_assignee.rs b/lib/model/pull_request/pull_request_assignee.rs new file mode 100644 index 0000000..db6b91d --- /dev/null +++ b/lib/model/pull_request/pull_request_assignee.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestAssigneeModel { + pub pull_request: Uuid, + pub user: Uuid, + pub assigned_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/pull_request/pull_request_comment.rs b/lib/model/pull_request/pull_request_comment.rs new file mode 100644 index 0000000..7f21947 --- /dev/null +++ b/lib/model/pull_request/pull_request_comment.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestCommentModel { + pub id: Uuid, + pub pull_request: Uuid, + pub author: Uuid, + pub body: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/pull_request/pull_request_commit.rs b/lib/model/pull_request/pull_request_commit.rs new file mode 100644 index 0000000..22f591f --- /dev/null +++ b/lib/model/pull_request/pull_request_commit.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestCommitModel { + pub pull_request: Uuid, + pub commit: Uuid, + pub sha: String, + pub created_at: DateTime, +} diff --git a/lib/model/pull_request/pull_request_label.rs b/lib/model/pull_request/pull_request_label.rs new file mode 100644 index 0000000..f54cf03 --- /dev/null +++ b/lib/model/pull_request/pull_request_label.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestLabelModel { + pub pull_request: Uuid, + pub label: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/pull_request/pull_request_reaction.rs b/lib/model/pull_request/pull_request_reaction.rs new file mode 100644 index 0000000..33e3f87 --- /dev/null +++ b/lib/model/pull_request/pull_request_reaction.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestReactionModel { + pub id: Uuid, + pub pull_request: Uuid, + pub comment: Option, + pub user: Uuid, + pub reaction: String, + pub created_at: DateTime, +} diff --git a/lib/model/pull_request/pull_request_review.rs b/lib/model/pull_request/pull_request_review.rs new file mode 100644 index 0000000..b311df5 --- /dev/null +++ b/lib/model/pull_request/pull_request_review.rs @@ -0,0 +1,20 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestReviewModel { + pub id: Uuid, + pub pull_request: Uuid, + pub reviewer: Uuid, + pub state: String, + pub body: Option, + pub commit_sha: Option, + pub submitted_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub dismissed_by: Option, + pub dismissed_at: Option>, + pub dismiss_reason: Option, +} diff --git a/lib/model/pull_request/pull_request_review_comment.rs b/lib/model/pull_request/pull_request_review_comment.rs new file mode 100644 index 0000000..e8f1aae --- /dev/null +++ b/lib/model/pull_request/pull_request_review_comment.rs @@ -0,0 +1,25 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestReviewCommentModel { + pub id: Uuid, + pub pull_request: Uuid, + pub review: Option, + pub author: Uuid, + pub body: String, + pub path: String, + pub commit_sha: String, + pub original_commit_sha: Option, + pub line: Option, + pub original_line: Option, + pub side: Option, + pub resolved: bool, + pub resolved_by: Option, + pub resolved_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/pull_request/pull_request_review_reaction.rs b/lib/model/pull_request/pull_request_review_reaction.rs new file mode 100644 index 0000000..409631b --- /dev/null +++ b/lib/model/pull_request/pull_request_review_reaction.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestReviewReactionModel { + pub id: Uuid, + pub review_comment: Uuid, + pub user: Uuid, + pub reaction: String, + pub created_at: DateTime, +} diff --git a/lib/model/pull_request/pull_request_review_request.rs b/lib/model/pull_request/pull_request_review_request.rs new file mode 100644 index 0000000..98b6125 --- /dev/null +++ b/lib/model/pull_request/pull_request_review_request.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct PullRequestReviewRequestModel { + pub pull_request: Uuid, + pub reviewer: Option, + pub group: Option, + pub requested_by: Uuid, + pub created_at: DateTime, + pub removed_at: Option>, +} diff --git a/lib/model/repos/mod.rs b/lib/model/repos/mod.rs new file mode 100644 index 0000000..f142237 --- /dev/null +++ b/lib/model/repos/mod.rs @@ -0,0 +1,47 @@ +pub mod repo; +pub mod repo_audit_log; +pub mod repo_commit; +pub mod repo_commit_status; +pub mod repo_committer; +pub mod repo_deploy_key; +pub mod repo_fork; +pub mod repo_history_name; +pub mod repo_language; +pub mod repo_lfs_lock; +pub mod repo_lfs_object; +pub mod repo_license; +pub mod repo_lock; +pub mod repo_protect; +pub mod repo_ref; +pub mod repo_release; +pub mod repo_release_asset; +pub mod repo_secret; +pub mod repo_star; +pub mod repo_topic; +pub mod repo_watch; +pub mod repo_webhook; +pub mod repo_webhook_delivery; + +pub use repo::RepoModel; +pub use repo_audit_log::RepoAuditLogModel; +pub use repo_commit::RepoCommitModel; +pub use repo_committer::RepoCommitterModel; +pub use repo_deploy_key::RepoDeployKeyModel; +pub use repo_fork::RepoForkModel; +pub use repo_history_name::RepoHistoryNameModel; +pub use repo_language::RepoLanguageModel; +pub use repo_lfs_lock::RepoLfsLockModel; +pub use repo_lfs_object::RepoLfsObjectModel; +pub use repo_license::RepoLicenseModel; +pub use repo_lock::RepoLockModel; +pub use repo_protect::RepoProtectModel; +pub use repo_ref::RepoRefModel; +pub use repo_release::RepoReleaseModel; +pub use repo_release_asset::RepoReleaseAssetModel; +pub use repo_commit_status::RepoCommitStatusModel; +pub use repo_secret::RepoSecretModel; +pub use repo_star::RepoStarModel; +pub use repo_topic::RepoTopicModel; +pub use repo_watch::RepoWatchModel; +pub use repo_webhook::RepoWebhookModel; +pub use repo_webhook_delivery::RepoWebhookDeliveryModel; diff --git a/lib/model/repos/repo.rs b/lib/model/repos/repo.rs new file mode 100644 index 0000000..e855b4b --- /dev/null +++ b/lib/model/repos/repo.rs @@ -0,0 +1,23 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoModel { + pub id: Uuid, + pub wk: Uuid, + pub name: String, + pub description: Option, + pub default_branch: String, + pub visibility: String, + pub size_bytes: i64, + pub is_archived: bool, + pub is_template: bool, + pub is_mirror: bool, + pub created_by: Uuid, + pub storage_path: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/repos/repo_audit_log.rs b/lib/model/repos/repo_audit_log.rs new file mode 100644 index 0000000..0d36141 --- /dev/null +++ b/lib/model/repos/repo_audit_log.rs @@ -0,0 +1,18 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoAuditLogModel { + pub id: Uuid, + pub repo: Uuid, + pub actor: Option, + pub action: String, + pub target_type: String, + pub target_id: Option, + pub ip_address: Option, + pub user_agent: Option, + pub metadata: Option, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_commit.rs b/lib/model/repos/repo_commit.rs new file mode 100644 index 0000000..e4337bd --- /dev/null +++ b/lib/model/repos/repo_commit.rs @@ -0,0 +1,19 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoCommitModel { + pub id: Uuid, + pub repo: Uuid, + pub sha: String, + pub tree_sha: String, + pub parent_shas: String, + pub author: Uuid, + pub committer: Uuid, + pub message: String, + pub authored_at: DateTime, + pub committed_at: DateTime, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_commit_status.rs b/lib/model/repos/repo_commit_status.rs new file mode 100644 index 0000000..bc4d239 --- /dev/null +++ b/lib/model/repos/repo_commit_status.rs @@ -0,0 +1,18 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoCommitStatusModel { + pub id: Uuid, + pub repo: Uuid, + pub commit_sha: String, + pub state: String, + pub target_url: Option, + pub description: Option, + pub context: String, + pub creator: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_committer.rs b/lib/model/repos/repo_committer.rs new file mode 100644 index 0000000..c8d069b --- /dev/null +++ b/lib/model/repos/repo_committer.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoCommitterModel { + pub id: Uuid, + pub repo: Uuid, + pub user: Option, + pub name: String, + pub email: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_deploy_key.rs b/lib/model/repos/repo_deploy_key.rs new file mode 100644 index 0000000..9a6cef5 --- /dev/null +++ b/lib/model/repos/repo_deploy_key.rs @@ -0,0 +1,22 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoDeployKeyModel { + pub id: i64, + pub repo: Uuid, + pub title: String, + pub public_key: String, + pub fingerprint: String, + pub key_type: String, + pub key_bits: Option, + pub read_only: bool, + pub last_used_at: Option>, + pub expires_at: Option>, + pub is_revoked: bool, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_fork.rs b/lib/model/repos/repo_fork.rs new file mode 100644 index 0000000..6893747 --- /dev/null +++ b/lib/model/repos/repo_fork.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoForkModel { + pub id: Uuid, + pub repo: Uuid, + pub source_repo: Uuid, + pub forked_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_history_name.rs b/lib/model/repos/repo_history_name.rs new file mode 100644 index 0000000..ebb2647 --- /dev/null +++ b/lib/model/repos/repo_history_name.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoHistoryNameModel { + pub id: Uuid, + pub repo: Uuid, + pub name: String, + pub changed_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_language.rs b/lib/model/repos/repo_language.rs new file mode 100644 index 0000000..5cb13ca --- /dev/null +++ b/lib/model/repos/repo_language.rs @@ -0,0 +1,11 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoLanguageModel { + pub repo: Uuid, + pub language: String, + pub bytes: i64, + pub percentage: f32, +} diff --git a/lib/model/repos/repo_lfs_lock.rs b/lib/model/repos/repo_lfs_lock.rs new file mode 100644 index 0000000..fb46f23 --- /dev/null +++ b/lib/model/repos/repo_lfs_lock.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoLfsLockModel { + pub id: Uuid, + pub repo: Uuid, + pub path: String, + pub locked_by: Uuid, + pub ref_name: Option, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_lfs_object.rs b/lib/model/repos/repo_lfs_object.rs new file mode 100644 index 0000000..b5faf04 --- /dev/null +++ b/lib/model/repos/repo_lfs_object.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoLfsObjectModel { + pub repo: Uuid, + pub oid: String, + pub size_bytes: i64, + pub storage_key: String, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_license.rs b/lib/model/repos/repo_license.rs new file mode 100644 index 0000000..a929812 --- /dev/null +++ b/lib/model/repos/repo_license.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoLicenseModel { + pub repo: Uuid, + pub spdx_id: Option, + pub name: String, + pub url: Option, + pub detected_at: DateTime, +} diff --git a/lib/model/repos/repo_lock.rs b/lib/model/repos/repo_lock.rs new file mode 100644 index 0000000..117c313 --- /dev/null +++ b/lib/model/repos/repo_lock.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoLockModel { + pub id: Uuid, + pub repo: Uuid, + pub locked_by: Uuid, + pub reason: String, + pub expires_at: Option>, + pub created_at: DateTime, + pub released_at: Option>, +} diff --git a/lib/model/repos/repo_protect.rs b/lib/model/repos/repo_protect.rs new file mode 100644 index 0000000..b6e2b00 --- /dev/null +++ b/lib/model/repos/repo_protect.rs @@ -0,0 +1,20 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoProtectModel { + pub id: Uuid, + pub repo: Uuid, + pub pattern: String, + pub require_pull_request: bool, + pub required_approvals: i32, + pub require_status_checks: bool, + pub required_status_contexts: String, + pub enforce_admins: bool, + pub allow_force_pushes: bool, + pub allow_deletions: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_ref.rs b/lib/model/repos/repo_ref.rs new file mode 100644 index 0000000..497a795 --- /dev/null +++ b/lib/model/repos/repo_ref.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoRefModel { + pub id: Uuid, + pub repo: Uuid, + pub name: String, + pub kind: String, + pub target_sha: String, + pub is_default: bool, + pub is_protected: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_release.rs b/lib/model/repos/repo_release.rs new file mode 100644 index 0000000..9ef581f --- /dev/null +++ b/lib/model/repos/repo_release.rs @@ -0,0 +1,20 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoReleaseModel { + pub id: Uuid, + pub repo: Uuid, + pub tag_name: String, + pub target_commit_sha: String, + pub name: String, + pub body: Option, + pub draft: bool, + pub prerelease: bool, + pub author: Uuid, + pub published_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_release_asset.rs b/lib/model/repos/repo_release_asset.rs new file mode 100644 index 0000000..5f7997c --- /dev/null +++ b/lib/model/repos/repo_release_asset.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoReleaseAssetModel { + pub id: Uuid, + pub release_id: Uuid, + pub name: String, + pub content_type: Option, + pub size: i64, + pub download_count: i64, + pub storage_path: String, + pub uploader: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_secret.rs b/lib/model/repos/repo_secret.rs new file mode 100644 index 0000000..500ec58 --- /dev/null +++ b/lib/model/repos/repo_secret.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoSecretModel { + pub id: Uuid, + pub repo: Uuid, + pub name: String, + pub encrypted_value: String, + pub key_id: String, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_star.rs b/lib/model/repos/repo_star.rs new file mode 100644 index 0000000..276dc4b --- /dev/null +++ b/lib/model/repos/repo_star.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoStarModel { + pub repo: Uuid, + pub user: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_topic.rs b/lib/model/repos/repo_topic.rs new file mode 100644 index 0000000..842a6a2 --- /dev/null +++ b/lib/model/repos/repo_topic.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoTopicModel { + pub repo: Uuid, + pub topic: String, + pub created_at: DateTime, +} diff --git a/lib/model/repos/repo_watch.rs b/lib/model/repos/repo_watch.rs new file mode 100644 index 0000000..e74a9d4 --- /dev/null +++ b/lib/model/repos/repo_watch.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoWatchModel { + pub repo: Uuid, + pub user: Uuid, + pub level: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_webhook.rs b/lib/model/repos/repo_webhook.rs new file mode 100644 index 0000000..fb68fd0 --- /dev/null +++ b/lib/model/repos/repo_webhook.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoWebhookModel { + pub id: Uuid, + pub repo: Uuid, + pub url: String, + pub secret_hash: Option, + pub events: String, + pub active: bool, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/repos/repo_webhook_delivery.rs b/lib/model/repos/repo_webhook_delivery.rs new file mode 100644 index 0000000..f8fd7a5 --- /dev/null +++ b/lib/model/repos/repo_webhook_delivery.rs @@ -0,0 +1,20 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RepoWebhookDeliveryModel { + pub id: Uuid, + pub repo: Uuid, + pub webhook: Uuid, + pub event: String, + pub request_headers: Option, + pub request_body: Option, + pub response_status: Option, + pub response_headers: Option, + pub response_body: Option, + pub error: Option, + pub delivered_at: Option>, + pub created_at: DateTime, +} diff --git a/lib/model/room/dm_conversation.rs b/lib/model/room/dm_conversation.rs new file mode 100644 index 0000000..2e37242 --- /dev/null +++ b/lib/model/room/dm_conversation.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct DmConversationModel { + pub id: Uuid, + pub room: Uuid, + pub initiator: Uuid, + pub recipient: Uuid, + pub is_closed: bool, + pub closed_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/room/message_read.rs b/lib/model/room/message_read.rs new file mode 100644 index 0000000..bf0d1c8 --- /dev/null +++ b/lib/model/room/message_read.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct MessageReadModel { + pub id: Uuid, + pub message: Uuid, + pub room: Uuid, + pub user: Uuid, + pub read_at: DateTime, +} diff --git a/lib/model/room/message_star.rs b/lib/model/room/message_star.rs new file mode 100644 index 0000000..d28332e --- /dev/null +++ b/lib/model/room/message_star.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct MessageStarModel { + pub id: Uuid, + pub message: Uuid, + pub room: Uuid, + pub user: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/room/mod.rs b/lib/model/room/mod.rs new file mode 100644 index 0000000..2b511c8 --- /dev/null +++ b/lib/model/room/mod.rs @@ -0,0 +1,33 @@ +pub mod room; +pub mod room_ai; +pub mod room_attachments; +pub mod room_categories; +pub mod room_message; +pub mod room_message_edit_history; +pub mod room_permission_overwrite; +pub mod room_pins; +pub mod room_reactions; +pub mod room_server_label; +pub mod room_threads; +pub mod room_mention; +pub mod user_room_state; +pub mod dm_conversation; +pub mod message_read; +pub mod message_star; + +pub use room::RoomModel; +pub use room_ai::RoomAiModel; +pub use room_attachments::RoomAttachmentModel; +pub use room_categories::RoomCategoryModel; +pub use room_mention::RoomMentionModel; +pub use room_message::RoomMessageModel; +pub use room_message_edit_history::RoomMessageEditHistoryModel; +pub use room_permission_overwrite::RoomPermissionOverwriteModel; +pub use room_pins::RoomPinModel; +pub use room_reactions::RoomReactionModel; +pub use room_server_label::RoomServerLabelModel; +pub use room_threads::RoomThreadModel; +pub use user_room_state::UserRoomStateModel; +pub use dm_conversation::DmConversationModel; +pub use message_read::MessageReadModel; +pub use message_star::MessageStarModel; diff --git a/lib/model/room/room.rs b/lib/model/room/room.rs new file mode 100644 index 0000000..8d629b5 --- /dev/null +++ b/lib/model/room/room.rs @@ -0,0 +1,21 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomModel { + pub id: Uuid, + pub wk: Uuid, + pub parent: Option, + pub name: String, + pub topic: Option, + pub room_type: String, + pub position: i32, + pub is_private: bool, + pub is_archived: bool, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/room/room_ai.rs b/lib/model/room/room_ai.rs new file mode 100644 index 0000000..69eea18 --- /dev/null +++ b/lib/model/room/room_ai.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomAiModel { + pub room: Uuid, + pub agent_session: Uuid, + pub enabled: bool, + pub auto_reply: bool, + pub created_by: Uuid, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/room/room_attachments.rs b/lib/model/room/room_attachments.rs new file mode 100644 index 0000000..58cb77e --- /dev/null +++ b/lib/model/room/room_attachments.rs @@ -0,0 +1,18 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomAttachmentModel { + pub id: Uuid, + pub message: Uuid, + pub seq: i64, + pub file_name: String, + pub content_type: Option, + pub size_bytes: i64, + pub storage_key: String, + pub url: Option, + pub uploaded_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/room/room_categories.rs b/lib/model/room/room_categories.rs new file mode 100644 index 0000000..7136d7a --- /dev/null +++ b/lib/model/room/room_categories.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomCategoryModel { + pub id: Uuid, + pub wk: Uuid, + pub name: String, + pub position: i32, + pub collapsed: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/room/room_mention.rs b/lib/model/room/room_mention.rs new file mode 100644 index 0000000..2e73bf0 --- /dev/null +++ b/lib/model/room/room_mention.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomMentionModel { + pub id: Uuid, + pub message: Uuid, + pub seq: i64, + pub mention_type: String, + pub target_id: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/room/room_message.rs b/lib/model/room/room_message.rs new file mode 100644 index 0000000..7473a06 --- /dev/null +++ b/lib/model/room/room_message.rs @@ -0,0 +1,25 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomMessageModel { + pub id: Uuid, + pub room: Uuid, + pub seq: i64, + pub thread: Option, + pub parent: Option, + pub author: Uuid, + pub content: String, + pub content_type: String, + pub pinned: bool, + pub system_type: Option, + #[serde(default)] + pub metadata: Value, + pub edited_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, +} diff --git a/lib/model/room/room_message_edit_history.rs b/lib/model/room/room_message_edit_history.rs new file mode 100644 index 0000000..d373d54 --- /dev/null +++ b/lib/model/room/room_message_edit_history.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomMessageEditHistoryModel { + pub id: Uuid, + pub message: Uuid, + pub seq: i64, + pub editor: Uuid, + pub old_content: String, + pub new_content: String, + pub edited_at: DateTime, +} diff --git a/lib/model/room/room_permission_overwrite.rs b/lib/model/room/room_permission_overwrite.rs new file mode 100644 index 0000000..6286eb6 --- /dev/null +++ b/lib/model/room/room_permission_overwrite.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomPermissionOverwriteModel { + pub id: Uuid, + pub room: Uuid, + pub target_type: String, + pub target_id: Uuid, + pub allow: String, + pub deny: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/room/room_pins.rs b/lib/model/room/room_pins.rs new file mode 100644 index 0000000..f5d50b2 --- /dev/null +++ b/lib/model/room/room_pins.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomPinModel { + pub room: Uuid, + pub message: Uuid, + pub seq: i64, + pub pinned_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/room/room_reactions.rs b/lib/model/room/room_reactions.rs new file mode 100644 index 0000000..aeba374 --- /dev/null +++ b/lib/model/room/room_reactions.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomReactionModel { + pub message: Uuid, + pub user: Uuid, + pub seq: i64, + pub reaction: String, + pub created_at: DateTime, +} diff --git a/lib/model/room/room_server_label.rs b/lib/model/room/room_server_label.rs new file mode 100644 index 0000000..3548ec1 --- /dev/null +++ b/lib/model/room/room_server_label.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomServerLabelModel { + pub id: Uuid, + pub wk: Uuid, + pub name: String, + pub color: String, + pub description: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/room/room_threads.rs b/lib/model/room/room_threads.rs new file mode 100644 index 0000000..fb71547 --- /dev/null +++ b/lib/model/room/room_threads.rs @@ -0,0 +1,20 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct RoomThreadModel { + pub id: Uuid, + pub room: Uuid, + pub seq: i64, + pub starter_message: Option, + pub title: String, + pub created_by: Uuid, + pub archived: bool, + pub locked: bool, + pub last_message_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub archived_at: Option>, +} diff --git a/lib/model/room/user_room_state.rs b/lib/model/room/user_room_state.rs new file mode 100644 index 0000000..611377a --- /dev/null +++ b/lib/model/room/user_room_state.rs @@ -0,0 +1,18 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserRoomStateModel { + pub id: Uuid, + pub user: Uuid, + pub room: Uuid, + pub last_read_seq: i64, + pub last_read_at: Option>, + pub is_pinned: bool, + pub is_muted: bool, + pub hide_muted: bool, + pub notify_level: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/system/mod.rs b/lib/model/system/mod.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/lib/model/system/mod.rs @@ -0,0 +1 @@ + diff --git a/lib/model/users/mod.rs b/lib/model/users/mod.rs new file mode 100644 index 0000000..fd6be7f --- /dev/null +++ b/lib/model/users/mod.rs @@ -0,0 +1,37 @@ +pub mod user; +pub mod user_2fa; +pub mod user_accessibility; +pub mod user_appearance; +pub mod user_billing; +pub mod user_billing_history; +pub mod user_blacklist; +pub mod user_email; +pub mod user_favorite; +pub mod user_gpg_key; +pub mod user_notifications; +pub mod user_pass; +pub mod user_passreset; +pub mod user_privacy; +pub mod user_profile; +pub mod user_session; +pub mod user_ssh_key; +pub mod user_token; + +pub use user::UserModel; +pub use user_2fa::User2FaModel; +pub use user_accessibility::UserAccessibilityModel; +pub use user_appearance::UserAppearanceModel; +pub use user_billing::UserBillingModel; +pub use user_billing_history::UserBillingHistoryModel; +pub use user_blacklist::UserBlacklistModel; +pub use user_email::UserEmailModel; +pub use user_favorite::UserFavoriteModel; +pub use user_gpg_key::UserGpgKeyModel; +pub use user_notifications::UserNotificationModel; +pub use user_pass::UserPasswordModel; +pub use user_passreset::UserPasswordResetModel; +pub use user_privacy::UserPrivacyModel; +pub use user_profile::UserProfileModel; +pub use user_session::UserSessionModel; +pub use user_ssh_key::UserSshKeyModel; +pub use user_token::UserTokenModel; diff --git a/lib/model/users/user.rs b/lib/model/users/user.rs new file mode 100644 index 0000000..dffa0e3 --- /dev/null +++ b/lib/model/users/user.rs @@ -0,0 +1,18 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserModel { + pub id: Uuid, + pub username: String, + pub display_name: String, + pub avatar_url: String, + pub website_url: String, + pub allow_use: bool, + pub can_search: bool, + pub last_sign_in_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_2fa.rs b/lib/model/users/user_2fa.rs new file mode 100644 index 0000000..b3a3981 --- /dev/null +++ b/lib/model/users/user_2fa.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct User2FaModel { + pub user: Uuid, + pub secret: Option, + pub backup_codes: String, + pub enabled: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_accessibility.rs b/lib/model/users/user_accessibility.rs new file mode 100644 index 0000000..86db5cd --- /dev/null +++ b/lib/model/users/user_accessibility.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserAccessibilityModel { + pub user: Uuid, + pub reduce_motion: bool, + pub high_contrast: bool, + pub screen_reader_optimized: bool, + pub font_scale_percent: i32, + pub color_blind_mode: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_appearance.rs b/lib/model/users/user_appearance.rs new file mode 100644 index 0000000..ff8fd2d --- /dev/null +++ b/lib/model/users/user_appearance.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserAppearanceModel { + pub user: Uuid, + pub theme: String, + pub code_theme: String, + pub layout_density: String, + pub sidebar_collapsed: bool, + pub show_line_numbers: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_billing.rs b/lib/model/users/user_billing.rs new file mode 100644 index 0000000..57ba865 --- /dev/null +++ b/lib/model/users/user_billing.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::{FromRow, types::Decimal}; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserBillingModel { + pub user: Uuid, + pub balance: Decimal, + pub is_pro: bool, + pub total_supply: Decimal, + pub total_supply_usable: Decimal, + pub cycle_start: Option>, + pub cycle_end: Option>, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_billing_history.rs b/lib/model/users/user_billing_history.rs new file mode 100644 index 0000000..07d69fb --- /dev/null +++ b/lib/model/users/user_billing_history.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::{FromRow, types::Decimal}; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserBillingHistoryModel { + pub id: Uuid, + pub user: Uuid, + pub amount: Decimal, + pub currency: String, + pub reason: String, + pub created_at: DateTime, +} diff --git a/lib/model/users/user_blacklist.rs b/lib/model/users/user_blacklist.rs new file mode 100644 index 0000000..1841d19 --- /dev/null +++ b/lib/model/users/user_blacklist.rs @@ -0,0 +1,11 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserBlacklistModel { + pub user: Uuid, + pub black: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/users/user_email.rs b/lib/model/users/user_email.rs new file mode 100644 index 0000000..2b60b61 --- /dev/null +++ b/lib/model/users/user_email.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserEmailModel { + pub user: Uuid, + pub email: String, + pub created_at: DateTime, + pub active: bool, + pub last_use_login: Option>, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_favorite.rs b/lib/model/users/user_favorite.rs new file mode 100644 index 0000000..dd794bb --- /dev/null +++ b/lib/model/users/user_favorite.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserFavoriteModel { + pub user: Uuid, + pub target: Uuid, + pub created_at: Uuid, +} diff --git a/lib/model/users/user_gpg_key.rs b/lib/model/users/user_gpg_key.rs new file mode 100644 index 0000000..dd5348a --- /dev/null +++ b/lib/model/users/user_gpg_key.rs @@ -0,0 +1,22 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserGpgKeyModel { + pub id: i64, + pub user: Uuid, + pub title: String, + pub public_key: String, + pub fingerprint: String, + pub key_id: String, + pub primary_key_id: Option, + pub emails: String, + pub is_verified: bool, + pub last_used_at: Option>, + pub expires_at: Option>, + pub is_revoked: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_notifications.rs b/lib/model/users/user_notifications.rs new file mode 100644 index 0000000..701a4a3 --- /dev/null +++ b/lib/model/users/user_notifications.rs @@ -0,0 +1,24 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserNotificationModel { + pub user: Uuid, + pub email_enabled: bool, + pub in_app_enabled: bool, + pub push_enabled: bool, + pub digest_mode: String, + pub dnd_enabled: bool, + pub dnd_start_minute: Option, + pub dnd_end_minute: Option, + pub marketing_enabled: bool, + pub security_enabled: bool, + pub product_enabled: bool, + pub push_subscription_endpoint: Option, + pub push_subscription_keys_p256dh: Option, + pub push_subscription_keys_auth: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_pass.rs b/lib/model/users/user_pass.rs new file mode 100644 index 0000000..961d770 --- /dev/null +++ b/lib/model/users/user_pass.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserPasswordModel { + pub user: Uuid, + pub hash: String, + pub salt: String, + pub is_active: bool, + pub reason: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_passreset.rs b/lib/model/users/user_passreset.rs new file mode 100644 index 0000000..96951c8 --- /dev/null +++ b/lib/model/users/user_passreset.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserPasswordResetModel { + pub token: String, + pub user: Uuid, + pub expires_at: DateTime, + pub used: bool, + pub created_at: DateTime, +} diff --git a/lib/model/users/user_privacy.rs b/lib/model/users/user_privacy.rs new file mode 100644 index 0000000..a278029 --- /dev/null +++ b/lib/model/users/user_privacy.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserPrivacyModel { + pub user: Uuid, + pub profile_visibility: String, + pub email_visibility: String, + pub activity_visibility: String, + pub allow_search_indexing: bool, + pub allow_direct_messages: bool, + pub show_online_status: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_profile.rs b/lib/model/users/user_profile.rs new file mode 100644 index 0000000..3805203 --- /dev/null +++ b/lib/model/users/user_profile.rs @@ -0,0 +1,15 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserProfileModel { + pub user: Uuid, + pub language: String, + pub theme: String, + pub timezone: String, + + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_session.rs b/lib/model/users/user_session.rs new file mode 100644 index 0000000..53cd1c6 --- /dev/null +++ b/lib/model/users/user_session.rs @@ -0,0 +1,19 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserSessionModel { + pub id: Uuid, + pub user: Uuid, + pub token_hash: String, + pub device_name: Option, + pub user_agent: Option, + pub ip_address: Option, + pub last_seen_at: Option>, + pub expires_at: DateTime, + pub is_revoked: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_ssh_key.rs b/lib/model/users/user_ssh_key.rs new file mode 100644 index 0000000..7efab54 --- /dev/null +++ b/lib/model/users/user_ssh_key.rs @@ -0,0 +1,21 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserSshKeyModel { + pub id: i64, + pub user: Uuid, + pub title: String, + pub public_key: String, + pub fingerprint: String, + pub key_type: String, + pub key_bits: Option, + pub is_verified: bool, + pub last_used_at: Option>, + pub expires_at: Option>, + pub is_revoked: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/users/user_token.rs b/lib/model/users/user_token.rs new file mode 100644 index 0000000..a7c580d --- /dev/null +++ b/lib/model/users/user_token.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct UserTokenModel { + pub id: i64, + pub user: Uuid, + pub name: String, + pub token_hash: String, + pub scopes: String, + pub expires_at: Option>, + pub is_revoked: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/workspace/mod.rs b/lib/model/workspace/mod.rs new file mode 100644 index 0000000..92dbd87 --- /dev/null +++ b/lib/model/workspace/mod.rs @@ -0,0 +1,19 @@ +pub mod wk_apply_join; +pub mod wk_billing; +pub mod wk_gp_member; +pub mod wk_gp_role; +pub mod wk_group; +pub mod wk_history_name; +pub mod wk_join_approval; +pub mod wk_join_strategy; +pub mod wk_member; +pub mod workspace; +pub use wk_apply_join::WkApplyJoinModel; +pub use wk_billing::WkBillingModel; +pub use wk_gp_role::WkGpRoleModel; +pub use wk_group::WkGroupModel; +pub use wk_history_name::WkHistoryNameModel; +pub use wk_join_approval::WkJoinApprovalModel; +pub use wk_join_strategy::WkJoinStrategyModel; +pub use wk_member::WkMemberModel; +pub use workspace::WorkspaceModel; diff --git a/lib/model/workspace/wk_apply_join.rs b/lib/model/workspace/wk_apply_join.rs new file mode 100644 index 0000000..0b3556a --- /dev/null +++ b/lib/model/workspace/wk_apply_join.rs @@ -0,0 +1,17 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkApplyJoinModel { + pub id: Uuid, + pub wk: Uuid, + pub user: Uuid, + pub status: String, + pub question: Option, + pub answer: Option, + pub message: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/workspace/wk_billing.rs b/lib/model/workspace/wk_billing.rs new file mode 100644 index 0000000..6b14874 --- /dev/null +++ b/lib/model/workspace/wk_billing.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use rust_decimal::Decimal; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkBillingModel { + pub wk: Uuid, + pub balance: Decimal, + pub total_supply: Decimal, + pub total_supply_usable: Decimal, + pub updated_at: DateTime, +} diff --git a/lib/model/workspace/wk_gp_member.rs b/lib/model/workspace/wk_gp_member.rs new file mode 100644 index 0000000..15491cc --- /dev/null +++ b/lib/model/workspace/wk_gp_member.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkGpMemberModel { + pub user: Uuid, + pub gp: Uuid, + pub join_at: DateTime, + pub leave_at: Option>, +} diff --git a/lib/model/workspace/wk_gp_role.rs b/lib/model/workspace/wk_gp_role.rs new file mode 100644 index 0000000..87c3af3 --- /dev/null +++ b/lib/model/workspace/wk_gp_role.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkGpRoleModel { + pub wk: Uuid, + pub gp: Uuid, + pub repo_read: bool, + pub repo_write: bool, + pub channel_read: bool, + pub channel_write: bool, + pub ai_read: bool, + pub ai_write: bool, + pub pr_review: bool, + pub issues_ass: bool, + pub log_view: bool, +} diff --git a/lib/model/workspace/wk_group.rs b/lib/model/workspace/wk_group.rs new file mode 100644 index 0000000..652e27c --- /dev/null +++ b/lib/model/workspace/wk_group.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkGroupModel { + pub id: Uuid, + pub name: String, + pub wk: Uuid, + pub created_at: DateTime, + pub avatar_url: Option, + pub is_deleted: bool, +} diff --git a/lib/model/workspace/wk_history_name.rs b/lib/model/workspace/wk_history_name.rs new file mode 100644 index 0000000..091b034 --- /dev/null +++ b/lib/model/workspace/wk_history_name.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkHistoryNameModel { + pub id: Uuid, + pub wk: Uuid, + pub name: String, + pub changed_by: Uuid, + pub created_at: DateTime, +} diff --git a/lib/model/workspace/wk_join_approval.rs b/lib/model/workspace/wk_join_approval.rs new file mode 100644 index 0000000..67cb57a --- /dev/null +++ b/lib/model/workspace/wk_join_approval.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkJoinApprovalModel { + pub id: Uuid, + pub apply: Uuid, + pub wk: Uuid, + pub user: Uuid, + pub approver: Uuid, + pub approved: bool, + pub reason: Option, + pub created_at: DateTime, +} diff --git a/lib/model/workspace/wk_join_strategy.rs b/lib/model/workspace/wk_join_strategy.rs new file mode 100644 index 0000000..dab260e --- /dev/null +++ b/lib/model/workspace/wk_join_strategy.rs @@ -0,0 +1,16 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkJoinStrategyModel { + pub wk: Uuid, + pub require_approval: bool, + pub require_question: bool, + pub question: Option, + pub answer: Option, + pub enabled: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/lib/model/workspace/wk_member.rs b/lib/model/workspace/wk_member.rs new file mode 100644 index 0000000..882c134 --- /dev/null +++ b/lib/model/workspace/wk_member.rs @@ -0,0 +1,14 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WkMemberModel { + pub wk: Uuid, + pub user: Uuid, + pub owner: bool, + pub admin: bool, + pub join_at: DateTime, + pub leave_at: Option>, +} diff --git a/lib/model/workspace/workspace.rs b/lib/model/workspace/workspace.rs new file mode 100644 index 0000000..dffd347 --- /dev/null +++ b/lib/model/workspace/workspace.rs @@ -0,0 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use uuid::Uuid; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct WorkspaceModel { + pub id: Uuid, + pub name: String, + pub description: String, + pub avatar_url: String, + pub created_at: DateTime, +} diff --git a/lib/parsefile/Cargo.toml b/lib/parsefile/Cargo.toml new file mode 100644 index 0000000..b80aa26 --- /dev/null +++ b/lib/parsefile/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "parsefile" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[dependencies] +serde = { workspace = true, features = ["derive"] } +serde_yaml = { workspace = true } +thiserror = { workspace = true } +indexmap = { workspace = true, features = ["serde"] } + +[lints] +workspace = true diff --git a/lib/parsefile/ep.pipeline.yaml b/lib/parsefile/ep.pipeline.yaml new file mode 100644 index 0000000..84be346 --- /dev/null +++ b/lib/parsefile/ep.pipeline.yaml @@ -0,0 +1,75 @@ +version: "1" + +name: "container-ci" + +run_on: + push: + branches: + - main + - "release/*" + tags: + - "v*" + pull_request: + branches: + - main + - "feature/*" + +jobs: + test: + stage: verify + runtime: + type: container + image: "node:20-bookworm" + runner_labels: + - linux + - container + steps: + - name: "Checkout" + task: "source.checkout" + - name: "Test" + command: "pnpm test" + + build: + stage: package + depends_on: + - test + runtime: + type: container + image: "node:20-bookworm" + steps: + - name: "Checkout" + task: "source.checkout" + - name: "Build" + command: "pnpm build" + - name: "Upload dist" + task: "artifact.upload" + params: + name: "dist" + paths: + - "dist/**" + + image: + stage: package + depends_on: + - build + runtime: + type: container + image: "docker:27-cli" + runner_labels: + - linux + - docker + steps: + - name: "Checkout" + task: "source.checkout" + - name: "Download dist" + task: "artifact.download" + params: + name: "dist" + path: "dist" + - name: "Build image" + task: "container.build" + params: + dockerfile: "Dockerfile" + context: "." + tags: + - "registry.example.com/app:${commit.sha}" \ No newline at end of file diff --git a/lib/parsefile/src/error.rs b/lib/parsefile/src/error.rs new file mode 100644 index 0000000..8499fca --- /dev/null +++ b/lib/parsefile/src/error.rs @@ -0,0 +1,23 @@ +#[derive(Debug, thiserror::Error)] +pub enum ParseError { + #[error("failed to read file: {0}")] + Io(#[from] std::io::Error), + #[error("failed to parse YAML: {0}")] + Yaml(#[from] serde_yaml::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ValidationError { + #[error("step '{step}' in job '{job}' has neither task nor command")] + EmptyStep { job: String, step: String }, + #[error("dependency error: {0}")] + Dependency(#[from] DependencyError), +} + +#[derive(Debug, thiserror::Error)] +pub enum DependencyError { + #[error("job '{job}' depends_on '{dependency}' which does not exist")] + UnknownReference { job: String, dependency: String }, + #[error("circular dependency detected: {chain}")] + CircularDependency { chain: String }, +} diff --git a/lib/parsefile/src/lib.rs b/lib/parsefile/src/lib.rs new file mode 100644 index 0000000..976ee04 --- /dev/null +++ b/lib/parsefile/src/lib.rs @@ -0,0 +1,411 @@ +mod error; +mod matcher; +mod model; + +use std::path::Path; + +pub use error::{DependencyError, ParseError, ValidationError}; +pub use model::{ + BranchPattern, Job, Pipeline, PipelineVersion, PullRequestTrigger, + PushTrigger, RunOn, Runtime, RuntimeType, Step, StepKind, StepParam, + TriggerEvent, +}; + +pub fn parse_file(path: &Path) -> Result { + let content = std::fs::read_to_string(path)?; + parse_from_str(&content) +} + +pub fn parse_from_str(content: &str) -> Result { + let pipeline: Pipeline = serde_yaml::from_str(content)?; + Ok(pipeline) +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + + fn sample_pipeline_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("ep.pipeline.yaml") + } + + #[test] + fn parse_sample_pipeline() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + + assert_eq!(pipeline.version, PipelineVersion::V1); + assert_eq!(pipeline.name, "container-ci"); + assert_eq!(pipeline.jobs.len(), 3); + } + + #[test] + fn parse_test_job() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + let test_job = &pipeline.jobs["test"]; + + assert_eq!(test_job.stage, "verify"); + assert!(test_job.depends_on.is_empty()); + assert_eq!(test_job.runtime.runtime_type, RuntimeType::Container); + assert_eq!(test_job.runtime.image, "node:20-bookworm"); + assert_eq!(test_job.runner_labels, vec!["linux", "container"]); + assert_eq!(test_job.steps.len(), 2); + + let checkout = &test_job.steps[0]; + assert_eq!(checkout.name, "Checkout"); + assert_eq!(checkout.kind(), Some(StepKind::Task)); + assert_eq!(checkout.task, Some("source.checkout".to_string())); + + let test_step = &test_job.steps[1]; + assert_eq!(test_step.name, "Test"); + assert_eq!(test_step.kind(), Some(StepKind::Command)); + assert_eq!(test_step.command, Some("pnpm test".to_string())); + } + + #[test] + fn parse_build_job_with_params() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + let build_job = &pipeline.jobs["build"]; + + assert_eq!(build_job.stage, "package"); + assert_eq!(build_job.depends_on, vec!["test"]); + + let upload_step = &build_job.steps[2]; + assert_eq!(upload_step.name, "Upload dist"); + assert_eq!(upload_step.kind(), Some(StepKind::Task)); + assert_eq!(upload_step.task, Some("artifact.upload".to_string())); + assert_eq!(upload_step.params.get("name").unwrap().as_string(), "dist"); + assert_eq!( + upload_step.params.get("paths").unwrap().as_list(), + &["dist/**".to_string()] + ); + } + + #[test] + fn parse_image_job() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + let image_job = &pipeline.jobs["image"]; + + assert_eq!(image_job.stage, "package"); + assert_eq!(image_job.depends_on, vec!["build"]); + assert_eq!(image_job.runtime.runtime_type, RuntimeType::Container); + assert_eq!(image_job.runtime.image, "docker:27-cli"); + assert_eq!(image_job.runner_labels, vec!["linux", "docker"]); + + let build_image_step = &image_job.steps[2]; + assert_eq!(build_image_step.name, "Build image"); + assert_eq!(build_image_step.kind(), Some(StepKind::Task)); + assert_eq!( + build_image_step + .params + .get("dockerfile") + .unwrap() + .as_string(), + "Dockerfile" + ); + assert_eq!( + build_image_step.params.get("tags").unwrap().as_list(), + &["registry.example.com/app:${commit.sha}".to_string()] + ); + } + + #[test] + fn parse_run_on() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + let run_on = pipeline.run_on.as_ref().unwrap(); + + let push = run_on.push.as_ref().unwrap(); + assert_eq!(push.branches.len(), 2); + assert_eq!(push.branches[0].as_str(), "main"); + assert_eq!(push.branches[1].as_str(), "release/*"); + assert_eq!(push.tags.len(), 1); + assert_eq!(push.tags[0].as_str(), "v*"); + + let pr = run_on.pull_request.as_ref().unwrap(); + assert_eq!(pr.branches.len(), 2); + assert_eq!(pr.branches[0].as_str(), "main"); + assert_eq!(pr.branches[1].as_str(), "feature/*"); + } + + #[test] + fn stages_and_jobs_by_stage() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + + let stages = pipeline.stages(); + assert_eq!(stages, vec!["verify", "package"]); + + let verify_jobs = pipeline.jobs_by_stage("verify"); + assert_eq!(verify_jobs.len(), 1); + + let package_jobs = pipeline.jobs_by_stage("package"); + assert_eq!(package_jobs.len(), 2); + } + + #[test] + fn execution_order() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + let order = pipeline.execution_order().unwrap(); + + let test_idx = order.iter().position(|j| j == "test").unwrap(); + let build_idx = order.iter().position(|j| j == "build").unwrap(); + let image_idx = order.iter().position(|j| j == "image").unwrap(); + assert!(test_idx < build_idx); + assert!(build_idx < image_idx); + } + + #[test] + fn validate_sample_pipeline() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!(pipeline.validate().is_ok()); + } + + #[test] + fn validation_empty_step() { + let yaml = r#" +version: "1" +name: "bad-pipeline" +jobs: + build: + stage: package + runtime: + type: container + image: "node:20" + steps: + - name: "NoAction" +"#; + let pipeline = parse_from_str(yaml).unwrap(); + let err = pipeline.validate().unwrap_err(); + assert!(matches!(err, ValidationError::EmptyStep { .. })); + } + + #[test] + fn validation_unknown_dependency() { + let yaml = r#" +version: "1" +name: "bad-pipeline" +jobs: + build: + stage: package + depends_on: + - nonexistent + runtime: + type: container + image: "node:20" + steps: + - name: "Build" + command: "echo hi" +"#; + let pipeline = parse_from_str(yaml).unwrap(); + let err = pipeline.execution_order().unwrap_err(); + assert!(matches!(err, DependencyError::UnknownReference { .. })); + } + + #[test] + fn validation_circular_dependency() { + let yaml = r#" +version: "1" +name: "bad-pipeline" +jobs: + a: + stage: build + depends_on: + - b + runtime: + type: container + image: "node:20" + steps: + - name: "A" + command: "echo a" + b: + stage: build + depends_on: + - a + runtime: + type: container + image: "node:20" + steps: + - name: "B" + command: "echo b" +"#; + let pipeline = parse_from_str(yaml).unwrap(); + let err = pipeline.execution_order().unwrap_err(); + assert!(matches!(err, DependencyError::CircularDependency { .. })); + } + + #[test] + fn parse_from_str_unknown_version() { + let yaml = r#" +version: "99" +name: "test" +jobs: {} +"#; + assert!(parse_from_str(yaml).is_err()); + } + + #[test] + fn parse_from_str_unknown_runtime_type() { + let yaml = r#" +version: "1" +name: "test" +jobs: + build: + stage: package + runtime: + type: unknown + image: "node:20" + steps: + - name: "Build" + command: "echo hi" +"#; + assert!(parse_from_str(yaml).is_err()); + } + + #[test] + fn should_run_push_main() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!(pipeline.should_run(&TriggerEvent::PushBranch("main".into()))); + } + + #[test] + fn should_run_push_release_branch() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!( + pipeline + .should_run(&TriggerEvent::PushBranch("release/1.0".into())) + ); + } + + #[test] + fn should_run_push_feature_branch_rejected() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!( + !pipeline + .should_run(&TriggerEvent::PushBranch("feature/login".into())) + ); + } + + #[test] + fn should_run_push_tag() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!(pipeline.should_run(&TriggerEvent::PushTag("v1.0.0".into()))); + assert!(!pipeline.should_run(&TriggerEvent::PushTag("1.0.0".into()))); + } + + #[test] + fn should_run_pull_request_main() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!(pipeline.should_run(&TriggerEvent::PullRequest { + target: "main".into() + })); + } + + #[test] + fn should_run_pull_request_feature() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!(pipeline.should_run(&TriggerEvent::PullRequest { + target: "feature/login".into() + })); + } + + #[test] + fn should_run_pull_request_develop_rejected() { + let pipeline = parse_file(&sample_pipeline_path()).unwrap(); + assert!(!pipeline.should_run(&TriggerEvent::PullRequest { + target: "develop".into() + })); + } + + #[test] + fn should_run_no_run_on_always_true() { + let yaml = r#" +version: "1" +name: "always-run" +jobs: + build: + stage: package + runtime: + type: container + image: "node:20" + steps: + - name: "Build" + command: "echo hi" +"#; + let pipeline = parse_from_str(yaml).unwrap(); + assert!( + pipeline.should_run(&TriggerEvent::PushBranch("anything".into())) + ); + assert!(pipeline.should_run(&TriggerEvent::PullRequest { + target: "anything".into() + })); + } + + #[test] + fn should_run_push_with_empty_branches_matches_all() { + let yaml = r#" +version: "1" +name: "push-any-branch" +run_on: + push: {} +jobs: + build: + stage: package + runtime: + type: container + image: "node:20" + steps: + - name: "Build" + command: "echo hi" +"#; + let pipeline = parse_from_str(yaml).unwrap(); + assert!( + pipeline.should_run(&TriggerEvent::PushBranch("anything".into())) + ); + assert!(pipeline.should_run(&TriggerEvent::PushTag("v1".into()))); + } + + #[test] + fn should_run_push_only_no_pull_request() { + let yaml = r#" +version: "1" +name: "push-only" +run_on: + push: + branches: + - main +jobs: + build: + stage: package + runtime: + type: container + image: "node:20" + steps: + - name: "Build" + command: "echo hi" +"#; + let pipeline = parse_from_str(yaml).unwrap(); + assert!(pipeline.should_run(&TriggerEvent::PushBranch("main".into()))); + assert!(!pipeline.should_run(&TriggerEvent::PullRequest { + target: "main".into() + })); + } + + #[test] + fn branch_pattern_exact_and_glob() { + let exact = BranchPattern::new("main"); + assert!(!exact.is_pattern()); + assert!(exact.matches("main")); + assert!(!exact.matches("develop")); + + let glob = BranchPattern::new("release/*"); + assert!(glob.is_pattern()); + assert!(glob.matches("release/1.0")); + assert!(!glob.matches("release/1.0/fix")); + assert!(!glob.matches("main")); + + let deep = BranchPattern::new("feature/**"); + assert!(deep.is_pattern()); + assert!(deep.matches("feature/login")); + assert!(deep.matches("feature/login/subtask")); + } +} diff --git a/lib/parsefile/src/matcher.rs b/lib/parsefile/src/matcher.rs new file mode 100644 index 0000000..6aee9cd --- /dev/null +++ b/lib/parsefile/src/matcher.rs @@ -0,0 +1,93 @@ +pub fn glob_match(pattern: &str, text: &str) -> bool { + match_impl(pattern.as_bytes(), text.as_bytes()) +} + +fn match_impl(pattern: &[u8], text: &[u8]) -> bool { + if pattern.is_empty() { + return text.is_empty(); + } + if pattern.len() >= 2 && pattern[0] == b'*' && pattern[1] == b'*' { + let rest = skip_double_star_slash(pattern); + for i in 0..=text.len() { + if match_impl(rest, &text[i..]) { + return true; + } + } + return false; + } + if pattern[0] == b'*' { + let rest = &pattern[1..]; + for i in 0..=text.len() { + if i > 0 && text[i - 1] == b'/' { + break; + } + if match_impl(rest, &text[i..]) { + return true; + } + } + return false; + } + + if text.is_empty() { + return false; + } + if pattern[0] == b'?' { + if text[0] == b'/' { + return false; + } + return match_impl(&pattern[1..], &text[1..]); + } + if pattern[0] == text[0] { + return match_impl(&pattern[1..], &text[1..]); + } + + false +} +fn skip_double_star_slash(pattern: &[u8]) -> &[u8] { + let rest = &pattern[2..]; + if rest.first() == Some(&b'/') { + &rest[1..] + } else { + rest + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn exact_match() { + assert!(glob_match("main", "main")); + assert!(!glob_match("main", "develop")); + } + + #[test] + fn single_star_matches_segment() { + assert!(glob_match("featurerelease", "release")); + assert!(glob_match("**/release", "a/b/c/release")); + } + + #[test] + fn question_mark() { + assert!(glob_match("v?", "v1")); + assert!(glob_match("v?", "vA")); + assert!(!glob_match("v?", "v12")); + assert!(!glob_match("v?", "v/")); + } + + #[test] + fn tag_patterns() { + assert!(glob_match("v*", "v1.0.0")); + assert!(glob_match("v*", "v2")); + assert!(!glob_match("v*", "2.0.0")); + assert!(glob_match("release-*", "release-2024")); + } + + #[test] + fn star_at_start() { + assert!(glob_match("*-staging", "dev-staging")); + assert!(glob_match("*-staging", "test-staging")); + assert!(!glob_match("*-staging", "dev/staging")); + } +} diff --git a/lib/parsefile/src/model.rs b/lib/parsefile/src/model.rs new file mode 100644 index 0000000..2c3dd67 --- /dev/null +++ b/lib/parsefile/src/model.rs @@ -0,0 +1,280 @@ +use std::collections::HashMap; + +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; + +use crate::{ + error::{DependencyError, ValidationError}, + matcher::glob_match, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PipelineVersion { + #[serde(rename = "1")] + V1, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum RuntimeType { + #[serde(rename = "container")] + Container, + #[serde(rename = "host")] + Host, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum StepKind { + Task, + Command, +} +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TriggerEvent { + PushBranch(String), + PushTag(String), + PullRequest { target: String }, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RunOn { + #[serde(default)] + pub push: Option, + #[serde(default)] + pub pull_request: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PushTrigger { + #[serde(default)] + pub branches: Vec, + #[serde(default)] + pub tags: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PullRequestTrigger { + #[serde(default)] + pub branches: Vec, +} +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BranchPattern(String); + +impl BranchPattern { + pub fn new(pattern: impl Into) -> Self { + Self(pattern.into()) + } + + pub fn as_str(&self) -> &str { + &self.0 + } + + pub fn is_pattern(&self) -> bool { + self.0.contains('*') || self.0.contains('?') + } + pub fn matches(&self, name: &str) -> bool { + if !self.is_pattern() { + self.0 == name + } else { + glob_match(&self.0, name) + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Pipeline { + pub version: PipelineVersion, + pub name: String, + #[serde(default)] + pub run_on: Option, + pub jobs: IndexMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Job { + pub stage: String, + #[serde(default)] + pub depends_on: Vec, + pub runtime: Runtime, + #[serde(default)] + pub runner_labels: Vec, + pub steps: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Runtime { + #[serde(rename = "type")] + pub runtime_type: RuntimeType, + pub image: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Step { + pub name: String, + #[serde(default)] + pub task: Option, + #[serde(default)] + pub command: Option, + #[serde(default)] + pub params: IndexMap, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StepParam { + String(String), + List(Vec), +} + +impl Step { + pub fn kind(&self) -> Option { + if self.task.is_some() { + Some(StepKind::Task) + } else if self.command.is_some() { + Some(StepKind::Command) + } else { + None + } + } +} + +impl StepParam { + pub fn as_string(&self) -> &str { + match self { + StepParam::String(s) => s, + StepParam::List(_) => { + panic!("StepParam::as_string called on a List variant") + } + } + } + + pub fn as_list(&self) -> &[String] { + match self { + StepParam::String(_) => { + panic!("StepParam::as_list called on a String variant") + } + StepParam::List(v) => v, + } + } +} + +impl Pipeline { + pub fn jobs_by_stage(&self, stage: &str) -> Vec<&Job> { + self.jobs.values().filter(|j| j.stage == stage).collect() + } + + pub fn stages(&self) -> Vec<&str> { + let mut seen: Vec<&str> = Vec::new(); + for job in self.jobs.values() { + if !seen.contains(&job.stage.as_str()) { + seen.push(&job.stage); + } + } + seen + } + pub fn should_run(&self, event: &TriggerEvent) -> bool { + let run_on = match &self.run_on { + None => return true, + Some(r) => r, + }; + + match event { + TriggerEvent::PushBranch(branch) => { + run_on.push.as_ref().map_or(false, |p| { + if p.branches.is_empty() { + true + } else { + p.branches.iter().any(|bp| bp.matches(branch)) + } + }) + } + TriggerEvent::PushTag(tag) => { + run_on.push.as_ref().map_or(false, |p| { + if p.tags.is_empty() { + true + } else { + p.tags.iter().any(|tp| tp.matches(tag)) + } + }) + } + TriggerEvent::PullRequest { target } => { + run_on.pull_request.as_ref().map_or(false, |pr| { + if pr.branches.is_empty() { + true + } else { + pr.branches.iter().any(|bp| bp.matches(target)) + } + }) + } + } + } + + pub fn execution_order(&self) -> Result, DependencyError> { + let mut in_degree: HashMap<&str, usize> = HashMap::new(); + let mut graph: HashMap<&str, Vec<&str>> = HashMap::new(); + + for name in self.jobs.keys() { + in_degree.insert(name, 0); + graph.insert(name, Vec::new()); + } + + for (name, job) in &self.jobs { + for dep in &job.depends_on { + if !self.jobs.contains_key(dep) { + return Err(DependencyError::UnknownReference { + job: name.clone(), + dependency: dep.clone(), + }); + } + graph.get_mut(dep.as_str()).unwrap().push(name); + *in_degree.get_mut(name.as_str()).unwrap() += 1; + } + } + + let mut queue: Vec<&str> = in_degree + .iter() + .filter(|(_, deg)| **deg == 0) + .map(|(name, _)| *name) + .collect(); + queue.sort(); + + let mut order: Vec = Vec::new(); + while let Some(node) = queue.pop() { + order.push(node.to_string()); + for &neighbor in &graph[node] { + let deg = in_degree.get_mut(neighbor).unwrap(); + *deg -= 1; + if *deg == 0 { + queue.push(neighbor); + queue.sort(); + } + } + } + + if order.len() != self.jobs.len() { + let cycle_nodes: Vec = self + .jobs + .keys() + .filter(|k| !order.contains(k)) + .cloned() + .collect(); + return Err(DependencyError::CircularDependency { + chain: cycle_nodes.join(" -> "), + }); + } + + Ok(order) + } + + pub fn validate(&self) -> Result<(), ValidationError> { + for (name, job) in &self.jobs { + for step in &job.steps { + if step.kind().is_none() { + return Err(ValidationError::EmptyStep { + job: name.clone(), + step: step.name.clone(), + }); + } + } + } + self.execution_order()?; + Ok(()) + } +} diff --git a/lib/queue/Cargo.toml b/lib/queue/Cargo.toml new file mode 100644 index 0000000..6c78440 --- /dev/null +++ b/lib/queue/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "queue" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "queue" +[dependencies] +async-nats = { workspace = true } +tokio = { workspace = true, features = ["full"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tracing = { workspace = true } +anyhow = { workspace = true } +async-trait = { workspace = true } +config = { workspace = true } +futures-util = { workspace = true } + +[lints] +workspace = true diff --git a/lib/queue/consumer.rs b/lib/queue/consumer.rs new file mode 100644 index 0000000..77fa31b --- /dev/null +++ b/lib/queue/consumer.rs @@ -0,0 +1,206 @@ +use std::{sync::Arc, time::Duration}; + +use async_nats::{HeaderMap, jetstream}; +use config::AppConfig; +use futures_util::StreamExt; +use tracing::{error, info, warn}; + +use crate::{ + handler::{AckAction, MessageHandler}, + producer::{NatsProducer, connect_jetstream, ensure_stream}, +}; + +pub struct NatsConsumer { + stream: jetstream::stream::Stream, + producer: NatsProducer, + max_deliver: i64, + retry_delay_secs: u64, + durable_name: String, +} + +impl NatsConsumer { + pub async fn new( + config: &AppConfig, + group_id: &str, + ) -> anyhow::Result { + let jetstream = connect_jetstream(config).await?; + let stream = ensure_stream(config, &jetstream).await?; + let producer = NatsProducer::new(config).await?; + + Ok(Self { + stream, + producer, + max_deliver: config.nats_max_deliver(), + retry_delay_secs: config.nats_retry_delay_secs(), + durable_name: durable_name(group_id), + }) + } + + pub async fn start_consuming( + &self, + topics: &[&str], + handler: H, + ) -> anyhow::Result<()> + where + H: MessageHandler + 'static, + { + let topics_owned: Vec = + topics.iter().map(|topic| topic.to_string()).collect(); + + let consumer = self + .stream + .get_or_create_consumer( + &self.durable_name, + jetstream::consumer::pull::Config { + durable_name: Some(self.durable_name.clone()), + ack_wait: Duration::from_secs(self.retry_delay_secs), + max_deliver: self.max_deliver, + filter_subjects: topics_owned.clone(), + ..Default::default() + }, + ) + .await?; + + info!("NATS consumer started subscribing to: {:?}", topics_owned); + + let producer = self.producer.clone(); + let max_deliver = self.max_deliver; + let retry_delay_secs = self.retry_delay_secs; + let handler = Arc::new(handler); + + tokio::spawn(async move { + let messages = consumer.messages().await; + let mut messages = match messages { + Ok(messages) => messages, + Err(error) => { + error!( + "NATS error while opening consumer stream: {:?}", + error + ); + return; + } + }; + + while let Some(message_result) = messages.next().await { + match message_result { + Ok(message) => { + handle_message( + &producer, + max_deliver, + retry_delay_secs, + handler.as_ref(), + message, + ) + .await; + } + Err(error) => { + error!("NATS error while consuming: {:?}", error); + } + } + } + }); + + Ok(()) + } +} + +async fn handle_message( + producer: &NatsProducer, + max_deliver: i64, + retry_delay_secs: u64, + handler: &H, + message: jetstream::Message, +) where + H: MessageHandler + ?Sized, +{ + let subject = message.subject.to_string(); + let payload = message.payload.clone(); + let delivered = message.info().map(|info| info.delivered).unwrap_or(1); + + match handler.handle(&subject, &payload).await { + AckAction::Ack => ack_message(&message, &subject, "message").await, + AckAction::Nack => { + if let Err(error) = handle_nack( + producer, + &message, + &subject, + &payload, + delivered, + max_deliver, + retry_delay_secs, + ) + .await + { + error!( + "Failed to route NACKed message from subject {}: {:?}", + subject, error + ); + } + } + } +} + +async fn handle_nack( + producer: &NatsProducer, + message: &jetstream::Message, + subject: &str, + payload: &[u8], + delivered: i64, + max_deliver: i64, + retry_delay_secs: u64, +) -> anyhow::Result<()> { + if delivered < max_deliver { + warn!( + "Message in subject {} failed (NACK). Retrying delivery {}/{} in {} seconds", + subject, delivered, max_deliver, retry_delay_secs + ); + message + .ack_with(jetstream::AckKind::Nak(Some(Duration::from_secs( + retry_delay_secs, + )))) + .await + .map_err(|error| { + anyhow::anyhow!("failed to nack message: {error}") + })?; + return Ok(()); + } + + let dlq_subject = format!("{subject}.dlq"); + error!( + "Message in subject {} exceeded max deliver attempts ({}). Routing to DLQ: {}", + subject, max_deliver, dlq_subject + ); + + let mut headers = HeaderMap::new(); + headers.append("x-original-subject", subject); + headers.append("x-delivered-count", delivered.to_string()); + producer + .send_raw(&dlq_subject, "", payload, Some(headers)) + .await + .map_err(|error| { + anyhow::anyhow!( + "failed to send DLQ message to {dlq_subject}: {error}" + ) + })?; + message.ack().await.map_err(|error| { + anyhow::anyhow!("failed to ack DLQ message: {error}") + })?; + Ok(()) +} + +async fn ack_message( + message: &jetstream::Message, + subject: &str, + description: &str, +) { + if let Err(error) = message.ack().await { + error!( + "Failed to ack {} in subject {}: {:?}", + description, subject, error + ); + } +} + +fn durable_name(name: &str) -> String { + name.replace('.', "-") +} diff --git a/lib/queue/handler.rs b/lib/queue/handler.rs new file mode 100644 index 0000000..869c5a7 --- /dev/null +++ b/lib/queue/handler.rs @@ -0,0 +1,9 @@ +pub enum AckAction { + Ack, + Nack, +} + +#[async_trait::async_trait] +pub trait MessageHandler: Send + Sync { + async fn handle(&self, topic: &str, payload: &[u8]) -> AckAction; +} diff --git a/lib/queue/lib.rs b/lib/queue/lib.rs new file mode 100644 index 0000000..1c57159 --- /dev/null +++ b/lib/queue/lib.rs @@ -0,0 +1,7 @@ +mod consumer; +mod handler; +mod producer; + +pub use consumer::NatsConsumer; +pub use handler::{AckAction, MessageHandler}; +pub use producer::NatsProducer; diff --git a/lib/queue/producer.rs b/lib/queue/producer.rs new file mode 100644 index 0000000..6e770b9 --- /dev/null +++ b/lib/queue/producer.rs @@ -0,0 +1,90 @@ +use std::time::Duration; + +use async_nats::{HeaderMap, jetstream}; +use config::AppConfig; +use serde::Serialize; + +#[derive(Clone)] +pub struct NatsProducer { + jetstream: jetstream::Context, +} + +impl NatsProducer { + pub async fn new(config: &AppConfig) -> anyhow::Result { + let jetstream = connect_jetstream(config).await?; + ensure_stream(config, &jetstream).await?; + + Ok(Self { jetstream }) + } + + pub async fn send( + &self, + subject: &str, + key: &str, + payload: &T, + headers: Option, + ) -> anyhow::Result<()> + where + T: Serialize + ?Sized, + { + let payload_bytes = serde_json::to_vec(payload)?; + self.send_raw(subject, key, &payload_bytes, headers).await + } + + pub async fn send_raw( + &self, + subject: &str, + key: &str, + payload: &[u8], + headers: Option, + ) -> anyhow::Result<()> { + let mut headers = headers.unwrap_or_default(); + if !key.is_empty() { + headers.append("x-message-key", key); + } + + let subject = subject.to_string(); + let publish = if headers.is_empty() { + self.jetstream + .publish(subject.clone(), payload.to_vec().into()) + .await? + } else { + self.jetstream + .publish_with_headers(subject, headers, payload.to_vec().into()) + .await? + }; + + tokio::time::timeout(Duration::from_secs(5), publish).await??; + + Ok(()) + } +} + +pub(crate) async fn connect_jetstream( + config: &AppConfig, +) -> anyhow::Result { + let client = match config.nats_token() { + Some(token) if !token.is_empty() => { + async_nats::ConnectOptions::with_token(token) + .connect(config.nats_url()) + .await? + } + _ => async_nats::connect(config.nats_url()).await?, + }; + + Ok(jetstream::new(client)) +} + +pub(crate) async fn ensure_stream( + config: &AppConfig, + jetstream: &jetstream::Context, +) -> anyhow::Result { + Ok(jetstream + .get_or_create_stream(jetstream::stream::Config { + name: config.nats_stream_name(), + subjects: config.nats_stream_subjects(), + max_age: Duration::from_secs(config.nats_max_age_secs()), + ..Default::default() + }) + .await?) +} diff --git a/lib/service/Cargo.toml b/lib/service/Cargo.toml new file mode 100644 index 0000000..892efcb --- /dev/null +++ b/lib/service/Cargo.toml @@ -0,0 +1,59 @@ +[package] +name = "service" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "service" +[dependencies] +db = { workspace = true } +config = { workspace = true } +cache = { workspace = true } +queue = { workspace = true } +email = { workspace = true } +ai = { workspace = true } +session = { workspace = true } +storage = { workspace = true } +model = { workspace = true } + +serde = { workspace = true, features = ["derive"]} +utoipa = { workspace = true, features = ["chrono","uuid"] } +hkdf = "0.13.0" +rsa = "0.10.0-rc.18" +rand_chacha = "0.10.0" +sha2 = "0.11.0" +chacha20poly1305 = "0.11.0-rc.3" +base64 = "0.22.1" +serde_json = "1.0.150" +rand = "0.10.1" +hex = "0.4.3" +tracing = "0.1.44" +sqlx = { workspace = true } +captcha-rs = "0.5.0" +argon2 = "0.6.0-rc.8" +chrono = { workspace = true } +uuid = { workspace = true, features = ["v4", "v7", "serde"] } +rust_decimal = "1.42.0" +hmac = "0.13.0" +sha1 = "0.11.0" +git = { workspace = true } +tonic = { workspace = true, features = ["transport"] } +comrak = { workspace = true } +deadpool-redis = { workspace = true, features = ["cluster"] } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] } +tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true } +async-trait = { workspace = true } +rig-core = { workspace = true } +[lints] +workspace = true diff --git a/lib/service/agent/billing.rs b/lib/service/agent/billing.rs new file mode 100644 index 0000000..cb7fc2c --- /dev/null +++ b/lib/service/agent/billing.rs @@ -0,0 +1,316 @@ +use chrono::Utc; +use db::sqlx::{self, types::Decimal}; +use rust_decimal::Decimal as RustDecimal; +use uuid::Uuid; + +use super::types::{BillingRecord, BillingTarget, SessionContext}; +use crate::error::AppError; +use crate::AppService; + +impl AppService { + pub(crate) async fn agent_calculate_cost( + &self, + model_version_id: Uuid, + input_tokens: i64, + output_tokens: i64, + ) -> Result, AppError> { + let (input_price_per_million, output_price_per_million) = + self.agent_resolve_pricing(model_version_id).await?; + + let input_price = match input_price_per_million { + Some(p) => p, + None => return Ok(None), + }; + + let output_price = match output_price_per_million { + Some(p) => p, + None => return Ok(None), + }; + + let million = RustDecimal::from(1_000_000u64); + let input_decimal = RustDecimal::from(input_tokens); + let output_decimal = RustDecimal::from(output_tokens); + + let cost = (input_decimal * input_price / million) + + (output_decimal * output_price / million); + Ok(Some((cost, "USD".to_string()))) + } + pub(crate) async fn agent_deduct_billing( + &self, + ctx: &SessionContext, + cost: RustDecimal, + ) -> Result<(), AppError> { + match ctx.billing_target { + BillingTarget::User => { + let user_id = ctx + .user_id + .ok_or_else(|| AppError::BadRequest("user billing target requires user_id".to_string()))?; + self.deduct_user_balance(user_id, cost).await + } + BillingTarget::Workspace => { + let wk_id = ctx + .workspace_id + .ok_or_else(|| AppError::BadRequest("workspace billing target requires workspace_id".to_string()))?; + self.deduct_workspace_balance(wk_id, cost).await + } + } + } + pub(crate) async fn agent_record_usage( + &self, + record: &BillingRecord, + ) -> Result<(), AppError> { + let cost_decimal: Option = record.cost.map(|c| c.into()); + + sqlx::query( + "INSERT INTO agent_token_usage \ + (invocation, session, model_version, \ + input_tokens, output_tokens, cached_input_tokens, \ + cache_read_tokens, cache_write_tokens, reasoning_tokens, \ + total_tokens, cost, currency, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)", + ) + .bind(record.invocation_id) + .bind(record.session_id) + .bind(record.model_version_id) + .bind(record.input_tokens) + .bind(record.output_tokens) + .bind(record.cached_input_tokens) + .bind(record.cache_read_tokens) + .bind(record.cache_write_tokens) + .bind(record.reasoning_tokens) + .bind(record.total_tokens) + .bind(&cost_decimal) + .bind(&record.currency) + .bind(record.created_at) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + pub(crate) async fn agent_record_invocation( + &self, + invocation_id: Uuid, + session_id: Uuid, + conversation_id: Option, + message_id: Option, + model_version_id: Uuid, + status: &str, + error: Option<&str>, + ) -> Result<(), AppError> { + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_model_invocation \ + (id, session, conversation, message, model_version, status, error, \ + started_at, finished_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", + ) + .bind(invocation_id) + .bind(session_id) + .bind(conversation_id) + .bind(message_id) + .bind(model_version_id) + .bind(status) + .bind(error) + .bind(now) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + pub(crate) async fn agent_record_tool_call( + &self, + invocation_id: Uuid, + session_id: Uuid, + conversation_id: Option, + message_id: Option, + tool_call_id: &str, + tool_name: &str, + arguments: Option<&str>, + result: Option<&str>, + error: Option<&str>, + status: &str, + latency_ms: Option, + ) -> Result<(), AppError> { + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_tool_call_log \ + (invocation, session, conversation, message, tool_call_id, \ + tool_name, arguments, result, error, status, \ + started_at, finished_at, latency_ms) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)", + ) + .bind(invocation_id) + .bind(session_id) + .bind(conversation_id) + .bind(message_id) + .bind(tool_call_id) + .bind(tool_name) + .bind(arguments) + .bind(result) + .bind(error) + .bind(status) + .bind(now) + .bind(now) + .bind(latency_ms) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } +} + +impl AppService { + async fn deduct_user_balance( + &self, + user_id: Uuid, + cost: RustDecimal, + ) -> Result<(), AppError> { + const MAX_RETRIES: u32 = 3; + for attempt in 0..MAX_RETRIES { + match self.try_deduct_user_balance(user_id, cost).await { + Ok(()) => return Ok(()), + Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => { + let backoff_ms = 10 * (1 << attempt); + tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await; + continue; + } + Err(e) => return Err(e), + } + } + Err(AppError::TxnError) + } + + async fn try_deduct_user_balance( + &self, + user_id: Uuid, + cost: RustDecimal, + ) -> Result<(), AppError> { + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let current: Option = sqlx::query_scalar( + "SELECT balance FROM user_billing WHERE \"user\" = $1 FOR UPDATE", + ) + .bind(user_id) + .fetch_optional(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let current = match current { + Some(balance) => balance, + None => { + let default_balance = RustDecimal::from(20); + let now = Utc::now(); + sqlx::query( + "INSERT INTO user_billing (\"user\", balance, created_at, updated_at) \ + VALUES ($1, $2, $3, $3)", + ) + .bind(user_id) + .bind(&default_balance) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + default_balance + } + }; + + if current < cost { + return Err(AppError::BadRequest( + "insufficient balance".to_string(), + )); + } + + let new_balance = current - cost; + + sqlx::query( + "UPDATE user_billing SET balance = $1, updated_at = $2 WHERE \"user\" = $3", + ) + .bind(&new_balance) + .bind(Utc::now()) + .bind(user_id) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + Ok(()) + } + + async fn deduct_workspace_balance( + &self, + wk_id: Uuid, + cost: RustDecimal, + ) -> Result<(), AppError> { + const MAX_RETRIES: u32 = 3; + for attempt in 0..MAX_RETRIES { + match self.try_deduct_workspace_balance(wk_id, cost).await { + Ok(()) => return Ok(()), + Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => { + let backoff_ms = 10 * (1 << attempt); + tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await; + continue; + } + Err(e) => return Err(e), + } + } + Err(AppError::TxnError) + } + + async fn try_deduct_workspace_balance( + &self, + wk_id: Uuid, + cost: RustDecimal, + ) -> Result<(), AppError> { + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let current: Option = sqlx::query_scalar( + "SELECT balance FROM wk_billing WHERE wk = $1 FOR UPDATE", + ) + .bind(wk_id) + .fetch_optional(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let current = match current { + Some(balance) => balance, + None => { + let default_balance = RustDecimal::from(20); + let now = Utc::now(); + sqlx::query( + "INSERT INTO wk_billing (wk, balance, updated_at) \ + VALUES ($1, $2, $3)", + ) + .bind(wk_id) + .bind(&default_balance) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + default_balance + } + }; + + if current < cost { + return Err(AppError::BadRequest( + "insufficient workspace balance".to_string(), + )); + } + + let new_balance = current - cost; + + sqlx::query( + "UPDATE wk_billing SET balance = $1, updated_at = $2 WHERE wk = $3", + ) + .bind(&new_balance) + .bind(Utc::now()) + .bind(wk_id) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + Ok(()) + } +} diff --git a/lib/service/agent/compaction.rs b/lib/service/agent/compaction.rs new file mode 100644 index 0000000..b63bb00 --- /dev/null +++ b/lib/service/agent/compaction.rs @@ -0,0 +1,163 @@ +use ai::agent::AgentConfig; +use ai::agent::RigAgent; +use ai::client::AiClient; +use ai::agent::request::AgentRequest; +use db::sqlx; +use tracing::{info, warn}; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +const COMPACTION_SYSTEM_PROMPT: &str = r#"You are a conversation context compaction assistant. + +Your task: summarize the older portion of a conversation so the agent can continue working with only the summary + recent messages. + +Rules: +- Preserve: key decisions, file paths, technical details, user preferences, unresolved questions. +- Discard: redundant tool outputs, verbose explanations that were already acted upon, pleasantries. +- Write in the same language as the conversation. +- Output a concise structured summary using bullet points. +- Keep the summary under 800 tokens. +- Do NOT answer any questions from the conversation. Only summarize."#; + +const COMPACTION_TRIGGER_CHARS: usize = 80_000; +const RECENT_MESSAGES_TO_KEEP: usize = 10; + +impl AppService { + pub(crate) async fn agent_maybe_compact( + &self, + ai_client: &AiClient, + model_name: &str, + conversation_id: Uuid, + ) -> Result<(), AppError> { + let rows: Vec<(Uuid, String, String)> = sqlx::query_as( + "SELECT id, role, content \ + FROM agent_message \ + WHERE conversation = $1 \ + AND deleted_at IS NULL \ + AND status = 'completed' \ + ORDER BY created_at ASC", + ) + .bind(conversation_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let total_chars: usize = rows.iter().map(|(_, _, c)| c.len()).sum(); + + if total_chars < COMPACTION_TRIGGER_CHARS { + return Ok(()); + } + + if rows.len() <= RECENT_MESSAGES_TO_KEEP { + return Ok(()); + } + + let split_at = rows.len().saturating_sub(RECENT_MESSAGES_TO_KEEP); + let older = &rows[..split_at]; + + let existing_summary: Option = sqlx::query_scalar( + "SELECT compacted_summary FROM agent_conversation WHERE id = $1", + ) + .bind(conversation_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .flatten(); + + let mut body = String::new(); + if let Some(ref prev) = existing_summary { + body.push_str("\n"); + body.push_str(prev); + body.push_str("\n\n\n"); + body.push_str("Merge the previous summary with the new messages below:\n\n"); + } + for (_, role, content) in older { + body.push_str(&format!("[{}]: {}\n\n", role, content)); + } + + let summary = match self + .agent_run_compaction_llm(ai_client, model_name, &body) + .await + { + Ok(s) => s, + Err(e) => { + warn!( + conversation_id = %conversation_id, + error = %e, + "compaction LLM call failed, skipping compaction" + ); + return Ok(()); + } + }; + + sqlx::query( + "UPDATE agent_conversation \ + SET compacted_summary = $1, updated_at = now() \ + WHERE id = $2", + ) + .bind(&summary) + .bind(conversation_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let ids: Vec = older.iter().map(|(id, _, _)| *id).collect(); + if !ids.is_empty() { + let now = chrono::Utc::now(); + sqlx::query( + "UPDATE agent_message \ + SET deleted_at = $1, updated_at = $1 \ + WHERE id = ANY($2::uuid[]) AND deleted_at IS NULL", + ) + .bind(now) + .bind(&ids) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + info!( + conversation_id = %conversation_id, + compacted_count = older.len(), + kept_recent = RECENT_MESSAGES_TO_KEEP, + "conversation compacted successfully" + ); + + Ok(()) + } + + async fn agent_run_compaction_llm( + &self, + ai_client: &AiClient, + model_name: &str, + body: &str, + ) -> Result { + let config = AgentConfig::new(model_name) + .map_err(|e| AppError::AiError(e))? + .with_system_prompt(COMPACTION_SYSTEM_PROMPT) + .with_temperature(Some(0.2)) + .with_max_completion_tokens(Some(1024)) + .with_quiet_mode(true); + + let agent = RigAgent::new(ai_client.clone(), config) + .map_err(|e| AppError::AiError(e))?; + + let request = AgentRequest::new(body); + let summary = agent + .chat(request, Vec::new()) + .await + .map_err(|e| AppError::AiError(e))?; + + let summary = summary.trim().to_string(); + + if summary.is_empty() { + return Err(AppError::InternalServerError( + "compaction returned empty summary".to_string(), + )); + } + + Ok(summary) + } +} diff --git a/lib/service/agent/config.rs b/lib/service/agent/config.rs new file mode 100644 index 0000000..0420280 --- /dev/null +++ b/lib/service/agent/config.rs @@ -0,0 +1,395 @@ +use ai::{ + agent::AgentConfig, + client::{AiClient, AiClientConfig, EmbedConfig, EndpointConfig}, +}; +use db::sqlx::{self, types::Decimal}; +use model::{ + agent::AgentSessionModel, ai::AiModelVersionModel, ai::AiProviderModel, +}; +use uuid::Uuid; + +use super::types::SessionContext; +use crate::error::AppError; +use crate::AppService; + +impl AppService { + pub(crate) async fn agent_session_context( + &self, + session_id: Uuid, + user_id: Uuid, + ) -> Result { + let session = sqlx::query_as::<_, AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, \ + source, parent_session_id, toolset_json, \ + memory_provider, memory_provider_config, iteration_budget, \ + created_by, created_at, updated_at, deleted_at \ + FROM agent_session WHERE id = $1 AND deleted_at IS NULL AND enabled = true", + ) + .bind(session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + if let Some(wk_id) = session.wk { + let _ = self + .workspace_require_member(wk_id, user_id) + .await?; + } else if Some(user_id) != session.user { + return Err(AppError::PermissionDenied); + } + + let model_version_id = session + .model_version + .ok_or_else(|| AppError::BadRequest("agent session has no model_version".to_string()))?; + + let version = self.resolve_model_version(model_version_id).await?; + + let billing_target = if session.wk.is_some() { + super::types::BillingTarget::Workspace + } else { + super::types::BillingTarget::User + }; + + Ok(SessionContext { + session_id, + user_id: session.user, + workspace_id: session.wk, + system_prompt: self.build_system_prompt_with_context( + &session, + user_id, + ).await, + model_version_id: version.id, + provider_model_name: version.provider_model_name, + temperature: session.temperature, + max_output_tokens: session.max_output_tokens, + tool_policy_json: session.tool_policy, + toolset_json: session.toolset_json, + variables_json: session.variables, + iteration_budget: session.iteration_budget, + memory_provider: session.memory_provider, + source: session.source, + parent_session_id: session.parent_session_id, + billing_target, + }) + } + pub(crate) async fn agent_build_ai_client( + &self, + model_version_id: Uuid, + ) -> Result { + let version = self.resolve_model_version(model_version_id).await?; + + let model_record = sqlx::query_as::<_, model::ai::AiModelModel>( + "SELECT id, provider, name, display_name, description, modality, \ + context_window, input_token_limit, output_token_limit, \ + enabled, public, created_at, updated_at, deleted_at \ + FROM ai_model WHERE id = $1", + ) + .bind(version.model) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("ai model not found".to_string()))?; + + let provider = sqlx::query_as::<_, AiProviderModel>( + "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ + FROM ai_provider WHERE id = $1 AND enabled = true", + ) + .bind(model_record.provider) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("ai provider not found".to_string()))?; + + let base_url = provider + .base_url + .unwrap_or_else(|| self.config.ai_basic_url().unwrap_or_default()); + let api_key = self + .config + .ai_api_key() + .map_err(|e| AppError::InternalServerError(format!("AI API key: {e}")))?; + + let embed_base_url = self + .config + .get_embed_model_base_url() + .map_err(|e| AppError::InternalServerError(format!("embed base url: {e}")))?; + let embed_api_key = self + .config + .get_embed_model_api_key() + .map_err(|e| AppError::InternalServerError(format!("embed api key: {e}")))?; + + let llm_config = EndpointConfig::new(&base_url, &api_key) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + let embed_endpoint = EndpointConfig::new(&embed_base_url, &embed_api_key) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + let embed_config = EmbedConfig::new( + embed_endpoint, + self.config + .get_embed_model_name() + .map_err(|e| AppError::InternalServerError(e.to_string()))?, + self.config + .get_embed_model_dimensions() + .map_err(|e| AppError::InternalServerError(e.to_string()))?, + ) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let client_config = AiClientConfig::new(llm_config, embed_config) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + AiClient::new(client_config).map_err(|e| AppError::InternalServerError(e.to_string())) + } + pub(crate) fn agent_build_config( + &self, + ctx: &SessionContext, + max_steps_override: Option, + ) -> AgentConfig { + let mut config = AgentConfig::new(&ctx.provider_model_name) + .unwrap_or_else(|_| { + AgentConfig::new("gpt-4o").expect("default agent config") + }); + + config.model = ctx.provider_model_name.clone(); + config.system_prompt = ctx.system_prompt.clone(); + if let Some(ref vars_json) = ctx.variables_json { + if let Ok(vars) = + serde_json::from_str::>(vars_json) + { + if !vars.is_empty() { + let mut prompt = config.system_prompt.clone(); + let mut any_replaced = false; + for (key, val) in &vars { + let placeholder = format!("{{{{{}}}}}", key); + if prompt.contains(&placeholder) { + let replacement = match val { + serde_json::Value::String(s) => s.clone(), + other => other.to_string(), + }; + prompt = prompt.replace(&placeholder, &replacement); + any_replaced = true; + } + } + if !any_replaced { + prompt.push_str("\n\n\n"); + for (key, val) in &vars { + let val_str = match val { + serde_json::Value::String(s) => s.clone(), + other => other.to_string(), + }; + prompt.push_str(&format!("- {}: {}\n", key, val_str)); + } + prompt.push_str(""); + } + config.system_prompt = prompt; + } + } + } + + if let Some(temp) = ctx.temperature { + config.temperature = Some(temp as f64); + } + if let Some(max_tok) = ctx.max_output_tokens { + config.max_completion_tokens = Some(max_tok as u64); + } + if let Some(max_steps) = max_steps_override { + config.max_iterations = max_steps; + } + if let Some(budget) = ctx.iteration_budget { + config.iteration_budget = budget as usize; + } + if let Some(ref policy_json) = ctx.tool_policy_json { + match serde_json::from_str::(policy_json) { + Ok(policy) => { + let allowed: Vec = policy + .get("allowed") + .and_then(|v| v.as_array()) + .map(|a| { + a.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + let denied: Vec = policy + .get("denied") + .and_then(|v| v.as_array()) + .map(|a| { + a.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + + if !allowed.is_empty() || !denied.is_empty() { + config = config.with_tool_policy(allowed, denied); + } + } + Err(e) => { + tracing::warn!( + error = %e, + "failed to parse tool policy JSON, ignoring" + ); + } + } + } + if let Some(ref toolset_json) = ctx.toolset_json { + if let Ok(policy) = + serde_json::from_str::(toolset_json) + { + let enabled: Vec = policy + .get("enabled") + .and_then(|v| v.as_array()) + .map(|a| { + a.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + let disabled: Vec = policy + .get("disabled") + .and_then(|v| v.as_array()) + .map(|a| { + a.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + + if !enabled.is_empty() || !disabled.is_empty() { + config = config.with_toolset_policy(enabled, disabled); + } + } + } + if ctx.memory_provider.as_deref() == Some("none") { + config.skip_memory = true; + } + + config + } + async fn resolve_model_version( + &self, + id: Uuid, + ) -> Result { + if let Some(version) = sqlx::query_as::<_, AiModelVersionModel>( + "SELECT id, model, version, provider_model_name, \ + input_price_per_million, output_price_per_million, cached_input_price_per_million, \ + training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ + FROM ai_model_version WHERE id = $1 AND enabled = true", + ) + .bind(id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + return Ok(version); + } + + let version = sqlx::query_as::<_, AiModelVersionModel>( + "SELECT v.id, v.model, v.version, v.provider_model_name, \ + v.input_price_per_million, v.output_price_per_million, v.cached_input_price_per_million, \ + v.training_cutoff, v.released_at, v.deprecated_at, v.enabled, v.created_at, v.updated_at \ + FROM ai_model_version v \ + INNER JOIN ai_model m ON m.id = v.model \ + WHERE m.id = $1 AND v.enabled = true AND m.enabled = true \ + ORDER BY v.created_at DESC LIMIT 1", + ) + .bind(id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("model version not found".to_string()))?; + + Ok(version) + } + pub(crate) async fn agent_resolve_pricing( + &self, + model_version_id: Uuid, + ) -> Result<(Option, Option), AppError> { + let version = sqlx::query_as::<_, AiModelVersionModel>( + "SELECT id, model, version, provider_model_name, \ + input_price_per_million, output_price_per_million, cached_input_price_per_million, \ + training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ + FROM ai_model_version WHERE id = $1", + ) + .bind(model_version_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("model version not found".to_string()))?; + + Ok(( + version.input_price_per_million, + version.output_price_per_million, + )) + } + + /// Build the system prompt enriched with workspace and user context. + async fn build_system_prompt_with_context( + &self, + session: &model::agent::AgentSessionModel, + user_id: Uuid, + ) -> String { + let base = session + .system_prompt + .clone() + .unwrap_or_else(|| ai::agent::config::default_system_prompt().to_string()); + + let mut context_section = String::new(); + + // Workspace context + if let Some(wk_id) = session.wk { + let wk: Option<(String,)> = sqlx::query_as( + "SELECT name FROM workspace WHERE id = $1") + .bind(wk_id) + .fetch_optional(self.db.reader()) + .await + .unwrap_or(None); + if let Some((wk_name,)) = wk { + context_section.push_str(&format!( + "- You are operating in workspace \"{wk_name}\" (id: {wk_id}).\n" + )); + context_section.push_str( + " All file operations, repo access, and code changes are scoped to this workspace.\n" + ); + } + } + + // User context + if let Some(session_user_id) = session.user { + let u: Option<(String, String)> = sqlx::query_as( + "SELECT display_name, username FROM \"user\" WHERE id = $1") + .bind(session_user_id) + .fetch_optional(self.db.reader()) + .await + .unwrap_or(None); + if let Some((display_name, username)) = u { + let name = if display_name.is_empty() { &username } else { &display_name }; + context_section.push_str(&format!( + "- The current user is {name} (username: {username}, id: {session_user_id}).\n" + )); + } + } else { + let u: Option<(String, String)> = sqlx::query_as( + "SELECT display_name, username FROM \"user\" WHERE id = $1") + .bind(user_id) + .fetch_optional(self.db.reader()) + .await + .unwrap_or(None); + if let Some((display_name, username)) = u { + let name = if display_name.is_empty() { &username } else { &display_name }; + context_section.push_str(&format!( + "- The current user is {name} (username: {username}, id: {user_id}).\n" + )); + } + } + + if context_section.is_empty() { + return base; + } + + format!( + "{base}\n\n\nThe following is provided for context. Always use this information\nwhen tailoring your responses, resolving references, and scoping operations.\n\n{context_section}" + ) + } +} diff --git a/lib/service/agent/context.rs b/lib/service/agent/context.rs new file mode 100644 index 0000000..38c8ab8 --- /dev/null +++ b/lib/service/agent/context.rs @@ -0,0 +1,266 @@ +use std::time::Duration; + +use ai::{ + agent::request::{ + AgentContextChunk, AgentMessage, AgentRequest, + }, + client::AiClient, + rag::{ + RagClient, RagConfig, RagDocument, + }, +}; +use db::sqlx; +use uuid::Uuid; + +use super::types::SessionContext; +use crate::error::AppError; +use crate::AppService; +const MAX_HISTORY_MESSAGES: u32 = 50; +const MAX_HISTORY_CHARS: usize = 500_000; +const MAX_HISTORY_ESTIMATED_TOKENS: u64 = 64_000; + +impl AppService { + pub(crate) async fn agent_build_request( + &self, + ai_client: &AiClient, + ctx: &SessionContext, + conversation_id: Option, + input: String, + timeout_secs: Option, + ) -> Result { + let mut request = AgentRequest::new(input.clone()); + if let Some(secs) = timeout_secs { + request = request.with_timeout(Duration::from_secs(secs)); + } + if let Some(conv_id) = conversation_id { + let mut all_messages = Vec::new(); + let compacted: Option = sqlx::query_scalar( + "SELECT compacted_summary FROM agent_conversation WHERE id = $1", + ) + .bind(conv_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .flatten(); + + if let Some(summary) = compacted { + all_messages.push(AgentMessage::User(format!( + "[Previous conversation summary]\n{}\n[End of summary — messages below are the most recent verbatim exchanges]", + summary + ))); + } + + let messages = self + .agent_load_conversation_messages(conv_id) + .await?; + all_messages.extend(messages); + + request = request.with_messages(all_messages); + } + let kb_context = self + .agent_load_knowledge_base(ai_client, ctx, &input) + .await?; + let (memories_text, _memory_rows) = self.agent_load_memories(ctx.session_id).await?; + + let mut all_context = kb_context; + if !memories_text.is_empty() { + all_context.push(AgentContextChunk::new( + "long_term_memory", + memories_text, + )); + } + if !all_context.is_empty() { + request = request.with_context(all_context); + } + + Ok(request) + } + pub(crate) async fn agent_load_conversation_messages( + &self, + conversation_id: Uuid, + ) -> Result, AppError> { + let rows: Vec<(String, String)> = sqlx::query_as( + "SELECT m.role, m.content \ + FROM agent_message m \ + INNER JOIN agent_conversation c ON c.id = m.conversation \ + WHERE m.conversation = $1 \ + AND m.deleted_at IS NULL \ + AND m.status = 'completed' \ + AND c.deleted_at IS NULL \ + ORDER BY m.created_at ASC \ + LIMIT $2", + ) + .bind(conversation_id) + .bind(MAX_HISTORY_MESSAGES as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let messages: Vec = rows + .into_iter() + .map(|(role, content)| match role.as_str() { + "assistant" => AgentMessage::Assistant(content), + _ => AgentMessage::User(content), + }) + .collect(); + let mut total_chars: usize = messages + .iter() + .map(|m| match m { + AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(), + }) + .sum(); + + let mut result = messages; + while total_chars > MAX_HISTORY_CHARS && !result.is_empty() { + let removed = result.remove(0); + let removed_len = match &removed { + AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(), + }; + total_chars = total_chars.saturating_sub(removed_len); + tracing::debug!( + removed_chars = removed_len, + remaining_chars = total_chars, + "trimmed oldest message to fit context window" + ); + } + let mut estimated_tokens: u64 = result + .iter() + .map(|m| match m { + AgentMessage::User(c) | AgentMessage::Assistant(c) => { + ai::agent::helpers::estimate_tokens(c) + } + }) + .sum(); + let mut trimmed_for_tokens = 0usize; + while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS && !result.is_empty() { + let removed = result.remove(0); + estimated_tokens -= match &removed { + AgentMessage::User(c) | AgentMessage::Assistant(c) => { + ai::agent::helpers::estimate_tokens(c) + } + }; + trimmed_for_tokens += 1; + } + if trimmed_for_tokens > 0 { + tracing::info!( + trimmed = trimmed_for_tokens, + estimated_tokens = estimated_tokens, + "trimmed oldest messages to stay within token budget" + ); + } + + Ok(result) + } + pub(crate) async fn agent_load_knowledge_base( + &self, + ai_client: &AiClient, + ctx: &SessionContext, + query: &str, + ) -> Result, AppError> { + let knowledge_base_ids: Option = sqlx::query_scalar( + "SELECT knowledge_base_ids FROM agent_session WHERE id = $1", + ) + .bind(ctx.session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .flatten(); + + let kb_id_str = match knowledge_base_ids { + Some(ref s) if !s.trim().is_empty() => s.clone(), + _ => return Ok(Vec::new()), + }; + + let kb_ids: Vec = kb_id_str + .split(',') + .filter_map(|s| Uuid::parse_str(s.trim()).ok()) + .collect(); + + if kb_ids.is_empty() { + return Ok(Vec::new()); + } + + let qdrant_url = self + .config + .qdrant_url() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + let vector_size = self + .config + .get_embed_model_dimensions() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size) + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .with_api_key( + self.config + .qdrant_api_key() + .map_err(|e| AppError::InternalServerError(e.to_string()))?, + ); + + let rag = RagClient::connect(ai_client, rag_config) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + let mut all_hits: Vec = Vec::new(); + for kb_id in &kb_ids { + let session_key = format!("kb:{kb_id}"); + match rag.search_session(&session_key, query).await { + Ok(hits) => { + for hit in hits { + all_hits.push(AgentContextChunk::from(ai::rag::RagSearchHit { + id: hit.id, + session_id: hit.session_id, + score: hit.score, + content: hit.content, + metadata: hit.metadata, + })); + } + } + Err(e) => { + tracing::warn!( + kb_id = %kb_id, + error = %e, + "agent: RAG search failed for knowledge base, skipping" + ); + } + } + } + + Ok(all_hits) + } + #[allow(dead_code)] + pub(crate) async fn agent_upsert_knowledge( + &self, + ai_client: &AiClient, + kb_id: Uuid, + documents: Vec, + ) -> Result<(), AppError> { + let qdrant_url = self + .config + .qdrant_url() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + let vector_size = self + .config + .get_embed_model_dimensions() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size) + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .with_api_key( + self.config + .qdrant_api_key() + .map_err(|e| AppError::InternalServerError(e.to_string()))?, + ); + + let rag = RagClient::connect(ai_client, rag_config) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + rag.ensure_collection() + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let session_key = format!("kb:{kb_id}"); + rag.upsert_documents(&session_key, documents) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + Ok(()) + } +} diff --git a/lib/service/agent/conversation.rs b/lib/service/agent/conversation.rs new file mode 100644 index 0000000..7c5fdd8 --- /dev/null +++ b/lib/service/agent/conversation.rs @@ -0,0 +1,627 @@ +use chrono::Utc; +use db::sqlx; +use model::agent::AgentConversationModel; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +#[derive(Debug, Clone, Deserialize, ToSchema)] +pub struct CreateConversation { + pub title: String, +} + +#[derive(Debug, Clone, Deserialize, ToSchema)] +pub struct UpdateConversation { + pub title: Option, +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct ConversationResponse { + pub id: Uuid, + pub session_id: Uuid, + pub title: String, + pub created_by: Uuid, + pub last_message_at: Option>, + pub archived_at: Option>, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct ToolCallResponse { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, + pub output: Option, + pub error: Option, + pub status: String, + pub elapsed_ms: Option, +} + +impl From for ToolCallResponse { + fn from(m: model::agent::AgentToolCallLogModel) -> Self { + Self { + id: m.tool_call_id.unwrap_or_default(), + name: m.tool_name, + arguments: m.arguments.as_deref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or_default(), + output: m.result.as_deref().and_then(|s| serde_json::from_str(s).ok()), + error: m.error, + status: m.status, + elapsed_ms: m.latency_ms, + } + } +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct MessageResponse { + pub id: Uuid, + pub conversation_id: Uuid, + pub parent_id: Option, + pub role: String, + pub author: Option, + pub content: String, + pub content_type: String, + pub status: String, + pub model_invocation: Option, + pub reasoning_content: Option, + #[serde(default)] + pub tool_calls: Vec, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} +#[derive(Debug, Clone, sqlx::FromRow)] +struct ConversationWithSessionRow { + pub id: Uuid, + pub session: Uuid, + pub title: String, + pub created_by: Uuid, + pub last_message_at: Option>, + pub archived_at: Option>, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, + pub session_name: Option, + #[allow(dead_code)] + pub session_wk: Option, +} +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct ConversationWithSessionResponse { + pub id: Uuid, + pub session_id: Uuid, + pub session_name: Option, + pub title: String, + pub created_by: Uuid, + pub last_message_at: Option>, + pub archived_at: Option>, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +impl From for ConversationWithSessionResponse { + fn from(r: ConversationWithSessionRow) -> Self { + Self { + id: r.id, + session_id: r.session, + session_name: r.session_name, + title: r.title, + created_by: r.created_by, + last_message_at: r.last_message_at, + archived_at: r.archived_at, + created_at: r.created_at, + updated_at: r.updated_at, + } + } +} + +impl AppService { + pub(crate) async fn agent_require_conversation_access( + &self, + user_id: Uuid, + conversation_id: Uuid, + ) -> Result { + let conv = sqlx::query_as::<_, AgentConversationModel>( + "SELECT id, session, title, created_by, last_message_at, \ + archived_at, compacted_summary, created_at, updated_at, deleted_at \ + FROM agent_conversation \ + WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(conversation_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; + let session: (Option, Option) = sqlx::query_as( + "SELECT \"user\", wk \ + FROM agent_session \ + WHERE id = $1 AND deleted_at IS NULL AND enabled = true", + ) + .bind(conv.session) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + let (session_user, session_wk) = session; + if session_user != Some(user_id) { + if let Some(wk) = session_wk { + let _ = crate::AppService::workspace_require_member( + &*self, wk, user_id, + ) + .await?; + } else { + return Err(AppError::PermissionDenied); + } + } + + Ok(conv) + } + async fn agent_require_session_access( + &self, + user_id: Uuid, + session_id: Uuid, + ) -> Result<(), AppError> { + let session: (Option, Option) = sqlx::query_as( + "SELECT \"user\", wk \ + FROM agent_session \ + WHERE id = $1 AND deleted_at IS NULL AND enabled = true", + ) + .bind(session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + if session.0 != Some(user_id) { + if let Some(wk) = session.1 { + let _ = crate::AppService::workspace_require_member( + &*self, wk, user_id, + ) + .await?; + } else { + return Err(AppError::PermissionDenied); + } + } + + Ok(()) + } +} + +impl AppService { + pub async fn agent_conversation_create( + &self, + user_id: Uuid, + session_id: Uuid, + params: CreateConversation, + ) -> Result { + self.agent_require_session_access(user_id, session_id) + .await?; + + let id = Uuid::now_v7(); + let now = Utc::now(); + let row = sqlx::query_as::<_, AgentConversationModel>( + "INSERT INTO agent_conversation \ + (id, session, title, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $5) \ + RETURNING id, session, title, created_by, last_message_at, \ + archived_at, compacted_summary, created_at, updated_at, deleted_at", + ) + .bind(id) + .bind(session_id) + .bind(¶ms.title) + .bind(user_id) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.into()) + } + + pub async fn agent_conversation_list( + &self, + user_id: Uuid, + session_id: Uuid, + ) -> Result, AppError> { + self.agent_require_session_access(user_id, session_id) + .await?; + + let rows = sqlx::query_as::<_, AgentConversationModel>( + "SELECT id, session, title, created_by, last_message_at, \ + archived_at, compacted_summary, created_at, updated_at, deleted_at \ + FROM agent_conversation \ + WHERE session = $1 AND deleted_at IS NULL \ + ORDER BY last_message_at DESC NULLS LAST, created_at DESC \ + LIMIT 100", + ) + .bind(session_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(Into::into).collect()) + } + pub async fn agent_conversation_list_all( + &self, + user_id: Uuid, + wk: Option<&str>, + ) -> Result, AppError> { + let rows: Vec = if let Some(wk_name) = wk { + sqlx::query_as( + "SELECT c.id, c.session, c.title, c.created_by, c.last_message_at, \ + c.archived_at, c.created_at, c.updated_at, \ + s.name as session_name, s.wk as session_wk \ + FROM agent_conversation c \ + INNER JOIN agent_session s ON s.id = c.session \ + WHERE c.deleted_at IS NULL AND s.deleted_at IS NULL AND s.enabled = true \ + AND (s.\"user\" = $1 OR (s.wk = (SELECT id FROM workspace WHERE name = $2) AND s.wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1))) \ + ORDER BY c.last_message_at DESC NULLS LAST, c.created_at DESC \ + LIMIT 100", + ) + .bind(user_id) + .bind(wk_name) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + } else { + sqlx::query_as( + "SELECT c.id, c.session, c.title, c.created_by, c.last_message_at, \ + c.archived_at, c.created_at, c.updated_at, \ + s.name as session_name, s.wk as session_wk \ + FROM agent_conversation c \ + INNER JOIN agent_session s ON s.id = c.session \ + WHERE c.deleted_at IS NULL AND s.deleted_at IS NULL AND s.enabled = true \ + AND s.\"user\" = $1 \ + ORDER BY c.last_message_at DESC NULLS LAST, c.created_at DESC \ + LIMIT 100", + ) + .bind(user_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + }; + + Ok(rows.into_iter().map(Into::into).collect()) + } + + pub async fn agent_conversation_get( + &self, + user_id: Uuid, + conversation_id: Uuid, + ) -> Result { + Ok(self + .agent_require_conversation_access(user_id, conversation_id) + .await? + .into()) + } + + pub async fn agent_conversation_update( + &self, + user_id: Uuid, + conversation_id: Uuid, + params: UpdateConversation, + ) -> Result { + let existing = self + .agent_require_conversation_access(user_id, conversation_id) + .await?; + + let title = params.title.unwrap_or(existing.title); + let now = Utc::now(); + + let row = sqlx::query_as::<_, AgentConversationModel>( + "UPDATE agent_conversation SET title = $1, updated_at = $2 \ + WHERE id = $3 AND deleted_at IS NULL \ + RETURNING id, session, title, created_by, last_message_at, \ + archived_at, compacted_summary, created_at, updated_at, deleted_at", + ) + .bind(&title) + .bind(now) + .bind(conversation_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.into()) + } + + pub async fn agent_conversation_delete( + &self, + user_id: Uuid, + conversation_id: Uuid, + ) -> Result<(), AppError> { + self.agent_require_conversation_access(user_id, conversation_id) + .await?; + + let now = Utc::now(); + let rows = sqlx::query( + "UPDATE agent_conversation SET deleted_at = $1, updated_at = $1 \ + WHERE id = $2 AND deleted_at IS NULL", + ) + .bind(now) + .bind(conversation_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if rows.rows_affected() == 0 { + return Err(AppError::NotFound("conversation not found".to_string())); + } + Ok(()) + } + + pub async fn agent_conversation_archive( + &self, + user_id: Uuid, + conversation_id: Uuid, + ) -> Result { + self.agent_require_conversation_access(user_id, conversation_id) + .await?; + + let now = Utc::now(); + let row = sqlx::query_as::<_, AgentConversationModel>( + "UPDATE agent_conversation SET archived_at = $1, updated_at = $1 \ + WHERE id = $2 AND deleted_at IS NULL \ + RETURNING id, session, title, created_by, last_message_at, \ + archived_at, compacted_summary, created_at, updated_at, deleted_at", + ) + .bind(now) + .bind(conversation_id) + .fetch_optional(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; + + Ok(row.into()) + } + + pub async fn agent_conversation_unarchive( + &self, + user_id: Uuid, + conversation_id: Uuid, + ) -> Result { + self.agent_require_conversation_access(user_id, conversation_id) + .await?; + + let now = Utc::now(); + let row = sqlx::query_as::<_, AgentConversationModel>( + "UPDATE agent_conversation SET archived_at = NULL, updated_at = $1 \ + WHERE id = $2 AND deleted_at IS NULL \ + RETURNING id, session, title, created_by, last_message_at, \ + archived_at, compacted_summary, created_at, updated_at, deleted_at", + ) + .bind(now) + .bind(conversation_id) + .fetch_optional(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?; + + Ok(row.into()) + } +} + +impl AppService { + pub async fn agent_message_list( + &self, + user_id: Uuid, + conversation_id: Uuid, + limit: Option, + before: Option, + ) -> Result, AppError> { + self.agent_require_conversation_access(user_id, conversation_id) + .await?; + + let limit = limit.unwrap_or(50).min(100) as i64; + + let rows = if let Some(before_id) = before { + sqlx::query_as::<_, model::agent::AgentMessageModel>( + "SELECT id, conversation, parent, role, author, content, content_type, \ + status, model_invocation, reasoning_content, created_at, updated_at, deleted_at \ + FROM agent_message \ + WHERE conversation = $1 AND deleted_at IS NULL \ + AND created_at < (SELECT created_at FROM agent_message WHERE id = $2 AND conversation = $1) \ + ORDER BY created_at DESC LIMIT $3", + ) + .bind(conversation_id) + .bind(before_id) + .bind(limit) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + } else { + sqlx::query_as::<_, model::agent::AgentMessageModel>( + "SELECT id, conversation, parent, role, author, content, content_type, \ + status, model_invocation, reasoning_content, created_at, updated_at, deleted_at \ + FROM agent_message \ + WHERE conversation = $1 AND deleted_at IS NULL \ + ORDER BY created_at DESC LIMIT $2", + ) + .bind(conversation_id) + .bind(limit) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + }; + + // Fetch tool calls for all assistant messages in one query. + let message_ids: Vec = rows.iter().map(|r| r.id).collect(); + let tool_call_logs = if !message_ids.is_empty() { + sqlx::query_as::<_, model::agent::AgentToolCallLogModel>( + "SELECT id, invocation, session, conversation, message, tool_call_id, \ + tool_name, arguments, result, error, status, \ + started_at, finished_at, latency_ms \ + FROM agent_tool_call_log \ + WHERE message = ANY($1) \ + ORDER BY started_at ASC", + ) + .bind(&message_ids) + .fetch_all(self.db.reader()) + .await + .unwrap_or_default() + } else { + Vec::new() + }; + + // Group tool calls by message_id. + let mut tool_calls_by_message: std::collections::HashMap> = + std::collections::HashMap::new(); + for log in tool_call_logs { + if let Some(msg_id) = log.message { + tool_calls_by_message + .entry(msg_id) + .or_default() + .push(log.into()); + } + } + + let mut messages: Vec = rows + .into_iter() + .map(|row| { + let mut msg: MessageResponse = row.into(); + msg.tool_calls = tool_calls_by_message + .remove(&msg.id) + .unwrap_or_default(); + msg + }) + .collect(); + messages.reverse(); + Ok(messages) + } + pub async fn agent_conversation_fork( + &self, + user_id: Uuid, + source_conversation_id: Uuid, + up_to_message_id: Option, + title_override: Option<&str>, + ) -> Result { + let source = self + .agent_require_conversation_access(user_id, source_conversation_id) + .await?; + + let session_id = source.session; + let base_title = title_override + .map(|t| t.to_string()) + .unwrap_or_else(|| format!("{} (fork)", source.title)); + + let new_id = Uuid::now_v7(); + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_conversation \ + (id, session, title, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $5)", + ) + .bind(new_id) + .bind(session_id) + .bind(&base_title) + .bind(user_id) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let messages = if let Some(msg_id) = up_to_message_id { + sqlx::query_as::<_, model::agent::AgentMessageModel>( + "SELECT id, conversation, parent, role, author, content, content_type, \ + status, model_invocation, reasoning_content, \ + created_at, updated_at, deleted_at \ + FROM agent_message \ + WHERE conversation = $1 \ + AND deleted_at IS NULL \ + AND created_at <= (SELECT created_at FROM agent_message WHERE id = $2) \ + ORDER BY created_at ASC", + ) + .bind(source_conversation_id) + .bind(msg_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + } else { + sqlx::query_as::<_, model::agent::AgentMessageModel>( + "SELECT id, conversation, parent, role, author, content, content_type, \ + status, model_invocation, reasoning_content, \ + created_at, updated_at, deleted_at \ + FROM agent_message \ + WHERE conversation = $1 AND deleted_at IS NULL \ + ORDER BY created_at ASC", + ) + .bind(source_conversation_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + }; + + for msg in &messages { + let new_msg_id = Uuid::now_v7(); + sqlx::query( + "INSERT INTO agent_message \ + (id, conversation, parent, role, author, content, content_type, \ + status, model_invocation, reasoning_content, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $11)", + ) + .bind(new_msg_id) + .bind(new_id) + .bind::>(None) + .bind(&msg.role) + .bind(msg.author) + .bind(&msg.content) + .bind(&msg.content_type) + .bind(&msg.status) + .bind(msg.model_invocation) + .bind(&msg.reasoning_content) + .bind(msg.created_at) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let _ = sqlx::query( + "INSERT INTO agent_message_fork \ + (source_message, forked_conversation, forked_by, created_at) \ + VALUES ($1, $2, $3, $4)", + ) + .bind(msg.id) + .bind(new_id) + .bind(user_id) + .bind(now) + .execute(self.db.writer()) + .await; + } + + self.agent_conversation_get(user_id, new_id).await + } +} + +impl From for ConversationResponse { + fn from(m: AgentConversationModel) -> Self { + Self { + id: m.id, + session_id: m.session, + title: m.title, + created_by: m.created_by, + last_message_at: m.last_message_at, + archived_at: m.archived_at, + created_at: m.created_at, + updated_at: m.updated_at, + } + } +} + +impl From for MessageResponse { + fn from(m: model::agent::AgentMessageModel) -> Self { + Self { + id: m.id, + conversation_id: m.conversation, + parent_id: m.parent, + role: m.role, + author: m.author, + content: m.content, + content_type: m.content_type, + status: m.status, + model_invocation: m.model_invocation, + reasoning_content: m.reasoning_content, + tool_calls: Vec::new(), + created_at: m.created_at, + updated_at: m.updated_at, + } + } +} diff --git a/lib/service/agent/git_tools/blame.rs b/lib/service/agent/git_tools/blame.rs new file mode 100644 index 0000000..1384ede --- /dev/null +++ b/lib/service/agent/git_tools/blame.rs @@ -0,0 +1,90 @@ +use ai::error::AiResult; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::blame_service_client::BlameServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, arg_opt_str, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + +pub struct GitBlameTool; + +impl GitBlameTool { + pub fn new() -> Self { Self } +} + +impl Default for GitBlameTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitBlameTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_blame" } + + fn description(&self) -> &'static str { + "Blame a file to see which commits authored each line range." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "path": { "type": "string", "description": "File path to blame" }, + "rev": { "type": "string", "description": "Revision/branch (optional)" }, + "start_line": { "type": "integer", "description": "Start line number (optional)" }, + "end_line": { "type": "integer", "description": "End line number (optional)" } + }, + "required": ["workspace", "repo", "path"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let path = arg_str(&args, "path")?; + let rev = arg_opt_str(&args, "rev").map(String::from); + let start_line = args.get("start_line").and_then(|v| v.as_u64()).map(|v| v as u32); + let end_line = args.get("end_line").and_then(|v| v.as_u64()).map(|v| v as u32); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = BlameServiceClient::new(git.channel.clone()); + + if let (Some(start), Some(end)) = (start_line, end_line) { + let resp = client + .blame_lines(p::BlameLinesRequest { + repo_id: repo.id.to_string(), path: path.to_string(), rev, + start_line: start, end_line: end, + }) + .await.map_err(rpc_err)?.into_inner(); + + let lines: Vec = resp.lines.iter().map(|l| json!({ + "line_no": l.line_no, + "content": l.content, + "commit_oid": l.commit_oid.as_ref().map(|o| &o.value), + })).collect(); + + Ok(json!({ "lines": lines, "count": lines.len() })) + } else { + let resp = client + .blame_file(p::BlameFileRequest { + repo_id: repo.id.to_string(), path: path.to_string(), rev, options: None, + }) + .await.map_err(rpc_err)?.into_inner(); + + let hunks: Vec = resp.hunks.iter().map(|h| json!({ + "commit_oid": h.commit_oid.as_ref().map(|o| &o.value), + "final_start_line": h.final_start_line, + "final_lines": h.final_lines, + })).collect(); + + Ok(json!({ "hunks": hunks, "count": hunks.len() })) + } + } +} diff --git a/lib/service/agent/git_tools/branch.rs b/lib/service/agent/git_tools/branch.rs new file mode 100644 index 0000000..f80f58b --- /dev/null +++ b/lib/service/agent/git_tools/branch.rs @@ -0,0 +1,302 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::branch_service_client::BranchServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + +pub struct GitBranchListTool; + +impl GitBranchListTool { + pub fn new() -> Self { Self } +} + +impl Default for GitBranchListTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitBranchListTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_branch_list" } + + fn description(&self) -> &'static str { + "List all branches in a repository with their HEAD commit OID." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" } + }, + "required": ["workspace", "repo"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = BranchServiceClient::new(git.channel.clone()); + let resp = client + .branch_list(p::BranchListRequest { repo_id: repo.id.to_string() }) + .await + .map_err(rpc_err)? + .into_inner(); + + let branches: Vec = resp.branches.iter().map(|b| json!({ + "name": b.name, + "oid": b.oid.as_ref().map(|o| &o.value).unwrap_or(&String::new()), + "is_head": b.is_head, + "is_current": b.is_current, + })).collect(); + + Ok(json!({ "branches": branches, "count": branches.len() })) + } +} + +pub struct GitBranchInfoTool; + +impl GitBranchInfoTool { + pub fn new() -> Self { Self } +} + +impl Default for GitBranchInfoTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitBranchInfoTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_branch_info" } + + fn description(&self) -> &'static str { + "Get detailed information about a single branch, including its HEAD OID and upstream." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "branch": { "type": "string", "description": "Branch name" } + }, + "required": ["workspace", "repo", "branch"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let branch = arg_str(&args, "branch")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = BranchServiceClient::new(git.channel.clone()); + let resp = client + .branch_info(p::BranchInfoRequest { + repo_id: repo.id.to_string(), + branch: branch.to_string(), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let b = resp.branch.ok_or_else(|| AiError::Config(format!("branch '{branch}' not found")))?; + Ok(json!({ + "name": b.name, + "oid": b.oid.as_ref().map(|o| &o.value), + "is_head": b.is_head, + "is_current": b.is_current, + "upstream": b.upstream, + })) + } +} + +pub struct GitBranchAheadBehindTool; + +impl GitBranchAheadBehindTool { + pub fn new() -> Self { Self } +} + +impl Default for GitBranchAheadBehindTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitBranchAheadBehindTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_branch_ahead_behind" } + + fn description(&self) -> &'static str { + "Compare a local branch with its remote tracking branch. Returns commits ahead and behind." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "local_branch": { "type": "string", "description": "Local branch name" }, + "remote_branch": { "type": "string", "description": "Remote tracking branch name" } + }, + "required": ["workspace", "repo", "local_branch", "remote_branch"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let local_branch = arg_str(&args, "local_branch")?; + let remote_branch = arg_str(&args, "remote_branch")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = BranchServiceClient::new(git.channel.clone()); + let resp = client + .branch_ahead_behind(p::BranchAheadBehindRequest { + repo_id: repo.id.to_string(), + local_branch: local_branch.to_string(), + remote_branch: remote_branch.to_string(), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + Ok(json!({ "ahead": resp.ahead, "behind": resp.behind })) + } +} + +pub struct GitBranchDeleteTool; + +impl GitBranchDeleteTool { + pub fn new() -> Self { Self } +} + +impl Default for GitBranchDeleteTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitBranchDeleteTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_branch_delete" } + + fn description(&self) -> &'static str { + "Delete a branch from the repository. Requires write access." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "name": { "type": "string", "description": "Branch name to delete" }, + "force": { "type": "boolean", "description": "Force delete (even if not merged)" } + }, + "required": ["workspace", "repo", "name"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let name = arg_str(&args, "name")?; + let force = args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = BranchServiceClient::new(git.channel.clone()); + client + .branch_delete(p::BranchDeleteRequest { + repo_id: repo.id.to_string(), + params: Some(p::BranchDeleteParams { + name: name.to_string(), + force, + }), + }) + .await + .map_err(rpc_err)?; + + Ok(json!({ "success": true, "branch": name })) + } +} + +pub struct GitCreateBranchTool; + +impl GitCreateBranchTool { + pub fn new() -> Self { Self } +} + +impl Default for GitCreateBranchTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitCreateBranchTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_create_branch" } + + fn description(&self) -> &'static str { + "Create a new branch in a repository. Requires write access." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "name": { "type": "string", "description": "New branch name" }, + "oid": { "type": "string", "description": "Commit OID to branch from" }, + "force": { "type": "boolean", "description": "Force create (overwrite existing)" } + }, + "required": ["workspace", "repo", "name", "oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let name = arg_str(&args, "name")?; + let oid = arg_str(&args, "oid")?; + let force = args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = BranchServiceClient::new(git.channel.clone()); + client + .branch_fork(p::BranchForkRequest { + repo_id: repo.id.to_string(), + params: Some(p::BranchForkParams { + name: name.to_string(), + oid: Some(p::ObjectId { value: oid.to_string() }), + force, + }), + }) + .await + .map_err(rpc_err)?; + + Ok(json!({ "success": true, "branch": name, "oid": oid })) + } +} diff --git a/lib/service/agent/git_tools/commit.rs b/lib/service/agent/git_tools/commit.rs new file mode 100644 index 0000000..0eb232b --- /dev/null +++ b/lib/service/agent/git_tools/commit.rs @@ -0,0 +1,438 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::commit_service_client::CommitServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, arg_opt_str, arg_u64, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + +pub struct GitCommitHistoryTool; + +impl GitCommitHistoryTool { + pub fn new() -> Self { + Self + } +} + +impl Default for GitCommitHistoryTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl FunctionCall for GitCommitHistoryTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { + "git_commit_history" + } + + fn description(&self) -> &'static str { + "List recent commits on a branch. Returns commit OID, message, author, and timestamp." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { + "type": "string", + "description": "Workspace name (e.g. 'my-org')" + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "branch": { + "type": "string", + "description": "Branch name (optional, defaults to default branch)" + }, + "limit": { + "type": "integer", + "description": "Max commits to return (default 20, max 100)" + }, + "skip": { + "type": "integer", + "description": "Number of commits to skip (for pagination)" + } + }, + "required": ["workspace", "repo"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let branch = arg_opt_str(&args, "branch").map(String::from); + let limit = arg_u64(&args, "limit", 20).min(100); + let skip = arg_u64(&args, "skip", 0); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = CommitServiceClient::new(git.channel.clone()); + let resp = client + .commit_history(p::CommitHistoryRequest { + repo_id: repo.id.to_string(), + limit, + skip, + sort: p::CommitWalkSort::Time as i32, + branch, + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let commits: Vec = resp + .commits + .iter() + .map(|c| { + json!({ + "oid": c.oid.as_ref().map(|o| &o.value).unwrap_or(&String::new()), + "summary": c.summary, + "message": c.message, + "author_name": c.author.as_ref().map(|a| &a.name).unwrap_or(&String::new()), + "author_email": c.author.as_ref().map(|a| &a.email).unwrap_or(&String::new()), + "time": c.author.as_ref().map(|a| a.time_secs).unwrap_or(0), + }) + }) + .collect(); + + Ok(json!({ "commits": commits, "count": commits.len() })) + } +} + +pub struct GitCommitInfoTool; + +impl GitCommitInfoTool { + pub fn new() -> Self { + Self + } +} + +impl Default for GitCommitInfoTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl FunctionCall for GitCommitInfoTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { + "git_commit_info" + } + + fn description(&self) -> &'static str { + "Get detailed information about a specific commit by its OID." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "oid": { "type": "string", "description": "Commit OID (SHA)" } + }, + "required": ["workspace", "repo", "oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let oid = arg_str(&args, "oid")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = CommitServiceClient::new(git.channel.clone()); + let resp = client + .commit_info(p::CommitInfoRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: oid.to_string() }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let c = resp + .commit + .ok_or_else(|| AiError::Response("commit not found".to_string()))?; + let parent_ids: Vec = c.parent_ids.iter().map(|o| o.value.clone()).collect(); + + Ok(json!({ + "oid": c.oid.as_ref().map(|o| &o.value), + "summary": c.summary, + "message": c.message, + "author_name": c.author.as_ref().map(|a| &a.name), + "author_email": c.author.as_ref().map(|a| &a.email), + "author_time": c.author.as_ref().map(|a| a.time_secs), + "committer_name": c.committer.as_ref().map(|a| &a.name), + "tree_id": c.tree_id.as_ref().map(|o| &o.value), + "parent_ids": parent_ids, + })) + } +} + + +pub struct GitCommitExistsTool; + +impl GitCommitExistsTool { + pub fn new() -> Self { Self } +} + +impl Default for GitCommitExistsTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitCommitExistsTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_commit_exists" } + + fn description(&self) -> &'static str { + "Check whether a specific commit OID exists in the repository." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "oid": { "type": "string", "description": "Commit OID (SHA)" } + }, + "required": ["workspace", "repo", "oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let oid = arg_str(&args, "oid")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = CommitServiceClient::new(git.channel.clone()); + let resp = client + .commit_exists(p::CommitExistsRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: oid.to_string() }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + Ok(json!({ "exists": resp.exists })) + } +} + + + +pub struct GitCherryPickTool; + +impl GitCherryPickTool { + pub fn new() -> Self { Self } +} + +impl Default for GitCherryPickTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitCherryPickTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_cherry_pick" } + + fn description(&self) -> &'static str { + "Cherry-pick a commit onto the current branch. Requires write access." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "oid": { "type": "string", "description": "Commit OID to cherry-pick" }, + "message": { "type": "string", "description": "Override commit message (optional)" }, + "update_ref": { "type": "string", "description": "Branch ref to update (optional)" } + }, + "required": ["workspace", "repo", "oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let oid = arg_str(&args, "oid")?; + let message = args.get("message").and_then(|v| v.as_str()).map(String::from); + let update_ref = arg_opt_str(&args, "update_ref").map(String::from); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = CommitServiceClient::new(git.channel.clone()); + let resp = client + .cherry_pick(p::CherryPickRequest { + repo_id: repo.id.to_string(), + params: Some(p::CommitCherryPickParams { + cherrypick_oid: Some(p::ObjectId { value: oid.to_string() }), + message, + update_ref, + ..Default::default() + }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + Ok(json!({ + "success": true, + "new_oid": resp.oid.as_ref().map(|o| &o.value), + })) + } +} + + +pub struct GitCommitCreateTool; + +impl GitCommitCreateTool { + pub fn new() -> Self { + Self + } +} + +impl Default for GitCommitCreateTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl FunctionCall for GitCommitCreateTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { + "git_commit_create" + } + + fn description(&self) -> &'static str { + "Create a commit with new or updated files in a workspace repository. Author is set to the requesting user, committer is redpanda . Provide workspace name, repo name, branch, commit message, and a list of files (path and content for each)." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { + "type": "string", + "description": "Workspace name (e.g. 'my-org')" + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "branch": { + "type": "string", + "description": "Branch name to commit to (e.g. 'main'). If the branch does not exist, it will be created." + }, + "message": { + "type": "string", + "description": "Commit message" + }, + "files": { + "type": "array", + "description": "List of file changes", + "items": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path relative to repo root (e.g. 'src/main.rs')" + }, + "content": { + "type": "string", + "description": "File content as a string" + } + }, + "required": ["path", "content"] + } + } + }, + "required": ["workspace", "repo", "branch", "message", "files"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let branch = arg_str(&args, "branch")?; + let message = arg_str(&args, "message")?; + + let files_val = args + .get("files") + .and_then(|v| v.as_array()) + .ok_or_else(|| AiError::Config("'files' must be an array of {path, content} objects".to_string()))?; + + let mut file_changes: Vec = Vec::new(); + for f in files_val { + let path = f + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config("each file must have a 'path' field".to_string()))?; + let content = f + .get("content") + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config("each file must have a 'content' field".to_string()))?; + file_changes.push(super::helpers::FileChange { + path: path.to_string(), + content: content.as_bytes().to_vec(), + }); + } + + if file_changes.is_empty() { + return Err(AiError::Config("'files' array must not be empty".to_string())); + } + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = CommitServiceClient::new(git.channel.clone()); + let resp = client + .create_commit(tonic::Request::new(p::CreateCommitRequest { + repo_id: repo.id.to_string(), + branch: branch.to_string(), + message: message.to_string(), + author_name: ctx.user_id.to_string(), + author_email: format!("{}@gitdata.ai", ctx.user_id), + committer_name: "redpanda".to_string(), + committer_email: "redpanda@gitdata.ai".to_string(), + files: file_changes + .into_iter() + .map(|fc| p::FileChange { + path: fc.path, + content: fc.content, + }) + .collect(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + Ok(json!({ + "success": true, + "oid": resp.oid.as_ref().map(|o| &o.value), + "files_committed": files_val.len(), + })) + } +} diff --git a/lib/service/agent/git_tools/diff.rs b/lib/service/agent/git_tools/diff.rs new file mode 100644 index 0000000..f5c12d4 --- /dev/null +++ b/lib/service/agent/git_tools/diff.rs @@ -0,0 +1,188 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::diff_service_client::DiffServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + +pub struct GitDiffStatsTool; + +impl GitDiffStatsTool { + pub fn new() -> Self { Self } +} + +impl Default for GitDiffStatsTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitDiffStatsTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_diff_stats" } + + fn description(&self) -> &'static str { + "Get diff statistics between two commits: files changed, insertions, deletions." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "old_oid": { "type": "string", "description": "Base commit OID" }, + "new_oid": { "type": "string", "description": "Target commit OID" } + }, + "required": ["workspace", "repo", "old_oid", "new_oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let old_oid = arg_str(&args, "old_oid")?; + let new_oid = arg_str(&args, "new_oid")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = DiffServiceClient::new(git.channel.clone()); + let resp = client + .diff_stats(p::DiffStatsRequest { + repo_id: repo.id.to_string(), + old_oid: Some(p::ObjectId { value: old_oid.to_string() }), + new_oid: Some(p::ObjectId { value: new_oid.to_string() }), + options: None, + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let result = resp.result.ok_or_else(|| AiError::Response("no diff result".to_string()))?; + let stats = result.stats.ok_or_else(|| AiError::Response("no stats".to_string()))?; + + let files: Vec = result.deltas.iter().map(|d| { + let status = match p::DiffDeltaStatus::try_from(d.status) { + Ok(p::DiffDeltaStatus::Added) => "added", + Ok(p::DiffDeltaStatus::Deleted) => "deleted", + Ok(p::DiffDeltaStatus::Modified) => "modified", + Ok(p::DiffDeltaStatus::Renamed) => "renamed", + _ => "unknown", + }; + let old_path = d.old_file.as_ref().and_then(|f| f.path.as_deref()).unwrap_or(""); + let new_path = d.new_file.as_ref().and_then(|f| f.path.as_deref()).unwrap_or(""); + json!({ "status": status, "old_path": old_path, "new_path": new_path, "hunks": d.hunks.len() }) + }).collect(); + + Ok(json!({ + "files_changed": stats.files_changed, + "insertions": stats.insertions, + "deletions": stats.deletions, + "files": files, + })) + } +} + +pub struct GitDiffPatchTool; + +impl GitDiffPatchTool { + pub fn new() -> Self { Self } +} + +impl Default for GitDiffPatchTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitDiffPatchTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_diff_patch" } + + fn description(&self) -> &'static str { + "Get the full diff (unified format) between two commits, including line-level changes." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "old_oid": { "type": "string", "description": "Base commit OID" }, + "new_oid": { "type": "string", "description": "Target commit OID" }, + "context_lines": { "type": "integer", "description": "Lines of context around changes (default 3)" } + }, + "required": ["workspace", "repo", "old_oid", "new_oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let old_oid = arg_str(&args, "old_oid")?; + let new_oid = arg_str(&args, "new_oid")?; + let ctx_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(3) as u32; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = DiffServiceClient::new(git.channel.clone()); + let resp = client + .diff_patch(p::DiffPatchRequest { + repo_id: repo.id.to_string(), + old_oid: Some(p::ObjectId { value: old_oid.to_string() }), + new_oid: Some(p::ObjectId { value: new_oid.to_string() }), + options: Some(p::DiffOptions { + context_lines: ctx_lines, + ..Default::default() + }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let result = resp.result.ok_or_else(|| AiError::Response("no diff result".to_string()))?; + let stats = result.stats.ok_or_else(|| AiError::Response("no stats".to_string()))?; + + let mut patch_text = String::new(); + for delta in &result.deltas { + let status = match p::DiffDeltaStatus::try_from(delta.status) { + Ok(p::DiffDeltaStatus::Added) => "added", + Ok(p::DiffDeltaStatus::Deleted) => "deleted", + Ok(p::DiffDeltaStatus::Modified) => "modified", + Ok(p::DiffDeltaStatus::Renamed) => "renamed", + _ => "unknown", + }; + let old = delta.old_file.as_ref().and_then(|f| f.path.as_deref()).unwrap_or("unknown"); + let new = delta.new_file.as_ref().and_then(|f| f.path.as_deref()).unwrap_or("unknown"); + patch_text.push_str(&format!("--- {}\n+++ {}\n@@ status: {status} @@\n", old, new)); + + for hunk in &delta.hunks { + patch_text.push_str(&hunk.header); + patch_text.push('\n'); + for line in &delta.lines { + patch_text.push_str(&format!("{}{}\n", line.origin, line.content)); + } + patch_text.push('\n'); + } + } + + let truncated = patch_text.len() > 32_000; + if truncated { + patch_text = format!("{}...(truncated)", &patch_text[..32_000]); + } + + Ok(json!({ + "files_changed": stats.files_changed, + "insertions": stats.insertions, + "deletions": stats.deletions, + "patch": patch_text, + "truncated": truncated, + })) + } +} diff --git a/lib/service/agent/git_tools/helpers.rs b/lib/service/agent/git_tools/helpers.rs new file mode 100644 index 0000000..e0843fb --- /dev/null +++ b/lib/service/agent/git_tools/helpers.rs @@ -0,0 +1,112 @@ +use ai::error::{AiError, AiResult}; + +#[derive(Debug, Clone)] +pub struct FileChange { + pub path: String, + pub content: Vec, +} +use ai::tool::register::ToolRegister; +use db::sqlx; +use model::repos::RepoModel; +use serde_json::Value; +use uuid::Uuid; + +use crate::agent::run::{AppAgentContext, GitAgentContext}; + +pub fn register_git_tools(tools: &mut ToolRegister) { + tools.register(super::commit::GitCommitHistoryTool::new()); + tools.register(super::commit::GitCommitInfoTool::new()); + tools.register(super::commit::GitCommitExistsTool::new()); + tools.register(super::commit::GitCherryPickTool::new()); + tools.register(super::commit::GitCommitCreateTool::new()); + tools.register(super::branch::GitBranchListTool::new()); + tools.register(super::branch::GitBranchInfoTool::new()); + tools.register(super::branch::GitBranchAheadBehindTool::new()); + tools.register(super::branch::GitCreateBranchTool::new()); + tools.register(super::branch::GitBranchDeleteTool::new()); + tools.register(super::tree::GitTreeEntriesTool::new()); + tools.register(super::tree::GitFileContentTool::new()); + tools.register(super::diff::GitDiffStatsTool::new()); + tools.register(super::diff::GitDiffPatchTool::new()); + tools.register(super::blame::GitBlameTool::new()); + tools.register(super::tag::GitTagListTool::new()); + tools.register(super::tag::GitTagInfoTool::new()); + tools.register(super::tag::GitCreateTagTool::new()); + tools.register(super::tag::GitDeleteTagTool::new()); + tools.register(super::merge::GitMergeBaseTool::new()); + tools.register(super::merge::GitMergeAnalysisTool::new()); + tools.register(super::merge::GitMergeIsConflictedTool::new()); +} + +pub(super) async fn require_repo_member( + git: &GitAgentContext, + user_id: Uuid, + workspace_name: &str, + repo_name: &str, +) -> AiResult { + let wk_id: Uuid = sqlx::query_scalar( + "SELECT id FROM workspace WHERE name = $1", + ) + .bind(workspace_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| AiError::Config(format!("workspace '{workspace_name}' not found")))?; + + let is_member: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM wk_member \ + WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(wk_id) + .bind(user_id) + .fetch_one(git.db.reader()) + .await + .map_err(AiError::Database)?; + + if is_member == 0 { + return Err(AiError::Config(format!( + "user is not a member of workspace '{workspace_name}'" + ))); + } + + let repo: RepoModel = sqlx::query_as( + "SELECT id, wk, name, description, default_branch, visibility, \ + size_bytes, is_archived, is_template, is_mirror, created_by, \ + storage_path, created_at, updated_at, deleted_at \ + FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL", + ) + .bind(wk_id) + .bind(repo_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| AiError::Config(format!("repo '{repo_name}' not found")))?; + + Ok(repo) +} + +pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> { + ctx.git + .as_ref() + .ok_or_else(|| AiError::Config("git tools are not available in this session".to_string())) +} + +pub(super) fn rpc_err(status: tonic::Status) -> AiError { + AiError::Api(format!("git rpc error: {}", status.message())) +} + +pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> { + args.get(key) + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config(format!("'{key}' parameter is required"))) +} + +pub(super) fn arg_opt_str<'a>(args: &'a Value, key: &str) -> Option<&'a str> { + args.get(key).and_then(|v| v.as_str()) +} + +pub(super) fn arg_u64(args: &Value, key: &str, default: u64) -> u64 { + args.get(key) + .and_then(|v| v.as_u64()) + .unwrap_or(default) +} diff --git a/lib/service/agent/git_tools/merge.rs b/lib/service/agent/git_tools/merge.rs new file mode 100644 index 0000000..1626a5b --- /dev/null +++ b/lib/service/agent/git_tools/merge.rs @@ -0,0 +1,204 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::merge_service_client::MergeServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + +pub struct GitMergeBaseTool; + +impl GitMergeBaseTool { + pub fn new() -> Self { Self } +} + +impl Default for GitMergeBaseTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitMergeBaseTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_merge_base" } + + fn description(&self) -> &'static str { + "Find the common ancestor (merge base) of two commits." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "oid_a": { "type": "string", "description": "First commit OID" }, + "oid_b": { "type": "string", "description": "Second commit OID" } + }, + "required": ["workspace", "repo", "oid_a", "oid_b"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let oid_a = arg_str(&args, "oid_a")?; + let oid_b = arg_str(&args, "oid_b")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = MergeServiceClient::new(git.channel.clone()); + let resp = client + .merge_base(p::MergeBaseRequest { + repo_id: repo.id.to_string(), + oid_a: Some(p::ObjectId { value: oid_a.to_string() }), + oid_b: Some(p::ObjectId { value: oid_b.to_string() }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + Ok(json!({ "base_oid": resp.base_oid.map(|o| o.value) })) + } +} + +pub struct GitMergeAnalysisTool; + +impl GitMergeAnalysisTool { + pub fn new() -> Self { Self } +} + +impl Default for GitMergeAnalysisTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitMergeAnalysisTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_merge_analysis" } + + fn description(&self) -> &'static str { + "Analyze whether two commits can be merged (fast-forward, normal, up-to-date, etc)." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "oid_a": { "type": "string", "description": "First commit OID" }, + "oid_b": { "type": "string", "description": "Second commit OID" } + }, + "required": ["workspace", "repo", "oid_a", "oid_b"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let oid_a = arg_str(&args, "oid_a")?; + let oid_b = arg_str(&args, "oid_b")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = MergeServiceClient::new(git.channel.clone()); + let resp = client + .merge_analysis(p::MergeAnalysisRequest { + repo_id: repo.id.to_string(), + oid_a: Some(p::ObjectId { value: oid_a.to_string() }), + oid_b: Some(p::ObjectId { value: oid_b.to_string() }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let analysis = resp.analysis.ok_or_else(|| AiError::Response("no analysis".to_string()))?; + let pref = resp.preference.ok_or_else(|| AiError::Response("no preference".to_string()))?; + + // Determine overall status + let status = if analysis.is_up_to_date { + "up_to_date" + } else if analysis.is_fast_forward { + "fast_forward" + } else if analysis.is_normal { + "normal_merge" + } else if analysis.is_unborn { + "unborn" + } else { + "none" + }; + + Ok(json!({ + "status": status, + "analysis": { + "is_none": analysis.is_none, + "is_normal": analysis.is_normal, + "is_up_to_date": analysis.is_up_to_date, + "is_fast_forward": analysis.is_fast_forward, + "is_unborn": analysis.is_unborn, + }, + "preference": { + "is_none": pref.is_none, + "is_no_fast_forward": pref.is_no_fast_forward, + "is_fastforward_only": pref.is_fastforward_only, + }, + })) + } +} + +pub struct GitMergeIsConflictedTool; + +impl GitMergeIsConflictedTool { + pub fn new() -> Self { Self } +} + +impl Default for GitMergeIsConflictedTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitMergeIsConflictedTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_merge_is_conflicted" } + + fn description(&self) -> &'static str { + "Check if the repository is currently in a conflicted merge state." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" } + }, + "required": ["workspace", "repo"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = MergeServiceClient::new(git.channel.clone()); + let resp = client + .merge_is_conflicted(p::MergeIsConflictedRequest { + repo_id: repo.id.to_string(), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + Ok(json!({ "is_conflicted": resp.is_conflicted })) + } +} diff --git a/lib/service/agent/git_tools/mod.rs b/lib/service/agent/git_tools/mod.rs new file mode 100644 index 0000000..e12ffce --- /dev/null +++ b/lib/service/agent/git_tools/mod.rs @@ -0,0 +1,10 @@ +pub mod blame; +pub mod branch; +pub mod commit; +pub mod diff; +pub mod helpers; +pub mod merge; +pub mod tag; +pub mod tree; + +pub use helpers::register_git_tools; diff --git a/lib/service/agent/git_tools/tag.rs b/lib/service/agent/git_tools/tag.rs new file mode 100644 index 0000000..81c9960 --- /dev/null +++ b/lib/service/agent/git_tools/tag.rs @@ -0,0 +1,243 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::tag_service_client::TagServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, arg_opt_str, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + + +pub struct GitTagListTool; + +impl GitTagListTool { + pub fn new() -> Self { Self } +} + +impl Default for GitTagListTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitTagListTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_tag_list" } + + fn description(&self) -> &'static str { "List all tags in a repository." } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" } + }, + "required": ["workspace", "repo"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = TagServiceClient::new(git.channel.clone()); + let resp = client + .tag_list(p::TagListRequest { repo_id: repo.id.to_string() }) + .await + .map_err(rpc_err)? + .into_inner(); + + let tags: Vec = resp.tags.iter().map(|t| json!({ + "name": t.name, + "oid": t.oid.as_ref().map(|o| &o.value), + "target": t.target.as_ref().map(|o| &o.value), + "is_annotated": t.is_annotated, + "message": t.message, + "tagger": t.tagger, + })).collect(); + + Ok(json!({ "tags": tags, "count": tags.len() })) + } +} + + +pub struct GitCreateTagTool; + +impl GitCreateTagTool { + pub fn new() -> Self { Self } +} + +impl Default for GitCreateTagTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitCreateTagTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_create_tag" } + + fn description(&self) -> &'static str { "Create a new tag pointing at a commit." } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "name": { "type": "string", "description": "Tag name" }, + "target_oid": { "type": "string", "description": "Target commit OID" }, + "message": { "type": "string", "description": "Tag message (for annotated tags)" }, + "force": { "type": "boolean", "description": "Force overwrite existing tag" } + }, + "required": ["workspace", "repo", "name", "target_oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let name = arg_str(&args, "name")?; + let target_oid = arg_str(&args, "target_oid")?; + let message = arg_opt_str(&args, "message").map(String::from); + let force = args.get("force").and_then(|v| v.as_bool()).unwrap_or(false); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = TagServiceClient::new(git.channel.clone()); + let resp = client + .tag_init(p::TagInitRequest { + repo_id: repo.id.to_string(), + params: Some(p::TagInitParams { + name: name.to_string(), + target: Some(p::ObjectId { value: target_oid.to_string() }), + message, + tagger: None, + force, + }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let oid = resp.oid.map(|o| o.value).unwrap_or_default(); + Ok(json!({ "success": true, "tag": name, "oid": oid })) + } +} + + +pub struct GitTagInfoTool; + +impl GitTagInfoTool { + pub fn new() -> Self { Self } +} + +impl Default for GitTagInfoTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitTagInfoTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_tag_info" } + + fn description(&self) -> &'static str { "Get detailed information about a specific tag." } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "name": { "type": "string", "description": "Tag name" } + }, + "required": ["workspace", "repo", "name"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let name = arg_str(&args, "name")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = TagServiceClient::new(git.channel.clone()); + let resp = client + .tag_info(p::TagInfoRequest { repo_id: repo.id.to_string(), name: name.to_string() }) + .await + .map_err(rpc_err)? + .into_inner(); + + let t = resp.tag.ok_or_else(|| AiError::Config(format!("tag '{name}' not found")))?; + Ok(json!({ + "name": t.name, + "oid": t.oid.as_ref().map(|o| &o.value), + "target": t.target.as_ref().map(|o| &o.value), + "is_annotated": t.is_annotated, + "message": t.message, + "tagger": t.tagger, + "tagger_email": t.tagger_email, + })) + } +} + + +pub struct GitDeleteTagTool; + +impl GitDeleteTagTool { + pub fn new() -> Self { Self } +} + +impl Default for GitDeleteTagTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitDeleteTagTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_delete_tag" } + + fn description(&self) -> &'static str { "Delete a tag from the repository. Requires write access." } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "name": { "type": "string", "description": "Tag name to delete" } + }, + "required": ["workspace", "repo", "name"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let name = arg_str(&args, "name")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut client = TagServiceClient::new(git.channel.clone()); + client + .tag_delete(p::TagDeleteRequest { + repo_id: repo.id.to_string(), + params: Some(p::TagDeleteParams { name: name.to_string() }), + }) + .await + .map_err(rpc_err)?; + + Ok(json!({ "success": true, "tag": name })) + } +} diff --git a/lib/service/agent/git_tools/tree.rs b/lib/service/agent/git_tools/tree.rs new file mode 100644 index 0000000..a651caf --- /dev/null +++ b/lib/service/agent/git_tools/tree.rs @@ -0,0 +1,197 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use git::rpc::proto as p; +use git::rpc::proto::blob_service_client::BlobServiceClient; +use git::rpc::proto::commit_service_client::CommitServiceClient; +use git::rpc::proto::tree_service_client::TreeServiceClient; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, arg_opt_str, git_ctx, require_repo_member, rpc_err}; +use crate::agent::run::AppAgentContext; + + +pub struct GitTreeEntriesTool; + +impl GitTreeEntriesTool { + pub fn new() -> Self { Self } +} + +impl Default for GitTreeEntriesTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitTreeEntriesTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_tree_entries" } + + fn description(&self) -> &'static str { + "List files and subdirectories at a given path in a commit's tree. Use this to explore repo structure." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "commit_oid": { "type": "string", "description": "Commit OID to read the tree from" }, + "path": { "type": "string", "description": "Directory path (empty string for root)" } + }, + "required": ["workspace", "repo", "commit_oid"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let commit_oid = arg_str(&args, "commit_oid")?; + let path = arg_opt_str(&args, "path").unwrap_or(""); + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut commit_client = CommitServiceClient::new(git.channel.clone()); + let commit_resp = commit_client + .commit_info(p::CommitInfoRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: commit_oid.to_string() }), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let tree_oid = commit_resp + .commit + .and_then(|c| c.tree_id) + .ok_or_else(|| AiError::Response("commit has no tree".to_string()))?; + + let mut client = TreeServiceClient::new(git.channel.clone()); + let resp = client + .tree_entries(p::TreeEntriesRequest { + repo_id: repo.id.to_string(), + oid: Some(tree_oid), + base_path: path.to_string(), + last: false, + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let entries: Vec = resp.entries.iter().map(|e| { + let kind = match p::TreeKind::try_from(e.kind) { + Ok(p::TreeKind::Blob) => "file", + Ok(p::TreeKind::Tree) => "dir", + Ok(p::TreeKind::LfsPointer) => "lfs", + _ => "unknown", + }; + json!({ + "name": e.name, + "oid": e.oid.as_ref().map(|o| &o.value), + "kind": kind, + "is_binary": e.is_binary, + "is_lfs": e.is_lfs, + }) + }).collect(); + + Ok(json!({ "entries": entries, "count": entries.len() })) + } +} + + + +pub struct GitFileContentTool; + +impl GitFileContentTool { + pub fn new() -> Self { Self } +} + +impl Default for GitFileContentTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for GitFileContentTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "git_file_content" } + + fn description(&self) -> &'static str { + "Read the content of a file at a given path from a specific commit. Returns the file content as text." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "repo": { "type": "string", "description": "Repository name" }, + "commit_oid": { "type": "string", "description": "Commit OID" }, + "path": { "type": "string", "description": "File path in the repo" } + }, + "required": ["workspace", "repo", "commit_oid", "path"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let repo_name = arg_str(&args, "repo")?; + let commit_oid = arg_str(&args, "commit_oid")?; + let path = arg_str(&args, "path")?; + + let repo = require_repo_member(git, ctx.user_id, workspace, repo_name).await?; + + let mut tree_client = TreeServiceClient::new(git.channel.clone()); + let entry_resp = tree_client + .tree_entry_by_path_from_commit(p::TreeEntryByPathFromCommitRequest { + repo_id: repo.id.to_string(), + commit_oid: Some(p::ObjectId { value: commit_oid.to_string() }), + path: path.to_string(), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let entry = entry_resp + .entry + .ok_or_else(|| AiError::Config(format!("file not found: {path}")))?; + + if entry.kind == p::TreeKind::Tree as i32 { + return Err(AiError::Config(format!("'{path}' is a directory, not a file"))); + } + + let blob_oid = entry + .oid + .ok_or_else(|| AiError::Response("entry has no oid".to_string()))?; + + let mut blob_client = BlobServiceClient::new(git.channel.clone()); + let blob_resp = blob_client + .blob_load(p::BlobLoadRequest { + repo_id: repo.id.to_string(), + id: Some(blob_oid), + path: path.to_string(), + }) + .await + .map_err(rpc_err)? + .into_inner(); + + let content = String::from_utf8_lossy(&blob_resp.blob).to_string(); + + let truncated = content.len() > 64_000; + let content = if truncated { + format!("{}...(truncated)", &content[..64_000]) + } else { + content + }; + + Ok(json!({ + "path": path, + "content": content, + "size": blob_resp.blob.len(), + "truncated": truncated, + })) + } +} diff --git a/lib/service/agent/issue_tools/helpers.rs b/lib/service/agent/issue_tools/helpers.rs new file mode 100644 index 0000000..b6e835d --- /dev/null +++ b/lib/service/agent/issue_tools/helpers.rs @@ -0,0 +1,59 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::register::ToolRegister; +use db::sqlx; +use serde_json::Value; +use uuid::Uuid; + +use crate::agent::run::{AppAgentContext, GitAgentContext}; + +pub fn register_issue_tools(tools: &mut ToolRegister) { + tools.register(super::issue::IssueListTool::new()); + tools.register(super::issue::IssueGetTool::new()); + tools.register(super::issue::IssueCommentsTool::new()); + tools.register(super::issue::IssueEventsTool::new()); +} + +pub(super) async fn require_workspace_member( + git: &GitAgentContext, + user_id: Uuid, + workspace_name: &str, +) -> AiResult { + let wk_id: Uuid = sqlx::query_scalar( + "SELECT id FROM workspace WHERE name = $1", + ) + .bind(workspace_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| AiError::Config(format!("workspace '{workspace_name}' not found")))?; + + let is_member: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM wk_member \ + WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(wk_id) + .bind(user_id) + .fetch_one(git.db.reader()) + .await + .map_err(AiError::Database)?; + + if is_member == 0 { + return Err(AiError::Config(format!( + "user is not a member of workspace '{workspace_name}'" + ))); + } + + Ok(wk_id) +} + +pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> { + ctx.git + .as_ref() + .ok_or_else(|| AiError::Config("issue tools are not available in this session".to_string())) +} + +pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> { + args.get(key) + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config(format!("'{key}' parameter is required"))) +} diff --git a/lib/service/agent/issue_tools/issue.rs b/lib/service/agent/issue_tools/issue.rs new file mode 100644 index 0000000..6cb3769 --- /dev/null +++ b/lib/service/agent/issue_tools/issue.rs @@ -0,0 +1,353 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use db::sqlx; +use serde_json::{json, Value}; + +use super::helpers::{arg_str, git_ctx, require_workspace_member}; +use crate::agent::run::AppAgentContext; + +pub struct IssueListTool; + +impl IssueListTool { + pub fn new() -> Self { Self } +} + +impl Default for IssueListTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for IssueListTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "issue_list" } + + fn description(&self) -> &'static str { + "List issues in a workspace with optional filters: state (open/closed), priority, label, milestone, assignee." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "state": { "type": "string", "description": "Filter by state: 'open' or 'closed'" }, + "priority": { "type": "string", "description": "Filter by priority" }, + "label": { "type": "string", "description": "Filter by label name" }, + "milestone": { "type": "string", "description": "Filter by milestone title" }, + "assignee": { "type": "string", "description": "Filter by assignee username" }, + "limit": { "type": "integer", "description": "Max results (default 20, max 100)" } + }, + "required": ["workspace"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(20).min(100); + + let mut conditions = vec!["i.wk = $1".to_string(), "i.deleted_at IS NULL".to_string()]; + let mut params: Vec = vec![wk_id.to_string()]; + let mut idx = 2i32; + + for (arg, col) in [ + ("state", "i.state"), + ("priority", "i.priority"), + ] { + if let Some(v) = args.get(arg).and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + conditions.push(format!("{col} = ${idx}")); + params.push(v.to_string()); + idx += 1; + } + } + + if let Some(v) = args.get("label").and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + conditions.push(format!("EXISTS(SELECT 1 FROM issue_label il INNER JOIN label l ON l.id = il.label WHERE il.issue = i.id AND l.name = ${idx})")); + params.push(v.to_string()); + idx += 1; + } + + if let Some(v) = args.get("milestone").and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + conditions.push(format!("EXISTS(SELECT 1 FROM issue_milestone im INNER JOIN milestone m ON m.id = im.milestone WHERE im.issue = i.id AND m.title = ${idx})")); + params.push(v.to_string()); + idx += 1; + } + + if let Some(v) = args.get("assignee").and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + conditions.push(format!("EXISTS(SELECT 1 FROM issue_assignee ia INNER JOIN \"user\" u ON u.id = ia.\"user\" WHERE ia.issue = i.id AND u.username = ${idx})")); + params.push(v.to_string()); + idx += 1; + } + + let where_clause = conditions.join(" AND "); + let query = format!( + "SELECT i.number, i.title, i.body, i.state, i.priority, \ + i.closed_at, i.due_at, i.created_at \ + FROM issue i WHERE {where_clause} \ + ORDER BY i.created_at DESC LIMIT ${idx}", + ); + + let mut q = sqlx::query_as::<_, IssueRow>(db::sqlx::AssertSqlSafe(query)); + q = q.bind(wk_id); + for i in 1..params.len() { + q = q.bind(¶ms[i]); + } + q = q.bind(limit); + + let rows = q.fetch_all(git.db.reader()).await.map_err(AiError::Database)?; + + let issues: Vec = rows.iter().map(|r| json!({ + "number": r.number, + "title": r.title, + "state": r.state, + "priority": r.priority, + "body": r.body.as_ref().map(|b| if b.len() > 500 { format!("{}...", &b[..500]) } else { b.clone() }), + "created_at": r.created_at.to_rfc3339(), + })).collect(); + + Ok(json!({ "issues": issues, "count": issues.len() })) + } +} + +#[derive(sqlx::FromRow)] +struct IssueRow { + number: i64, + title: String, + body: Option, + state: String, + priority: String, + due_at: Option>, + closed_at: Option>, + created_at: chrono::DateTime, +} + +pub struct IssueGetTool; + +impl IssueGetTool { + pub fn new() -> Self { Self } +} + +impl Default for IssueGetTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for IssueGetTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "issue_get" } + + fn description(&self) -> &'static str { + "Get full details of a single issue by its number." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "number": { "type": "integer", "description": "Issue number" } + }, + "required": ["workspace", "number"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let number = args.get("number").and_then(|v| v.as_i64()) + .ok_or_else(|| AiError::Config("'number' parameter is required".to_string()))?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + + let row = sqlx::query_as::<_, IssueRow>( + "SELECT number, title, body, state, priority, \ + closed_at, due_at, created_at \ + FROM issue WHERE wk = $1 AND number = $2 AND deleted_at IS NULL", + ) + .bind(wk_id) + .bind(number) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| AiError::Config(format!("issue #{number} not found")))?; + + // Load labels + #[derive(sqlx::FromRow)] + struct LabelRow { name: String } + + let labels: Vec = sqlx::query_as::<_, LabelRow>( + "SELECT l.name FROM label l \ + INNER JOIN issue_label il ON il.label = l.id \ + WHERE il.issue = (SELECT id FROM issue WHERE wk = $1 AND number = $2)", + ) + .bind(wk_id).bind(number) + .fetch_all(git.db.reader()).await.map_err(AiError::Database)? + .iter().map(|r| r.name.clone()).collect(); + + // Load assignees + #[derive(sqlx::FromRow)] + struct AssigneeRow { username: String } + + let assignees: Vec = sqlx::query_as::<_, AssigneeRow>( + "SELECT u.username FROM \"user\" u \ + INNER JOIN issue_assignee ia ON ia.\"user\" = u.id \ + WHERE ia.issue = (SELECT id FROM issue WHERE wk = $1 AND number = $2)", + ) + .bind(wk_id).bind(number) + .fetch_all(git.db.reader()).await.map_err(AiError::Database)? + .iter().map(|r| r.username.clone()).collect(); + + Ok(json!({ + "number": row.number, + "title": row.title, + "body": row.body, + "state": row.state, + "priority": row.priority, + "labels": labels, + "assignees": assignees, + "created_at": row.created_at.to_rfc3339(), + "due_at": row.due_at.map(|d| d.to_rfc3339()), + "closed_at": row.closed_at.map(|d| d.to_rfc3339()), + })) + } +} + +pub struct IssueCommentsTool; + +impl IssueCommentsTool { + pub fn new() -> Self { Self } +} + +impl Default for IssueCommentsTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for IssueCommentsTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "issue_comments" } + + fn description(&self) -> &'static str { + "List comments on an issue, ordered by time." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "number": { "type": "integer", "description": "Issue number" }, + "limit": { "type": "integer", "description": "Max results (default 50)" } + }, + "required": ["workspace", "number"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let number = args.get("number").and_then(|v| v.as_i64()) + .ok_or_else(|| AiError::Config("'number' parameter is required".to_string()))?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(50).min(200); + + #[derive(sqlx::FromRow)] + struct CommentRow { + body: String, + username: String, + created_at: chrono::DateTime, + } + + let rows: Vec = sqlx::query_as( + "SELECT ic.body, u.username, ic.created_at \ + FROM issue_comment ic \ + INNER JOIN \"user\" u ON u.id = ic.author \ + WHERE ic.issue = (SELECT id FROM issue WHERE wk = $1 AND number = $2) \ + AND ic.deleted_at IS NULL \ + ORDER BY ic.created_at ASC LIMIT $3", + ) + .bind(wk_id).bind(number).bind(limit) + .fetch_all(git.db.reader()).await.map_err(AiError::Database)?; + + let comments: Vec = rows.iter().map(|r| json!({ + "author": r.username, + "body": r.body, + "created_at": r.created_at.to_rfc3339(), + })).collect(); + + Ok(json!({ "issue_number": number, "comments": comments, "count": comments.len() })) + } +} + +pub struct IssueEventsTool; + +impl IssueEventsTool { + pub fn new() -> Self { Self } +} + +impl Default for IssueEventsTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for IssueEventsTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "issue_events" } + + fn description(&self) -> &'static str { + "List the timeline of events for an issue (created, commented, closed, labeled, etc)." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "number": { "type": "integer", "description": "Issue number" } + }, + "required": ["workspace", "number"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let number = args.get("number").and_then(|v| v.as_i64()) + .ok_or_else(|| AiError::Config("'number' parameter is required".to_string()))?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + + #[derive(sqlx::FromRow)] + struct EventRow { + event: String, + from_value: Option, + to_value: Option, + username: Option, + created_at: chrono::DateTime, + } + + let rows: Vec = sqlx::query_as( + "SELECT e.event, e.from_value, e.to_value, u.username, e.created_at \ + FROM issue_event e \ + LEFT JOIN \"user\" u ON u.id = e.actor \ + WHERE e.issue = (SELECT id FROM issue WHERE wk = $1 AND number = $2) \ + ORDER BY e.created_at ASC", + ) + .bind(wk_id).bind(number) + .fetch_all(git.db.reader()).await.map_err(AiError::Database)?; + + let events: Vec = rows.iter().map(|r| json!({ + "event": r.event, + "actor": r.username, + "from": r.from_value, + "to": r.to_value, + "created_at": r.created_at.to_rfc3339(), + })).collect(); + + Ok(json!({ "issue_number": number, "events": events, "count": events.len() })) + } +} diff --git a/lib/service/agent/issue_tools/mod.rs b/lib/service/agent/issue_tools/mod.rs new file mode 100644 index 0000000..bcf33ae --- /dev/null +++ b/lib/service/agent/issue_tools/mod.rs @@ -0,0 +1,4 @@ +pub mod helpers; +pub mod issue; + +pub use helpers::register_issue_tools; diff --git a/lib/service/agent/memory.rs b/lib/service/agent/memory.rs new file mode 100644 index 0000000..91da680 --- /dev/null +++ b/lib/service/agent/memory.rs @@ -0,0 +1,246 @@ +use ai::{ + error::{AiError, AiResult}, + tool::tools::FunctionCall, +}; +use async_trait::async_trait; +use chrono::Utc; +use db::sqlx; +use serde_json::{json, Value}; +use tracing::info; +use uuid::Uuid; + +use super::run::AppAgentContext; +use crate::error::AppError; +use crate::AppService; +pub struct SaveMemoryTool; + +impl SaveMemoryTool { + pub fn new() -> Self { + Self + } +} + +impl Default for SaveMemoryTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl FunctionCall for SaveMemoryTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { + "save_memory" + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "A short, descriptive key for the memory (e.g. 'user_preference_language', 'project_architecture')" + }, + "value": { + "type": "string", + "description": "The information to remember" + }, + "importance": { + "type": "integer", + "description": "Importance level 0-10 (10 = most important). Default: 5", + "minimum": 0, + "maximum": 10 + } + }, + "required": ["key", "value"] + }) + } + + async fn call( + &self, + context: &mut Self::Context, + args: Value, + ) -> AiResult { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config("key parameter is required".to_string()))?; + + let value = args + .get("value") + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config("value parameter is required".to_string()))?; + + let importance = args + .get("importance") + .and_then(|v| v.as_i64()) + .unwrap_or(5) + .clamp(0, 10) as i32; + + context.pending_memories.push(PendingMemory { + key: key.to_string(), + value: value.to_string(), + importance, + }); + + Ok(json!({ + "success": true, + "key": key, + "message": format!("Memory '{}' saved (importance: {})", key, importance) + })) + } +} +pub struct RecallMemoryTool { + memories_json: String, +} + +impl RecallMemoryTool { + pub fn new(memories_json: String) -> Self { + Self { memories_json } + } +} + +#[async_trait] +impl FunctionCall for RecallMemoryTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { + "recall_memory" + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Optional search query to filter memories by key or content" + } + } + }) + } + + async fn call( + &self, + _context: &mut Self::Context, + args: Value, + ) -> AiResult { + let query = args + .get("query") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + if query.is_empty() { + return Ok(json!({ + "memories": self.memories_json, + "count": "all" + })); + } + + Ok(json!({ + "memories": self.memories_json, + "query": query, + "hint": "Search the memories above for matches to your query" + })) + } +} +#[derive(Debug, Clone)] +pub struct PendingMemory { + pub key: String, + pub value: String, + pub importance: i32, +} + +impl AppService { + pub(crate) async fn agent_load_memories( + &self, + session_id: Uuid, + ) -> Result<(String, Vec<(Uuid, String, String, i32)>), AppError> { + let rows: Vec<(Uuid, String, String, i32)> = sqlx::query_as( + "SELECT id, key, value, importance \ + FROM agent_long_term_memory \ + WHERE session = $1 AND deleted_at IS NULL \ + ORDER BY importance DESC, updated_at DESC \ + LIMIT 50", + ) + .bind(session_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if rows.is_empty() { + return Ok((String::new(), rows)); + } + + let mut formatted = String::from("Long-term memories for this session:\n"); + for (_, key, value, importance) in &rows { + formatted.push_str(&format!( + "- [{}] {} (importance: {})\n", + key, value, importance + )); + } + + Ok((formatted, rows)) + } + pub(crate) async fn agent_persist_memories( + &self, + session_id: Uuid, + memories: &[PendingMemory], + ) -> Result<(), AppError> { + if memories.is_empty() { + return Ok(()); + } + + let now = Utc::now(); + for mem in memories { + sqlx::query( + "INSERT INTO agent_long_term_memory \ + (id, session, key, value, importance, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $6) \ + ON CONFLICT (session, key) WHERE deleted_at IS NULL \ + DO UPDATE SET value = $4, importance = $5, updated_at = $6", + ) + .bind(Uuid::now_v7()) + .bind(session_id) + .bind(&mem.key) + .bind(&mem.value) + .bind(mem.importance) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + info!( + session_id = %session_id, + count = memories.len(), + "persisted long-term memories from agent run" + ); + + Ok(()) + } + #[allow(dead_code)] + pub(crate) async fn agent_touch_memories( + &self, + memory_ids: &[Uuid], + ) -> Result<(), AppError> { + if memory_ids.is_empty() { + return Ok(()); + } + + let now = Utc::now(); + sqlx::query( + "UPDATE agent_long_term_memory \ + SET last_used_at = $1 \ + WHERE id = ANY($2::uuid[])", + ) + .bind(now) + .bind(memory_ids) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } +} diff --git a/lib/service/agent/memory_provider.rs b/lib/service/agent/memory_provider.rs new file mode 100644 index 0000000..686a7d1 --- /dev/null +++ b/lib/service/agent/memory_provider.rs @@ -0,0 +1,168 @@ +use ai::error::AiResult; +use ai::memory::{MemoryEntry, MemoryProvider}; +use async_trait::async_trait; +use chrono::Utc; +use db::sqlx; +use uuid::Uuid; +#[derive(Clone)] +pub struct SimpleMemoryProvider { + db: db::AppDatabase, +} + +impl SimpleMemoryProvider { + pub fn new(db: db::AppDatabase) -> Self { + Self { db } + } +} + +#[async_trait] +impl MemoryProvider for SimpleMemoryProvider { + fn name(&self) -> &'static str { + "simple" + } + + async fn save( + &self, + session_id: Uuid, + key: &str, + value: &str, + importance: i32, + ) -> AiResult<()> { + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_long_term_memory \ + (id, session, key, value, importance, last_used_at, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $6, $6) \ + ON CONFLICT ON CONSTRAINT idx_agent_ltm_session_key \ + DO UPDATE SET value = $4, importance = $5, last_used_at = $6, updated_at = $6", + ) + .bind(Uuid::now_v7()) + .bind(session_id) + .bind(key) + .bind(value) + .bind(importance) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| { + ai::error::AiError::Response(format!("memory save error: {e}")) + })?; + + Ok(()) + } + + async fn recall( + &self, + session_id: Uuid, + query: &str, + limit: usize, + ) -> AiResult> { + use db::sqlx::FromRow; + + #[derive(Debug, FromRow)] + struct Row { + key: String, + value: String, + importance: i32, + last_used_at: Option>, + } + + let rows: Vec = sqlx::query_as( + "SELECT key, value, importance, last_used_at \ + FROM agent_long_term_memory \ + WHERE session = $1 \ + AND deleted_at IS NULL \ + AND (value ILIKE $2 OR key ILIKE $2) \ + ORDER BY importance DESC, last_used_at DESC NULLS LAST \ + LIMIT $3", + ) + .bind(session_id) + .bind(format!("%{query}%")) + .bind(limit as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| { + ai::error::AiError::Response(format!("memory recall error: {e}")) + })?; + + let entries: Vec = rows + .into_iter() + .map(|r| MemoryEntry { + key: r.key, + value: r.value, + importance: r.importance, + last_used_at: r.last_used_at.map(|dt| dt.to_rfc3339()), + }) + .collect(); + if !entries.is_empty() { + let now = Utc::now(); + let _ = sqlx::query( + "UPDATE agent_long_term_memory \ + SET last_used_at = $1, updated_at = $1 \ + WHERE session = $2 AND value ILIKE $3 AND deleted_at IS NULL", + ) + .bind(now) + .bind(session_id) + .bind(format!("%{query}%")) + .execute(self.db.writer()) + .await; + } + + Ok(entries) + } + + async fn forget(&self, session_id: Uuid, key: &str) -> AiResult<()> { + let now = Utc::now(); + sqlx::query( + "UPDATE agent_long_term_memory \ + SET deleted_at = $1, updated_at = $1 \ + WHERE session = $2 AND key = $3 AND deleted_at IS NULL", + ) + .bind(now) + .bind(session_id) + .bind(key) + .execute(self.db.writer()) + .await + .map_err(|e| { + ai::error::AiError::Response(format!("memory forget error: {e}")) + })?; + + Ok(()) + } + + async fn build_context_block(&self, session_id: Uuid) -> AiResult { + use db::sqlx::FromRow; + + #[derive(Debug, FromRow)] + struct Entry { + key: String, + value: String, + } + + let rows: Vec = sqlx::query_as( + "SELECT key, value \ + FROM agent_long_term_memory \ + WHERE session = $1 AND deleted_at IS NULL \ + ORDER BY importance DESC, last_used_at DESC NULLS LAST \ + LIMIT 20", + ) + .bind(session_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| { + ai::error::AiError::Response(format!("memory context error: {e}")) + })?; + + if rows.is_empty() { + return Ok(String::new()); + } + + let mut block = String::from("\n"); + for row in &rows { + block.push_str(&format!("- {}: {}\n", row.key, row.value)); + } + block.push_str(""); + + Ok(block) + } +} diff --git a/lib/service/agent/mod.rs b/lib/service/agent/mod.rs new file mode 100644 index 0000000..77a66c0 --- /dev/null +++ b/lib/service/agent/mod.rs @@ -0,0 +1,17 @@ +pub mod billing; +pub mod compaction; +pub mod config; +pub mod context; +pub mod conversation; +pub mod git_tools; +pub mod issue_tools; +pub mod memory; +pub mod memory_provider; +pub mod persistence; +pub mod run; +pub mod session; +pub mod sse; +pub mod tools; +pub mod trace; +pub mod types; +pub mod workspace_tools; diff --git a/lib/service/agent/persistence.rs b/lib/service/agent/persistence.rs new file mode 100644 index 0000000..8429b40 --- /dev/null +++ b/lib/service/agent/persistence.rs @@ -0,0 +1,194 @@ +use chrono::Utc; +use db::sqlx; +use uuid::Uuid; + +use super::types::{AgentCostInfo, AgentStepInfo, AgentToolCallInfo, BillingRecord, SessionContext}; +use crate::error::AppError; +use crate::AppService; + +impl AppService { + pub(super) async fn persist_user_message( + &self, + conversation_id: Uuid, + user_id: Uuid, + content: &str, + ) -> Result { + let message_id = Uuid::now_v7(); + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_message \ + (id, conversation, role, author, content, content_type, status, created_at, updated_at) \ + VALUES ($1, $2, 'user', $3, $4, 'text', 'completed', $5, $5)", + ) + .bind(message_id) + .bind(conversation_id) + .bind(user_id) + .bind(content) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(message_id) + } + + pub(super) async fn persist_assistant_message( + &self, + conversation_id: Uuid, + _session_id: Uuid, + content: &str, + reasoning_content: Option<&str>, + invocation_id: Uuid, + ) -> Result { + let message_id = Uuid::now_v7(); + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_message \ + (id, conversation, role, content, content_type, status, \ + model_invocation, reasoning_content, created_at, updated_at) \ + VALUES ($1, $2, 'assistant', $3, 'text', 'completed', $4, $5, $6, $6)", + ) + .bind(message_id) + .bind(conversation_id) + .bind(content) + .bind(invocation_id) + .bind(reasoning_content) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(message_id) + } + + pub(super) async fn persist_billing_and_deduct( + &self, + ctx: &SessionContext, + invocation_id: Uuid, + input_tokens: i64, + output_tokens: i64, + ) -> Result, AppError> { + let cost_result = self + .agent_calculate_cost(ctx.model_version_id, input_tokens, output_tokens) + .await?; + + let (cost, currency) = match cost_result { + Some((c, cur)) => (c, cur), + None => { + let record = BillingRecord { + invocation_id, + session_id: ctx.session_id, + model_version_id: ctx.model_version_id, + input_tokens, + output_tokens, + cached_input_tokens: 0, + cache_read_tokens: 0, + cache_write_tokens: 0, + reasoning_tokens: 0, + total_tokens: input_tokens.saturating_add(output_tokens), + cost: None, + currency: None, + created_at: Utc::now(), + }; + self.agent_record_usage(&record).await?; + return Ok(None); + } + }; + + let record = BillingRecord { + invocation_id, + session_id: ctx.session_id, + model_version_id: ctx.model_version_id, + input_tokens, + output_tokens, + cached_input_tokens: 0, + cache_read_tokens: 0, + cache_write_tokens: 0, + reasoning_tokens: 0, + total_tokens: input_tokens.saturating_add(output_tokens), + cost: Some(cost), + currency: Some(currency.clone()), + created_at: Utc::now(), + }; + self.agent_record_usage(&record).await?; + + if let Err(e) = self.agent_deduct_billing(ctx, cost).await { + tracing::warn!( + invocation_id = %invocation_id, + error = %e, + "agent billing deduction failed" + ); + } + + Ok(Some(AgentCostInfo { + amount: cost.to_string(), + currency, + })) + } + + pub(super) async fn update_conversation_timestamp( + &self, + conversation_id: Uuid, + ) -> Result<(), AppError> { + let now = Utc::now(); + sqlx::query( + "UPDATE agent_conversation SET last_message_at = $1, updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(conversation_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } + + pub(super) async fn update_conversation_title( + &self, + conversation_id: Uuid, + title: &str, + ) -> Result<(), AppError> { + let now = Utc::now(); + sqlx::query( + "UPDATE agent_conversation SET title = $1, updated_at = $2 WHERE id = $3", + ) + .bind(title) + .bind(now) + .bind(conversation_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } +} + +#[allow(dead_code)] +pub(super) fn step_info_from_agent(step: ai::agent::AgentStep) -> AgentStepInfo { + AgentStepInfo { + index: step.index, + assistant: step.assistant, + tool_calls: step + .tool_calls + .into_iter() + .map(tool_call_info_from_record) + .collect(), + reflection: step.reflection, + } +} + +#[allow(dead_code)] +pub(super) fn tool_call_info_from_record(record: ai::agent::ToolCallRecord) -> AgentToolCallInfo { + AgentToolCallInfo { + id: record.id, + name: record.name, + arguments: record.arguments, + output: record.output, + error: record.error, + elapsed_ms: record.elapsed_ms, + } +} + +pub(super) fn stream_error(error: &str) -> String { + let payload = serde_json::json!({ + "type": "error", + "error": error, + }); + format!("data: {}\n\n", payload) +} diff --git a/lib/service/agent/run.rs b/lib/service/agent/run.rs new file mode 100644 index 0000000..1c17ff3 --- /dev/null +++ b/lib/service/agent/run.rs @@ -0,0 +1,283 @@ +use std::sync::Arc; +use std::time::Duration; + +use ai::{ + agent::RigAgent, + tool::register::ToolRegister, +}; +use cache::AppCache; +use db::AppDatabase; +use tonic::transport::Channel; +use tracing::{info, warn}; +use uuid::Uuid; + +use super::types::{ + AgentRunRequest, AgentRunResponse, AgentUsageInfo, +}; +use crate::error::AppError; +use crate::AppService; + +#[derive(Clone)] +pub struct GitAgentContext { + pub channel: Channel, + pub db: AppDatabase, + pub cache: AppCache, +} + +#[derive(Clone)] +pub struct AppAgentContext { + pub user_id: Uuid, + pub session_id: Uuid, + pub conversation_id: Uuid, + pub pending_title: Option, + pub pending_memories: Vec, + pub git: Option, +} + +impl AppService { + pub async fn agent_run( + &self, + user_id: Uuid, + req: AgentRunRequest, + ) -> Result { + let ctx = self.agent_session_context(req.session_id, user_id).await?; + + let conversation_id = req + .conversation_id + .ok_or_else(|| AppError::BadRequest("conversation_id is required".to_string()))?; + let conversation = self + .agent_require_conversation_access(user_id, conversation_id) + .await?; + if conversation.session != ctx.session_id { + return Err(AppError::BadRequest( + "conversation does not belong to session".to_string(), + )); + } + + let ai_client = self + .agent_build_ai_client(ctx.model_version_id) + .await?; + + let agent_config = self.agent_build_config( + &ctx, + req.max_steps, + ); + + self.agent_maybe_compact(&ai_client, &ctx.provider_model_name, conversation_id) + .await + .unwrap_or_else(|e| { + warn!(error = %e, "compaction check failed, continuing"); + }); + + let mut tools: ToolRegister = ToolRegister::new(); + if conversation.title == "New Chat" || conversation.title.trim().is_empty() { + tools.register(super::tools::SetTitleTool::new()); + } + tools.register(super::memory::SaveMemoryTool::new()); + + let (memories_text, _memory_rows) = + self.agent_load_memories(ctx.session_id).await?; + if !memories_text.is_empty() { + tools.register(super::memory::RecallMemoryTool::new(memories_text)); + } + + super::git_tools::register_git_tools(&mut tools); + super::workspace_tools::register_workspace_tools(&mut tools); + super::issue_tools::register_issue_tools(&mut tools); + + let agent_ctx = AppAgentContext { + user_id, + session_id: ctx.session_id, + conversation_id, + pending_title: None, + pending_memories: Vec::new(), + git: Some(GitAgentContext { + channel: self.git.clone(), + db: self.db.clone(), + cache: self.cache.clone(), + }), + }; + + let shared_ctx = Arc::new(tokio::sync::Mutex::new(agent_ctx)); + let mut tool_set = ai::agent::RigToolSet::from_register(&tools, shared_ctx); + let rig_tools = tool_set.take_tools(); + + let agent = RigAgent::new(ai_client.clone(), agent_config) + .map_err(|e| AppError::AiError(e))?; + + let timeout_secs = req.timeout_secs.unwrap_or(300); + + let agent_request = self + .agent_build_request( + &ai_client, + &ctx, + req.conversation_id, + req.input.clone(), + Some(timeout_secs), + ) + .await?; + + let invocation_id = Uuid::now_v7(); + info!( + invocation_id = %invocation_id, + session_id = %ctx.session_id, + user_id = %user_id, + billing_target = ?ctx.billing_target, + "agent run starting" + ); + + let user_message_id = self + .persist_user_message(conversation_id, user_id, &req.input) + .await?; + let result = match tokio::time::timeout( + Duration::from_secs(timeout_secs), + agent.chat(agent_request, rig_tools), + ) + .await + { + Ok(Ok(output)) => Ok(output), + Ok(Err(e)) => Err(e), + Err(_) => Err(ai::error::AiError::Timeout { + seconds: timeout_secs, + }), + }; + + let agent_ctx = tool_set.into_context(); + + match result { + Ok(output) => { + let message_id = self + .persist_assistant_message( + conversation_id, + ctx.session_id, + &output, + None, + invocation_id, + ) + .await?; + + let cost_info = + self.persist_billing_and_deduct( + &ctx, + invocation_id, + 0, // input_tokens not tracked in chat() mode + 0, // output_tokens not tracked in chat() mode + ) + .await?; + + self.agent_record_invocation( + invocation_id, + ctx.session_id, + Some(conversation_id), + Some(message_id), + ctx.model_version_id, + "completed", + None, + ) + .await?; + + self.update_conversation_timestamp(conversation_id) + .await?; + + let title = agent_ctx + .pending_title + .filter(|t| !t.trim().is_empty()) + .or_else(|| { + let first_line = + req.input.lines().next().unwrap_or(&req.input); + let truncated: String = + first_line.chars().take(50).collect(); + if truncated.trim().is_empty() { + None + } else if first_line.len() > 50 { + Some(format!("{}…", truncated.trim_end())) + } else { + Some(truncated.trim().to_string()) + } + }); + + if let Some(new_title) = &title { + if let Err(e) = self + .update_conversation_title( + conversation_id, + new_title, + ) + .await + { + warn!( + conversation_id = %conversation_id, + error = %e, + "failed to update conversation title" + ); + } + } + + if let Err(e) = self + .agent_persist_memories( + ctx.session_id, + &agent_ctx.pending_memories, + ) + .await + { + warn!( + invocation_id = %invocation_id, + error = %e, + "failed to persist agent memories" + ); + } + + info!( + invocation_id = %invocation_id, + message_id = %message_id, + "agent run completed successfully" + ); + + Ok(AgentRunResponse { + message_id, + conversation_id, + output, + steps: Vec::new(), + usage: AgentUsageInfo { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + }, + cost: cost_info, + }) + } + Err(e) => { + warn!( + invocation_id = %invocation_id, + error = %e, + "agent run failed" + ); + + let error_content = format!( + "I encountered an error while processing your request: {e}" + ); + let _ = self + .persist_assistant_message( + conversation_id, + ctx.session_id, + &error_content, + None, + invocation_id, + ) + .await; + + self.agent_record_invocation( + invocation_id, + ctx.session_id, + Some(conversation_id), + Some(user_message_id), + ctx.model_version_id, + "failed", + Some(&e.to_string()), + ) + .await?; + + Err(AppError::AiError(e)) + } + } + } +} diff --git a/lib/service/agent/session.rs b/lib/service/agent/session.rs new file mode 100644 index 0000000..f1e15fb --- /dev/null +++ b/lib/service/agent/session.rs @@ -0,0 +1,486 @@ +use chrono::Utc; +use db::sqlx; +use model::agent::AgentSessionModel; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +#[derive(Debug, Clone, Deserialize, ToSchema)] +pub struct CreateAgentSession { + pub name: String, + pub agent_kind: String, + pub model_version: Uuid, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub system_prompt: Option, + #[serde(default)] + pub temperature: Option, + pub max_output_tokens: Option, + pub tool_policy: Option, + pub toolset_json: Option, + pub memory_provider: Option, + pub memory_provider_config: Option, + pub iteration_budget: Option, + pub source: Option, + pub visibility: Option, + pub wk: Option, + pub knowledge_base_ids: Option>, + pub variables: Option, +} + +#[derive(Debug, Clone, Deserialize, ToSchema)] +pub struct UpdateAgentSession { + pub name: Option, + pub description: Option, + pub system_prompt: Option, + pub temperature: Option, + pub max_output_tokens: Option, + pub model_version: Option, + pub tool_policy: Option, + pub toolset_json: Option, + pub memory_provider: Option, + pub memory_provider_config: Option, + pub iteration_budget: Option, + pub visibility: Option, + pub enabled: Option, + pub knowledge_base_ids: Option>, + pub variables: Option, +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct AgentSessionResponse { + pub id: Uuid, + pub name: String, + pub description: Option, + pub agent_kind: String, + pub model_version: Option, + pub system_prompt: Option, + pub temperature: Option, + pub max_output_tokens: Option, + pub tool_policy: Option, + pub toolset_json: Option, + pub memory_provider: Option, + pub iteration_budget: Option, + pub source: Option, + pub parent_session_id: Option, + pub visibility: String, + pub version: i32, + pub enabled: bool, + pub user: Option, + pub wk: Option, + pub variables: Option, + pub published_at: Option>, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +impl AppService { + pub async fn agent_session_create( + &self, + user_id: Uuid, + params: CreateAgentSession, + ) -> Result { + let wk_uuid: Option = if let Some(ref wk_name) = params.wk { + let wk = crate::AppService::workspace_resolve( + &*self, wk_name, + ) + .await?; + let _ = crate::AppService::workspace_require_member( + &*self, wk.id, user_id, + ) + .await?; + Some(wk.id) + } else { + None + }; + + let id = Uuid::now_v7(); + let now = Utc::now(); + let visibility = params.visibility.unwrap_or_else(|| "private".to_string()); + let kb_ids = params.knowledge_base_ids.map(|ids| { + ids.iter().map(|id| id.to_string()).collect::>().join(",") + }); + + let row = sqlx::query_as::<_, AgentSessionModel>( + "INSERT INTO agent_session \ + (id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, enabled, \ + source, toolset_json, memory_provider, memory_provider_config, iteration_budget, \ + created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, 1, true, \ + 'api', '{}', 'simple', '{}', 90, \ + $15, $16, $16) \ + RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, \ + source, parent_session_id, toolset_json, \ + memory_provider, memory_provider_config, iteration_budget, \ + created_by, created_at, updated_at, deleted_at", + ) + .bind(id) + .bind(user_id) + .bind(wk_uuid) + .bind(¶ms.name) + .bind(¶ms.description) + .bind(¶ms.agent_kind) + .bind(params.model_version) + .bind(¶ms.system_prompt) + .bind(params.temperature) + .bind(params.max_output_tokens) + .bind(¶ms.tool_policy) + .bind(&kb_ids) + .bind(¶ms.variables) + .bind(&visibility) + .bind(user_id) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.into()) + } + + pub async fn agent_session_list( + &self, + user_id: Uuid, + ) -> Result, AppError> { + let rows = sqlx::query_as::<_, AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, \ + source, parent_session_id, toolset_json, \ + memory_provider, memory_provider_config, iteration_budget, \ + created_by, created_at, updated_at, deleted_at \ + FROM agent_session \ + WHERE (\"user\" = $1 OR wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1)) \ + AND deleted_at IS NULL \ + ORDER BY updated_at DESC", + ) + .bind(user_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(Into::into).collect()) + } + + pub async fn agent_session_get( + &self, + user_id: Uuid, + session_id: Uuid, + ) -> Result { + let row = sqlx::query_as::<_, AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \ + created_at, updated_at, deleted_at \ + FROM agent_session \ + WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + if row.user != Some(user_id) { + if let Some(wk) = row.wk { + let _ = crate::AppService::workspace_require_member( + &*self, wk, user_id, + ) + .await?; + } else { + return Err(AppError::PermissionDenied); + } + } + + Ok(row.into()) + } + + pub async fn agent_session_update( + &self, + user_id: Uuid, + session_id: Uuid, + params: UpdateAgentSession, + ) -> Result { + let existing = sqlx::query_as::<_, AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \ + created_at, updated_at, deleted_at \ + FROM agent_session \ + WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + if existing.user != Some(user_id) { + if let Some(wk) = existing.wk { + let _ = crate::AppService::workspace_require_admin( + &*self, wk, user_id, + ) + .await?; + } else { + return Err(AppError::PermissionDenied); + } + } + + let now = Utc::now(); + let name = params.name.unwrap_or(existing.name); + let description = params.description.or(existing.description); + let system_prompt = params.system_prompt.or(existing.system_prompt); + let temperature = params.temperature.or(existing.temperature); + let max_output_tokens = params.max_output_tokens.or(existing.max_output_tokens); + let model_version = params.model_version.or(existing.model_version); + let tool_policy = params.tool_policy.or(existing.tool_policy); + let toolset_json = params.toolset_json.or(existing.toolset_json); + let memory_provider = params.memory_provider.or(existing.memory_provider); + let memory_provider_config = params.memory_provider_config.or(existing.memory_provider_config); + let iteration_budget = params.iteration_budget.or(existing.iteration_budget); + let visibility = params.visibility.unwrap_or(existing.visibility); + let enabled = params.enabled.unwrap_or(existing.enabled); + let kb_ids = params + .knowledge_base_ids + .map(|ids| ids.iter().map(|id| id.to_string()).collect::>().join(",")) + .or(existing.knowledge_base_ids); + let variables = params.variables.or(existing.variables); + + let row = sqlx::query_as::<_, AgentSessionModel>( + "UPDATE agent_session SET \ + name = $1, description = $2, system_prompt = $3, temperature = $4, \ + max_output_tokens = $5, model_version = $6, tool_policy = $7, \ + toolset_json = $8, memory_provider = $9, \ + memory_provider_config = $10, iteration_budget = $11, \ + visibility = $12, enabled = $13, knowledge_base_ids = $14, \ + variables = $15, updated_at = $16 \ + WHERE id = $17 AND deleted_at IS NULL \ + RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \ + created_at, updated_at, deleted_at", + ) + .bind(&name) + .bind(&description) + .bind(&system_prompt) + .bind(temperature) + .bind(max_output_tokens) + .bind(model_version) + .bind(&tool_policy) + .bind(&toolset_json) + .bind(&memory_provider) + .bind(&memory_provider_config) + .bind(iteration_budget) + .bind(&visibility) + .bind(enabled) + .bind(&kb_ids) + .bind(&variables) + .bind(now) + .bind(session_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.into()) + } + + pub async fn agent_session_delete( + &self, + user_id: Uuid, + session_id: Uuid, + ) -> Result<(), AppError> { + let existing = sqlx::query_as::<_, AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \ + created_at, updated_at, deleted_at \ + FROM agent_session \ + WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + if existing.user != Some(user_id) { + if let Some(wk) = existing.wk { + let _ = crate::AppService::workspace_require_admin( + &*self, wk, user_id, + ) + .await?; + } else { + return Err(AppError::PermissionDenied); + } + } + + let now = Utc::now(); + sqlx::query( + "UPDATE agent_session SET deleted_at = $1, updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(session_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + pub async fn agent_session_search( + &self, + user_id: Uuid, + query: &str, + limit: u32, + ) -> Result, AppError> { + let rows = sqlx::query_as::<_, AgentSessionModel>( + "SELECT DISTINCT s.id, s.\"user\", s.wk, s.name, s.description, \ + s.agent_kind, s.model_version, \ + s.system_prompt, s.temperature, s.max_output_tokens, s.tool_policy, \ + s.knowledge_base_ids, s.variables, s.visibility, s.version, \ + s.published_at, s.rollback_from_version, s.enabled, \ + s.source, s.parent_session_id, s.toolset_json, \ + s.memory_provider, s.memory_provider_config, s.iteration_budget, \ + s.created_by, s.created_at, s.updated_at, s.deleted_at \ + FROM agent_session s \ + INNER JOIN agent_message m ON m.conversation IN ( \ + SELECT id FROM agent_conversation WHERE session = s.id \ + ) \ + WHERE (s.\"user\" = $1 OR s.wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1)) \ + AND s.deleted_at IS NULL \ + AND m.deleted_at IS NULL \ + AND m.search_vector @@ plainto_tsquery('english', $2) \ + ORDER BY s.updated_at DESC \ + LIMIT $3", + ) + .bind(user_id) + .bind(query) + .bind(limit as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(Into::into).collect()) + } + pub async fn agent_session_update_toolsets( + &self, + user_id: Uuid, + session_id: Uuid, + enabled: Option>, + disabled: Option>, + ) -> Result { + let existing = sqlx::query_as::<_, AgentSessionModel>( + "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \ + created_at, updated_at, deleted_at \ + FROM agent_session \ + WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(session_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; + + if existing.user != Some(user_id) { + if let Some(wk) = existing.wk { + let _ = crate::AppService::workspace_require_admin( + &*self, wk, user_id, + ) + .await?; + } else { + return Err(AppError::PermissionDenied); + } + } + + let toolset_json = { + let mut current: serde_json::Map = + existing + .toolset_json + .as_deref() + .and_then(|s| serde_json::from_str(s).ok()) + .unwrap_or_default(); + + if let Some(en) = enabled { + current.insert( + "enabled".to_string(), + serde_json::Value::Array( + en.into_iter().map(serde_json::Value::String).collect(), + ), + ); + } + if let Some(dis) = disabled { + current.insert( + "disabled".to_string(), + serde_json::Value::Array( + dis.into_iter().map(serde_json::Value::String).collect(), + ), + ); + } + Some(serde_json::to_string(¤t).unwrap_or_default()) + }; + + let now = Utc::now(); + let row = sqlx::query_as::<_, AgentSessionModel>( + "UPDATE agent_session SET toolset_json = $1, updated_at = $2 \ + WHERE id = $3 AND deleted_at IS NULL \ + RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \ + system_prompt, temperature, max_output_tokens, tool_policy, \ + knowledge_base_ids, variables, visibility, version, \ + published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \ + created_at, updated_at, deleted_at", + ) + .bind(&toolset_json) + .bind(now) + .bind(session_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.into()) + } +} + +impl From for AgentSessionResponse { + fn from(m: AgentSessionModel) -> Self { + Self { + id: m.id, + name: m.name, + description: m.description, + agent_kind: m.agent_kind, + model_version: m.model_version, + system_prompt: m.system_prompt, + temperature: m.temperature, + max_output_tokens: m.max_output_tokens, + tool_policy: m.tool_policy, + toolset_json: m.toolset_json, + memory_provider: m.memory_provider, + iteration_budget: m.iteration_budget, + source: m.source, + parent_session_id: m.parent_session_id, + visibility: m.visibility, + version: m.version, + enabled: m.enabled, + user: m.user, + wk: m.wk, + variables: m.variables, + published_at: m.published_at, + created_at: m.created_at, + updated_at: m.updated_at, + } + } +} diff --git a/lib/service/agent/sse.rs b/lib/service/agent/sse.rs new file mode 100644 index 0000000..e3d081b --- /dev/null +++ b/lib/service/agent/sse.rs @@ -0,0 +1,378 @@ +use std::sync::Arc; + +use ai::agent::{RigAgent, RigStreamChunk, RigToolSet}; +use ai::tool::register::ToolRegister; +use serde_json::{json, Value}; +use tokio::sync::mpsc; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use super::run::AppAgentContext; +use super::types::AgentRunRequest; +use crate::error::AppError; +use crate::AppService; + +impl AppService { + pub async fn agent_run_streaming( + &self, + user_id: Uuid, + req: AgentRunRequest, + ) -> Result, AppError> { + let ctx = self.agent_session_context(req.session_id, user_id).await?; + let conversation_id = req + .conversation_id + .ok_or_else(|| AppError::BadRequest("conversation_id is required".to_string()))?; + let conversation = self + .agent_require_conversation_access(user_id, conversation_id) + .await?; + if conversation.session != ctx.session_id { + return Err(AppError::BadRequest( + "conversation does not belong to session".to_string(), + )); + } + + let ai_client = self.agent_build_ai_client(ctx.model_version_id).await?; + let agent_config = self.agent_build_config(&ctx, req.max_steps); + + self.agent_maybe_compact(&ai_client, &ctx.provider_model_name, conversation_id) + .await + .unwrap_or_else(|e| { + warn!(error = %e, "compaction check failed, continuing"); + }); + + let mut tools: ToolRegister = ToolRegister::new(); + if conversation.title == "New Chat" || conversation.title.trim().is_empty() { + tools.register(super::tools::SetTitleTool::new()); + } + tools.register(super::memory::SaveMemoryTool::new()); + let (memories_text, _memory_rows) = + self.agent_load_memories(ctx.session_id).await?; + if !memories_text.is_empty() { + tools.register(super::memory::RecallMemoryTool::new(memories_text)); + } + + // Git RPC tools + super::git_tools::register_git_tools(&mut tools); + super::workspace_tools::register_workspace_tools(&mut tools); + super::issue_tools::register_issue_tools(&mut tools); + + let (tx, rx) = mpsc::unbounded_channel::(); + + let agent = RigAgent::new(ai_client.clone(), agent_config) + .map_err(|e| AppError::AiError(e))?; + + let timeout_secs = req.timeout_secs.unwrap_or(300); + + let agent_request = self + .agent_build_request( + &ai_client, + &ctx, + req.conversation_id, + req.input.clone(), + Some(timeout_secs), + ) + .await?; + + let invocation_id = Uuid::now_v7(); + let ctx_clone = ctx.clone(); + let self_clone = self.clone(); + + info!( + invocation_id = %invocation_id, + session_id = %ctx.session_id, + user_id = %user_id, + "agent sse stream starting" + ); + + if let Err(e) = self.cache.set::( + &format!("agent:stream:active:{}", conversation_id), + &invocation_id, + ).await { + warn!(error = %e, "agent sse: failed to mark stream active"); + } + + let user_message_id = match self + .persist_user_message(conversation_id, user_id, &req.input) + .await + { + Ok(id) => Some(id), + Err(e) => { + let _ = tx.send(super::persistence::stream_error("failed to persist user message")); + let _ = self.cache + .remove(&format!("agent:stream:active:{}", conversation_id)) + .await; + return Err(e); + } + }; + + let first_input = req.input.clone(); + + let shared_ctx = Arc::new(tokio::sync::Mutex::new(AppAgentContext { + user_id, + session_id: ctx.session_id, + conversation_id, + pending_title: None, + pending_memories: Vec::new(), + git: Some(super::run::GitAgentContext { + channel: self.git.clone(), + db: self.db.clone(), + cache: self.cache.clone(), + }), + })); + + let mut tool_set = RigToolSet::from_register(&tools, shared_ctx); + let rig_tools = tool_set.take_tools(); + let (mut chunk_rx, handle) = agent.run(agent_request, rig_tools); + + let trace_svc = self.clone(); + + tokio::spawn(async move { + let mut tracer = super::trace::TraceAccumulator::new( + trace_svc, invocation_id, conversation_id, + ); + let mut phase: &str = "think"; + while let Some(chunk) = chunk_rx.recv().await { + let (new_phase, sse_event) = process_chunk_with_phase(&chunk, phase, &mut tracer).await; + if new_phase != phase { + phase = new_phase; + let _ = tx.send(phase_sse(phase)); + } + if let Some(sse) = sse_event { + let _ = tx.send(sse); + } + } + let agent_result = match tokio::time::timeout( + std::time::Duration::from_secs(timeout_secs), + handle, + ) + .await + { + Ok(Ok(inner)) => inner, + Ok(Err(e)) => Err(ai::error::AiError::Response(e.to_string())), + Err(_) => Err(ai::error::AiError::Timeout { seconds: timeout_secs }), + }; + + let _ = self_clone.cache + .remove(&format!("agent:stream:active:{}", conversation_id)) + .await; + + let agent_ctx = tool_set.into_context(); + + match agent_result { + Ok(result) => { + let reasoning_content: Option = { + let collected: Vec = result + .steps + .iter() + .filter_map(|step| step.reasoning_content.clone()) + .collect(); + if collected.is_empty() { None } else { Some(collected.join("\n\n")) } + }; + + match self_clone + .persist_assistant_message( + conversation_id, + ctx_clone.session_id, + &result.output, + reasoning_content.as_deref(), + invocation_id, + ) + .await + { + Ok(msg_id) => { + for step in &result.steps { + for tc in &step.tool_calls { + let _ = self_clone + .agent_record_tool_call( + invocation_id, ctx_clone.session_id, + Some(conversation_id), Some(msg_id), + &tc.id, &tc.name, + Some(&tc.arguments.to_string()), + tc.output.as_ref().map(|v| v.to_string()).as_deref(), + tc.error.as_deref(), + if tc.error.is_some() { "error" } else { "success" }, + tc.elapsed_ms, + ) + .await; + } + } + + let _ = self_clone.persist_billing_and_deduct( + &ctx_clone, invocation_id, + result.input_tokens, result.output_tokens, + ).await; + + let _ = self_clone.agent_record_invocation( + invocation_id, ctx_clone.session_id, + Some(conversation_id), Some(msg_id), + ctx_clone.model_version_id, "completed", None, + ).await; + + let _ = self_clone.update_conversation_timestamp(conversation_id).await; + + let title = agent_ctx.pending_title + .filter(|t| !t.trim().is_empty()) + .or_else(|| { + let first_line = first_input.lines().next().unwrap_or(&first_input); + let truncated: String = first_line.chars().take(50).collect(); + if truncated.trim().is_empty() { None } + else { Some(if first_line.len() > 50 { format!("{}…", truncated.trim_end()) } else { truncated.trim().to_string() }) } + }); + if let Some(new_title) = &title { + if self_clone.update_conversation_title(conversation_id, new_title).await.is_ok() { + let title_event = serde_json::json!({ + "type": "title_updated", + "conversation_id": conversation_id.to_string(), + "title": new_title, + }); + let _ = tx.send(format!("data: {}\n\n", title_event)); + } + } + + if !agent_ctx.pending_memories.is_empty() { + let _ = self_clone.agent_persist_memories( + ctx_clone.session_id, &agent_ctx.pending_memories, + ).await; + } + + let _ = tx.send(done_sse_with_phase(msg_id, &result.output, "summarize")); + info!(invocation_id = %invocation_id, message_id = %msg_id, "agent sse stream completed"); + } + Err(e) => { + error!(error = %e, "sse: failed to persist assistant message"); + let _ = tx.send(super::persistence::stream_error("persistence failed")); + } + } + } + Err(e) => { + warn!(invocation_id = %invocation_id, error = %e, "agent sse stream failed"); + let _ = tx.send(super::persistence::stream_error(&e.to_string())); + let error_content = format!( + "I encountered an error while processing your request: {}", e + ); + let _ = self_clone.persist_assistant_message( + conversation_id, ctx_clone.session_id, &error_content, None, invocation_id, + ).await; + let _ = self_clone.agent_record_invocation( + invocation_id, ctx_clone.session_id, + Some(conversation_id), user_message_id, + ctx_clone.model_version_id, "failed", Some(&e.to_string()), + ).await; + } + } + }); + + Ok(rx) + } +} +async fn process_chunk_with_phase( + chunk: &RigStreamChunk, + _current_phase: &str, + tracer: &mut super::trace::TraceAccumulator, +) -> (&'static str, Option) { + match chunk { + RigStreamChunk::Thinking { content, .. } => { + tracer.feed_thinking(content).await; + ("think", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::TextDelta { content, .. } => { + tracer.feed_text(content).await; + ("answer", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::ToolCallStarted { tool_call_id, tool_name, arguments } => { + let args_val: Value = serde_json::from_str(arguments).unwrap_or(Value::Null); + tracer.feed_tool_call(tool_call_id, tool_name, &args_val).await; + ("act", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::ToolCallFinished { tool_call_id, tool_name, output, error } => { + let out_val: Option = match output { + o if o.is_empty() => None, + o => serde_json::from_str(o).ok(), + }; + tracer.feed_tool_result(tool_call_id, tool_name, out_val.as_ref(), error.as_deref(), 0).await; + ("act", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::Final { content, input_tokens, output_tokens } => { + tracer.finish(content, *input_tokens as i64, *output_tokens as i64).await; + ("summarize", Some(format_chunk_sse(chunk))) + } + RigStreamChunk::Failed { .. } => ("summarize", None), + RigStreamChunk::SubagentStarted { .. } => ("act", Some(format_chunk_sse(chunk))), + RigStreamChunk::SubagentCompleted { .. } => ("act", Some(format_chunk_sse(chunk))), + RigStreamChunk::SubagentFailed { .. } => ("summarize", Some(format_chunk_sse(chunk))), + } +} + +fn format_chunk_sse(chunk: &RigStreamChunk) -> String { + let payload = match chunk { + RigStreamChunk::TextDelta { index, content } => json!({ + "type": "delta", "index": index, "content": content, + }), + RigStreamChunk::Thinking { index, content } => json!({ + "type": "thinking", "index": index, "content": content, + }), + RigStreamChunk::ToolCallStarted { tool_call_id, tool_name, arguments } => json!({ + "type": "tool_call_started", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "arguments": arguments, + }), + RigStreamChunk::ToolCallFinished { tool_call_id, tool_name, output, error } => json!({ + "type": "tool_call_finished", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "output": output, + "error": error, + }), + RigStreamChunk::SubagentStarted { subagent_id, role, task } => json!({ + "type": "subagent_started", + "subagent_id": subagent_id, + "role": role, + "task": task, + }), + RigStreamChunk::SubagentCompleted { subagent_id, role, task, output } => json!({ + "type": "subagent_completed", + "subagent_id": subagent_id, + "role": role, + "task": task, + "output": output, + }), + RigStreamChunk::SubagentFailed { error } => json!({ + "type": "subagent_failed", + "error": error, + }), + RigStreamChunk::Final { .. } | RigStreamChunk::Failed { .. } => return String::new(), + }; + format!("data: {}\n\n", payload) +} + +fn phase_sse(phase: &str) -> String { + let payload = json!({ + "type": "phase_change", + "phase": phase, + "label": phase_label(phase), + }); + format!("data: {}\n\n", payload) +} + +fn phase_label(phase: &str) -> &str { + match phase { + "think" => "Thinking", + "answer" => "Answering", + "act" => "Acting", + "summarize" => "Summarizing", + _ => phase, + } +} + +fn done_sse_with_phase(message_id: Uuid, output: &str, phase: &str) -> String { + let payload = json!({ + "type": "done", + "message_id": message_id.to_string(), + "status": "completed", + "phase": phase, + "label": phase_label(phase), + "output": output, + }); + format!("data: {}\n\n", payload) +} diff --git a/lib/service/agent/tools.rs b/lib/service/agent/tools.rs new file mode 100644 index 0000000..6aa4668 --- /dev/null +++ b/lib/service/agent/tools.rs @@ -0,0 +1,70 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use serde_json::{json, Value}; + +use super::run::AppAgentContext; +pub struct SetTitleTool; + +impl SetTitleTool { + pub fn new() -> Self { + Self + } +} + +impl Default for SetTitleTool { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl FunctionCall for SetTitleTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { + "set_conversation_title" + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "A concise, descriptive title for the conversation (max 100 characters)" + } + }, + "required": ["title"] + }) + } + + async fn call( + &self, + context: &mut Self::Context, + args: Value, + ) -> AiResult { + let title = args + .get("title") + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config("title parameter is required".to_string()))?; + + let title = title.trim(); + if title.is_empty() { + return Err(AiError::Config("title cannot be empty".to_string())); + } + + let title = if title.len() > 100 { + &title[..100] + } else { + title + }; + + context.pending_title = Some(title.to_string()); + + Ok(json!({ + "success": true, + "title": title + })) + } +} diff --git a/lib/service/agent/trace.rs b/lib/service/agent/trace.rs new file mode 100644 index 0000000..975acd1 --- /dev/null +++ b/lib/service/agent/trace.rs @@ -0,0 +1,268 @@ +use chrono::Utc; +use db::sqlx; +use model::agent::AgentTraceModel; +use serde_json::{json, Value}; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +pub struct TraceContext { + pub invocation_id: Uuid, + pub conversation_id: Uuid, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TraceReplay { + pub invocation_id: Uuid, + pub conversation_id: Uuid, + pub phases: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TracePhaseRow { + pub id: Uuid, + pub sequence: i32, + pub phase: String, + pub label: String, + pub content: Option, + pub tool_calls: Option, + pub tool_results: Option, + pub input_tokens: Option, + pub output_tokens: Option, + pub metadata: Option, + pub created_at: chrono::DateTime, +} + +impl AppService { + pub async fn trace_record( + &self, + ctx: &TraceContext, + sequence: i32, + phase: &str, + content: Option<&str>, + tool_calls: Option<&Value>, + tool_results: Option<&Value>, + input_tokens: Option, + output_tokens: Option, + metadata: Option<&Value>, + ) -> Result { + let id = Uuid::now_v7(); + let now = Utc::now(); + sqlx::query( + "INSERT INTO agent_trace \ + (id, invocation, conversation, sequence, phase, content, \ + tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)", + ) + .bind(id) + .bind(ctx.invocation_id) + .bind(ctx.conversation_id) + .bind(sequence) + .bind(phase) + .bind(content) + .bind(tool_calls) + .bind(tool_results) + .bind(input_tokens) + .bind(output_tokens) + .bind(metadata) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(id) + } + + pub async fn trace_replay_by_invocation( + &self, + invocation_id: Uuid, + ) -> Result { + let rows = sqlx::query_as::<_, AgentTraceModel>( + "SELECT id, invocation, conversation, sequence, phase, content, \ + tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at \ + FROM agent_trace WHERE invocation = $1 ORDER BY sequence ASC", + ) + .bind(invocation_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let conversation_id = rows.first().map(|r| r.conversation).unwrap_or(Uuid::nil()); + + Ok(TraceReplay { + invocation_id, + conversation_id, + phases: rows + .into_iter() + .map(|r| TracePhaseRow { + id: r.id, + sequence: r.sequence, + phase: r.phase.clone(), + label: r.phase_label().to_string(), + content: r.content, + tool_calls: r.tool_calls, + tool_results: r.tool_results, + input_tokens: r.input_tokens, + output_tokens: r.output_tokens, + metadata: r.metadata, + created_at: r.created_at, + }) + .collect(), + }) + } + + pub async fn trace_replay_by_conversation( + &self, + conversation_id: Uuid, + ) -> Result, AppError> { + let rows = sqlx::query_as::<_, AgentTraceModel>( + "SELECT id, invocation, conversation, sequence, phase, content, \ + tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at \ + FROM agent_trace WHERE conversation = $1 ORDER BY invocation, sequence ASC", + ) + .bind(conversation_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut grouped: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + for row in rows { + grouped.entry(row.invocation).or_default().push(row); + } + + Ok(grouped + .into_iter() + .map(|(invocation_id, rows)| TraceReplay { + invocation_id, + conversation_id, + phases: rows + .into_iter() + .map(|r| TracePhaseRow { + id: r.id, + sequence: r.sequence, + phase: r.phase.clone(), + label: r.phase_label().to_string(), + content: r.content, + tool_calls: r.tool_calls, + tool_results: r.tool_results, + input_tokens: r.input_tokens, + output_tokens: r.output_tokens, + metadata: r.metadata, + created_at: r.created_at, + }) + .collect(), + }) + .collect()) + } +} + +pub struct TraceAccumulator { + ctx: TraceContext, + seq: i32, + think_buf: String, + answer_buf: String, + think_tokens: i64, + answer_tokens: i64, + svc: AppService, +} + +impl TraceAccumulator { + pub fn new(svc: AppService, invocation_id: Uuid, conversation_id: Uuid) -> Self { + Self { + ctx: TraceContext { invocation_id, conversation_id }, + seq: 0, + think_buf: String::new(), + answer_buf: String::new(), + think_tokens: 0, + answer_tokens: 0, + svc, + } + } + + pub async fn feed_thinking(&mut self, chunk: &str) { + self.think_buf.push_str(chunk); + self.think_tokens += (chunk.chars().count() as f64 / 2.5).ceil() as i64; + } + + pub async fn feed_text(&mut self, chunk: &str) { + if !self.think_buf.is_empty() { + self.flush_think().await; + } + self.answer_buf.push_str(chunk); + self.answer_tokens += (chunk.chars().count() as f64 / 2.5).ceil() as i64; + } + + pub async fn feed_tool_call(&mut self, tool_call_id: &str, tool_name: &str, args: &Value) { + if !self.answer_buf.is_empty() { + self.flush_answer().await; + } + let _ = self.svc.trace_record( + &self.ctx, self.seq, "act", + None, + Some(&json!({ "tool_call_id": tool_call_id, "name": tool_name, "arguments": args })), + None, + None, None, None, + ).await; + self.seq += 1; + } + + pub async fn feed_tool_result(&mut self, tool_call_id: &str, tool_name: &str, + output: Option<&Value>, error: Option<&str>, elapsed_ms: i64) { + let _ = self.svc.trace_record( + &self.ctx, self.seq, "act", + None, + None, + Some(&json!({ + "tool_call_id": tool_call_id, + "name": tool_name, + "output": output, + "error": error, + "elapsed_ms": elapsed_ms, + })), + None, None, None, + ).await; + self.seq += 1; + } + + pub async fn finish(&mut self, output: &str, input_tokens: i64, output_tokens: i64) { + if !self.think_buf.is_empty() { + self.flush_think().await; + } + if !self.answer_buf.is_empty() { + self.flush_answer().await; + } + let _ = self.svc.trace_record( + &self.ctx, self.seq, "summarize", + Some(output), + None, None, + Some(input_tokens), Some(output_tokens), + None, + ).await; + } + + async fn flush_think(&mut self) { + let content = std::mem::take(&mut self.think_buf); + let tokens = self.think_tokens; + self.think_tokens = 0; + let _ = self.svc.trace_record( + &self.ctx, self.seq, "think", + Some(&content), None, None, + Some(tokens), None, None, + ).await; + self.seq += 1; + } + + async fn flush_answer(&mut self) { + let content = std::mem::take(&mut self.answer_buf); + let tokens = self.answer_tokens; + self.answer_tokens = 0; + let _ = self.svc.trace_record( + &self.ctx, self.seq, "answer", + Some(&content), None, None, + None, Some(tokens), None, + ).await; + self.seq += 1; + } +} diff --git a/lib/service/agent/types.rs b/lib/service/agent/types.rs new file mode 100644 index 0000000..ff29692 --- /dev/null +++ b/lib/service/agent/types.rs @@ -0,0 +1,102 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use utoipa::ToSchema; +use uuid::Uuid; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum BillingTarget { + User, + Workspace, +} +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AgentRunRequest { + pub session_id: Uuid, + pub conversation_id: Option, + pub input: String, + #[serde(default)] + pub stream: bool, + pub max_steps: Option, + pub timeout_secs: Option, +} +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AgentRunResponse { + pub message_id: Uuid, + pub conversation_id: Uuid, + pub output: String, + pub steps: Vec, + pub usage: AgentUsageInfo, + pub cost: Option, +} +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AgentStepInfo { + pub index: usize, + pub assistant: Option, + pub tool_calls: Vec, + pub reflection: Option, +} +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AgentToolCallInfo { + pub id: String, + pub name: String, + pub arguments: Value, + pub output: Option, + pub error: Option, + pub elapsed_ms: Option, +} +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AgentUsageInfo { + pub input_tokens: i64, + pub output_tokens: i64, + pub total_tokens: i64, +} +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AgentCostInfo { + pub amount: String, + pub currency: String, +} +#[derive(Debug, Clone)] +pub(crate) struct SessionContext { + pub session_id: Uuid, + pub user_id: Option, + pub workspace_id: Option, + pub system_prompt: String, + pub model_version_id: Uuid, + pub provider_model_name: String, + pub temperature: Option, + pub max_output_tokens: Option, + pub tool_policy_json: Option, + pub toolset_json: Option, + pub variables_json: Option, + pub iteration_budget: Option, + pub memory_provider: Option, + #[allow(dead_code)] + pub source: Option, + #[allow(dead_code)] + pub parent_session_id: Option, + pub billing_target: BillingTarget, +} +#[derive(Debug, Clone)] +pub(crate) struct BillingRecord { + pub invocation_id: Uuid, + pub session_id: Uuid, + pub model_version_id: Uuid, + pub input_tokens: i64, + pub output_tokens: i64, + pub cached_input_tokens: i64, + pub cache_read_tokens: i64, + pub cache_write_tokens: i64, + pub reasoning_tokens: i64, + pub total_tokens: i64, + pub cost: Option, + pub currency: Option, + pub created_at: DateTime, +} +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub(crate) struct RunPersistState { + pub message_id: Uuid, + pub conversation_id: Uuid, + pub invocation_id: Uuid, + pub billing: BillingRecord, +} diff --git a/lib/service/agent/workspace_tools/helpers.rs b/lib/service/agent/workspace_tools/helpers.rs new file mode 100644 index 0000000..16a5691 --- /dev/null +++ b/lib/service/agent/workspace_tools/helpers.rs @@ -0,0 +1,59 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::register::ToolRegister; +use db::sqlx; +use serde_json::Value; +use uuid::Uuid; + +use crate::agent::run::{AppAgentContext, GitAgentContext}; + +pub fn register_workspace_tools(tools: &mut ToolRegister) { + tools.register(super::workspace::WorkspaceInfoTool::new()); + tools.register(super::workspace::WorkspaceMembersTool::new()); + tools.register(super::workspace::WorkspaceGroupsTool::new()); + tools.register(super::workspace::WorkspaceGroupMembersTool::new()); +} + +pub(super) async fn require_workspace_member( + git: &GitAgentContext, + user_id: Uuid, + workspace_name: &str, +) -> AiResult { + let wk_id: Uuid = sqlx::query_scalar( + "SELECT id FROM workspace WHERE name = $1", + ) + .bind(workspace_name) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| AiError::Config(format!("workspace '{workspace_name}' not found")))?; + + let is_member: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM wk_member \ + WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(wk_id) + .bind(user_id) + .fetch_one(git.db.reader()) + .await + .map_err(AiError::Database)?; + + if is_member == 0 { + return Err(AiError::Config(format!( + "user is not a member of workspace '{workspace_name}'" + ))); + } + + Ok(wk_id) +} + +pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> { + ctx.git + .as_ref() + .ok_or_else(|| AiError::Config("workspace tools are not available in this session".to_string())) +} + +pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> { + args.get(key) + .and_then(|v| v.as_str()) + .ok_or_else(|| AiError::Config(format!("'{key}' parameter is required"))) +} diff --git a/lib/service/agent/workspace_tools/mod.rs b/lib/service/agent/workspace_tools/mod.rs new file mode 100644 index 0000000..bfc187a --- /dev/null +++ b/lib/service/agent/workspace_tools/mod.rs @@ -0,0 +1,4 @@ +pub mod helpers; +pub mod workspace; + +pub use helpers::register_workspace_tools; diff --git a/lib/service/agent/workspace_tools/workspace.rs b/lib/service/agent/workspace_tools/workspace.rs new file mode 100644 index 0000000..9cd5ac6 --- /dev/null +++ b/lib/service/agent/workspace_tools/workspace.rs @@ -0,0 +1,278 @@ +use ai::error::{AiError, AiResult}; +use ai::tool::tools::FunctionCall; +use async_trait::async_trait; +use db::sqlx; +use serde_json::{json, Value}; +use uuid::Uuid; + +use super::helpers::{arg_str, git_ctx, require_workspace_member}; +use crate::agent::run::AppAgentContext; + +pub struct WorkspaceInfoTool; + +impl WorkspaceInfoTool { + pub fn new() -> Self { Self } +} + +impl Default for WorkspaceInfoTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for WorkspaceInfoTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "workspace_info" } + + fn description(&self) -> &'static str { + "Get information about a workspace: name, description, avatar." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" } + }, + "required": ["workspace"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let _wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + + let row = sqlx::query_as::<_, (String, String, String, chrono::DateTime)>( + "SELECT name, description, avatar_url, created_at FROM workspace WHERE name = $1", + ) + .bind(workspace) + .fetch_optional(git.db.reader()) + .await + .map_err(AiError::Database)? + .ok_or_else(|| AiError::Config(format!("workspace '{workspace}' not found")))?; + + let member_count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM wk_member WHERE wk = (SELECT id FROM workspace WHERE name = $1) AND leave_at IS NULL", + ) + .bind(workspace) + .fetch_one(git.db.reader()) + .await + .map_err(AiError::Database)?; + + Ok(json!({ + "name": row.0, + "description": row.1, + "avatar_url": row.2, + "created_at": row.3.to_rfc3339(), + "member_count": member_count, + })) + } +} + +pub struct WorkspaceMembersTool; + +impl WorkspaceMembersTool { + pub fn new() -> Self { Self } +} + +impl Default for WorkspaceMembersTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for WorkspaceMembersTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "workspace_members" } + + fn description(&self) -> &'static str { + "List all members of a workspace with their roles." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "limit": { "type": "integer", "description": "Max results (default 50)" } + }, + "required": ["workspace"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(50).min(200); + + #[derive(sqlx::FromRow)] + struct MemberRow { + username: String, + display_name: Option, + owner: bool, + admin: bool, + } + + let rows: Vec = sqlx::query_as( + "SELECT u.username, u.display_name, m.owner, m.admin \ + FROM wk_member m \ + INNER JOIN \"user\" u ON u.id = m.\"user\" \ + WHERE m.wk = $1 AND m.leave_at IS NULL \ + ORDER BY m.owner DESC, m.admin DESC \ + LIMIT $2", + ) + .bind(wk_id) + .bind(limit) + .fetch_all(git.db.reader()) + .await + .map_err(AiError::Database)?; + + let members: Vec = rows.iter().map(|r| json!({ + "username": r.username, + "display_name": r.display_name, + "role": if r.owner { "owner" } else if r.admin { "admin" } else { "member" }, + })).collect(); + + Ok(json!({ "members": members, "count": members.len() })) + } +} + +pub struct WorkspaceGroupsTool; + +impl WorkspaceGroupsTool { + pub fn new() -> Self { Self } +} + +impl Default for WorkspaceGroupsTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for WorkspaceGroupsTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "workspace_groups" } + + fn description(&self) -> &'static str { + "List all user groups in a workspace." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" } + }, + "required": ["workspace"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + + #[derive(sqlx::FromRow)] + struct GroupRow { + id: Uuid, + name: String, + created_at: chrono::DateTime, + } + + let rows: Vec = sqlx::query_as( + "SELECT id, name, created_at FROM wk_group \ + WHERE wk = $1 AND is_deleted = FALSE ORDER BY name ASC", + ) + .bind(wk_id) + .fetch_all(git.db.reader()) + .await + .map_err(AiError::Database)?; + + // For each group, count members + let mut groups: Vec = Vec::new(); + for row in &rows { + let count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM wk_gp_member WHERE gp = $1", + ) + .bind(row.id) + .fetch_one(git.db.reader()) + .await + .map_err(AiError::Database)?; + + groups.push(json!({ + "name": row.name, + "member_count": count, + "created_at": row.created_at.to_rfc3339(), + })); + } + + Ok(json!({ "groups": groups, "count": groups.len() })) + } +} + +pub struct WorkspaceGroupMembersTool; + +impl WorkspaceGroupMembersTool { + pub fn new() -> Self { Self } +} + +impl Default for WorkspaceGroupMembersTool { + fn default() -> Self { Self::new() } +} + +#[async_trait] +impl FunctionCall for WorkspaceGroupMembersTool { + type Context = AppAgentContext; + + fn name(&self) -> &'static str { "workspace_group_members" } + + fn description(&self) -> &'static str { + "List members of a specific workspace group." + } + + fn schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "workspace": { "type": "string", "description": "Workspace name" }, + "group_name": { "type": "string", "description": "Group name" } + }, + "required": ["workspace", "group_name"] + }) + } + + async fn call(&self, ctx: &mut AppAgentContext, args: Value) -> AiResult { + let git = git_ctx(ctx)?; + let workspace = arg_str(&args, "workspace")?; + let group_name = arg_str(&args, "group_name")?; + let wk_id = require_workspace_member(git, ctx.user_id, workspace).await?; + + #[derive(sqlx::FromRow)] + struct MemberRow { + username: String, + display_name: Option, + } + + let rows: Vec = sqlx::query_as( + "SELECT u.username, u.display_name \ + FROM wk_gp_member gm \ + INNER JOIN wk_group g ON g.id = gm.gp \ + INNER JOIN \"user\" u ON u.id = gm.\"user\" \ + WHERE g.wk = $1 AND g.name = $2 AND g.is_deleted = FALSE", + ) + .bind(wk_id) + .bind(group_name) + .fetch_all(git.db.reader()) + .await + .map_err(AiError::Database)?; + + let members: Vec = rows.iter().map(|r| json!({ + "username": r.username, + "display_name": r.display_name, + })).collect(); + + Ok(json!({ "group": group_name, "members": members, "count": members.len() })) + } +} diff --git a/lib/service/ai/card.rs b/lib/service/ai/card.rs new file mode 100644 index 0000000..22574c2 --- /dev/null +++ b/lib/service/ai/card.rs @@ -0,0 +1,33 @@ +use crate::AppService; +use crate::ai::types::AiModelCardResponse; +use crate::error::AppError; +use db::sqlx; +use model::ai::AiModelCardModel; +use session::Session; + +impl AppService { + pub async fn ai_model_card( + &self, + ctx: &Session, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let _user_uid = self.ai_require_login(ctx).await?; + self.ai_card_get_inner(model_id).await + } + + pub(crate) async fn ai_card_get_inner( + &self, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let card = sqlx::query_as::<_, AiModelCardModel>( + "SELECT model, overview, strengths, limitations, safety_notes, eval_summary, metadata, created_at, updated_at \ + FROM ai_model_card WHERE model = $1", + ) + .bind(model_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(card.map(AiModelCardResponse::from)) + } +} diff --git a/lib/service/ai/discussion.rs b/lib/service/ai/discussion.rs new file mode 100644 index 0000000..b028f03 --- /dev/null +++ b/lib/service/ai/discussion.rs @@ -0,0 +1,44 @@ +use crate::AppService; +use crate::ai::types::AiDiscussionResponse; +use crate::error::AppError; +use crate::{Pagination, session_user}; +use db::sqlx; +use model::ai::AiModelDiscussionModel; +use session::Session; + +impl AppService { + pub async fn ai_model_discussions( + &self, + ctx: &Session, + model_id: uuid::Uuid, + pagination: Pagination, + ) -> Result, AppError> { + let _user_uid = session_user(ctx)?; + + let discussions = sqlx::query_as::<_, AiModelDiscussionModel>( + "SELECT id, model, \"user\", parent, body, created_at, updated_at, deleted_at \ + FROM ai_model_discussion WHERE model = $1 AND deleted_at IS NULL \ + ORDER BY created_at DESC LIMIT $2 OFFSET $3", + ) + .bind(model_id) + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for d in discussions { + let user = self.users_find_by_id(d.user).await?; + results.push(AiDiscussionResponse { + id: d.id, + author: crate::issues::types::issue_author(user), + parent: d.parent, + body: d.body, + created_at: d.created_at, + updated_at: d.updated_at, + }); + } + Ok(results) + } +} diff --git a/lib/service/ai/like.rs b/lib/service/ai/like.rs new file mode 100644 index 0000000..3e1b645 --- /dev/null +++ b/lib/service/ai/like.rs @@ -0,0 +1,49 @@ +use crate::AppService; +use crate::ai::types::AiLikeResponse; +use crate::error::AppError; +use db::sqlx; +use model::ai::AiModelLikeModel; +use session::Session; + +impl AppService { + pub async fn ai_model_likes( + &self, + ctx: &Session, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let _user_uid = self.ai_require_login(ctx).await?; + + let likes = sqlx::query_as::<_, AiModelLikeModel>( + "SELECT model, \"user\", created_at FROM ai_model_like WHERE model = $1 \ + ORDER BY created_at DESC", + ) + .bind(model_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for l in likes { + let user = self.users_find_by_id(l.user).await?; + results.push(AiLikeResponse { + user: crate::issues::types::issue_author(user), + created_at: l.created_at, + }); + } + Ok(results) + } + + pub(crate) async fn ai_like_count_inner( + &self, + model_id: uuid::Uuid, + ) -> Result { + let count = sqlx::query_scalar::<_, i64>( + "SELECT COUNT(*) FROM ai_model_like WHERE model = $1", + ) + .bind(model_id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(count) + } +} diff --git a/lib/service/ai/mod.rs b/lib/service/ai/mod.rs new file mode 100644 index 0000000..3361e6c --- /dev/null +++ b/lib/service/ai/mod.rs @@ -0,0 +1,24 @@ +pub mod card; +pub mod discussion; +pub mod like; +pub mod model; +pub mod provider; +pub mod sync; +pub mod tag; +pub mod types; +pub mod version; + +use crate::AppService; +use crate::error::AppError; +use crate::session_user; +use session::Session; +use uuid::Uuid; + +impl AppService { + pub(crate) async fn ai_require_login( + &self, + ctx: &Session, + ) -> Result { + session_user(ctx) + } +} diff --git a/lib/service/ai/model.rs b/lib/service/ai/model.rs new file mode 100644 index 0000000..b495b56 --- /dev/null +++ b/lib/service/ai/model.rs @@ -0,0 +1,160 @@ +use crate::AppService; +use crate::ai::types::{AiModelFilter, AiModelListItem, AiModelResponse}; +use crate::error::AppError; +use crate::{Pagination, session_user}; +use db::sqlx; +use db::sqlx::AssertSqlSafe; +use model::ai::AiModelModel; +use model::ai::AiProviderModel; +use session::Session; + +impl AppService { + pub async fn ai_model_list( + &self, + ctx: &Session, + filter: AiModelFilter, + pagination: Pagination, + ) -> Result, AppError> { + let _user_uid = session_user(ctx)?; + + let mut conditions = vec![ + "m.public = true".to_string(), + "m.deleted_at IS NULL".to_string(), + ]; + let mut param_idx = 1; + + if filter.enabled.is_some() { + conditions.push(format!("m.enabled = ${param_idx}")); + param_idx += 1; + } + if filter.provider.is_some() { + conditions.push(format!("m.provider = ${param_idx}")); + param_idx += 1; + } + if filter.modality.is_some() { + conditions.push(format!("m.modality = ${param_idx}")); + param_idx += 1; + } + if filter.name.is_some() { + conditions.push(format!("m.name ILIKE ${param_idx}")); + param_idx += 1; + } + + let where_clause = conditions.join(" AND "); + let limit_idx = param_idx; + let offset_idx = param_idx + 1; + + let query = format!( + "SELECT m.id, m.provider, m.name, m.display_name, m.description, m.modality, \ + m.context_window, m.input_token_limit, m.output_token_limit, \ + m.enabled, m.public, m.created_at, m.updated_at, m.deleted_at \ + FROM ai_model m WHERE {where_clause} \ + ORDER BY m.display_name ASC LIMIT ${limit_idx} OFFSET ${offset_idx}" + ); + + let mut q = sqlx::query_as::<_, AiModelModel>(AssertSqlSafe(query)); + if let Some(enabled) = &filter.enabled { + q = q.bind(enabled); + } + if let Some(provider) = &filter.provider { + q = q.bind(provider); + } + if let Some(modality) = &filter.modality { + q = q.bind(modality); + } + if let Some(name) = &filter.name { + q = q.bind(format!("%{}%", name)); + } + q = q + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64); + + let models = q + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for m in models { + let provider = self.ai_provider_by_id(m.provider).await?; + results.push(AiModelListItem { + id: m.id, + name: m.name, + display_name: m.display_name, + description: crate::non_empty( + m.description.unwrap_or_default(), + ), + modality: m.modality, + provider_name: provider.name, + provider_logo_url: provider.logo_url, + context_window: m.context_window, + input_token_limit: m.input_token_limit, + output_token_limit: m.output_token_limit, + enabled: m.enabled, + created_at: m.created_at, + updated_at: m.updated_at, + }); + } + Ok(results) + } + pub async fn ai_model_get( + &self, + ctx: &Session, + id: uuid::Uuid, + ) -> Result { + let _user_uid = session_user(ctx)?; + + let m = sqlx::query_as::<_, AiModelModel>( + "SELECT id, provider, name, display_name, description, modality, \ + context_window, input_token_limit, output_token_limit, \ + enabled, public, created_at, updated_at, deleted_at \ + FROM ai_model WHERE id = $1 AND (public = true OR deleted_at IS NULL)", + ) + .bind(id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("model not found".to_string()))?; + + let provider = self.ai_provider_by_id(m.provider).await?; + let card = self.ai_card_get_inner(m.id).await?; + let versions = self.ai_version_list_inner(m.id).await?; + let tags = self.ai_tag_list_inner(m.id).await?; + let like_count = self.ai_like_count_inner(m.id).await?; + + Ok(AiModelResponse { + id: m.id, + name: m.name, + display_name: m.display_name, + description: crate::non_empty(m.description.unwrap_or_default()), + modality: m.modality, + context_window: m.context_window, + input_token_limit: m.input_token_limit, + output_token_limit: m.output_token_limit, + enabled: m.enabled, + public: m.public, + provider_name: provider.name, + provider_logo_url: provider.logo_url, + card, + versions, + tags, + like_count, + created_at: m.created_at, + updated_at: m.updated_at, + }) + } + + pub(crate) async fn ai_provider_by_id( + &self, + id: uuid::Uuid, + ) -> Result { + sqlx::query_as::<_, AiProviderModel>( + "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ + FROM ai_provider WHERE id = $1", + ) + .bind(id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string())) + } +} diff --git a/lib/service/ai/provider.rs b/lib/service/ai/provider.rs new file mode 100644 index 0000000..de45160 --- /dev/null +++ b/lib/service/ai/provider.rs @@ -0,0 +1,47 @@ +use crate::AppService; +use crate::ai::types::AiProviderResponse; +use crate::error::AppError; +use db::sqlx; +use model::ai::AiProviderModel; +use session::Session; + +impl AppService { + pub async fn ai_provider_list( + &self, + ctx: &Session, + ) -> Result, AppError> { + self.ai_require_login(ctx).await?; + + let providers = sqlx::query_as::<_, AiProviderModel>( + "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ + FROM ai_provider WHERE enabled = true ORDER BY name ASC", + ) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(providers + .into_iter() + .map(AiProviderResponse::from) + .collect()) + } + pub async fn ai_provider_get( + &self, + ctx: &Session, + id: uuid::Uuid, + ) -> Result { + self.ai_require_login(ctx).await?; + + let provider = sqlx::query_as::<_, AiProviderModel>( + "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ + FROM ai_provider WHERE id = $1", + ) + .bind(id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("provider not found".to_string()))?; + + Ok(AiProviderResponse::from(provider)) + } +} diff --git a/lib/service/ai/sync.rs b/lib/service/ai/sync.rs new file mode 100644 index 0000000..ecc1c1e --- /dev/null +++ b/lib/service/ai/sync.rs @@ -0,0 +1,585 @@ +use std::time::Duration; + +use ai::sync::{UpstreamModel, UpstreamPricing}; +use ai::client::EndpointConfig; +use chrono::Utc; +use db::sqlx::{self, types::Decimal}; +use model::ai::{AiModelModel, AiModelVersionModel, AiProviderModel}; +use serde::Serialize; +use tokio::time::interval; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct SyncModelsResponse { + pub models_created: i64, + pub models_updated: i64, + pub models_offline: i64, + pub models_deactivated: i64, + pub versions_created: i64, + pub pricing_created: i64, + pub pricing_updated: i64, +} + +#[derive(Debug, Clone)] +pub struct SyncResult { + pub provider_id: Uuid, + pub provider_name: String, + pub total: u32, + pub synced: u32, + pub skipped: u32, +} + +fn extract_provider_name(model: &UpstreamModel) -> String { + if let Some(owned) = &model.owned_by { + if !owned.is_empty() { + return normalize_provider_name(owned); + } + } + normalize_provider_name(model.id.split('/').next().unwrap_or("unknown")) +} + +fn normalize_provider_name(slug: &str) -> String { + match slug { + "openai" => "openai", + "anthropic" => "anthropic", + "google" | "google-ai" => "google", + "mistralai" => "mistral", + "meta-llama" | "meta" => "meta", + "deepseek" => "deepseek", + "azure" | "azure-openai" => "azure", + "x-ai" | "xai" => "xai", + "moonshot" => "moonshot", + "alibaba" | "qwen" => "qwen", + s => s, + } + .to_string() +} + +fn provider_display_name(name: &str) -> String { + match name { + "openai" => "OpenAI", + "anthropic" => "Anthropic", + "google" => "Google DeepMind", + "mistral" => "Mistral AI", + "meta" => "Meta", + "deepseek" => "DeepSeek", + "azure" => "Microsoft Azure", + "xai" => "xAI", + "moonshot" => "Moonshot AI", + "qwen" => "Alibaba Qwen", + s => s, + } + .to_string() +} + +fn infer_modality(model: &UpstreamModel) -> &'static str { + if let Some(caps) = &model.capabilities { + if caps.vision == Some(true) { + return "multimodal"; + } + } + let lower = model.id.to_lowercase(); + if lower.contains("vision") + || lower.contains("dall-e") + || lower.contains("gpt-image") + || lower.contains("gpt-4o") + { + "multimodal" + } else if lower.contains("embedding") { + "text" + } else if lower.contains("whisper") || lower.contains("audio") { + "audio" + } else { + "text" + } +} + +async fn upsert_provider( + db: &db::AppDatabase, + slug: &str, +) -> Result { + let _display = provider_display_name(slug); + let now = Utc::now(); + + if let Some(existing) = sqlx::query_as::<_, AiProviderModel>( + "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ + FROM ai_provider WHERE name = $1", + ) + .bind(slug) + .fetch_optional(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + sqlx::query( + "UPDATE ai_provider SET updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(existing.id) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(existing) + } else { + let id = Uuid::now_v7(); + sqlx::query( + "INSERT INTO ai_provider (id, name, enabled, created_at, updated_at) \ + VALUES ($1, $2, true, $3, $3)", + ) + .bind(id) + .bind(slug) + .bind(now) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(AiProviderModel { + id, + name: slug.to_string(), + base_url: None, + website_url: None, + logo_url: None, + enabled: true, + created_at: now, + updated_at: now, + }) + } +} + +async fn upsert_model( + db: &db::AppDatabase, + provider_id: Uuid, + model: &UpstreamModel, +) -> Result<(AiModelModel, bool), AppError> { + let now = Utc::now(); + let name = &model.id; + let modality = infer_modality(model); + let ctx = model.context_length.map(|c| c as i32); + let max_out = model.max_output_tokens.map(|v| v as i32); + + if let Some(existing) = sqlx::query_as::<_, AiModelModel>( + "SELECT id, provider, name, display_name, description, modality, \ + context_window, input_token_limit, output_token_limit, \ + enabled, public, created_at, updated_at, deleted_at \ + FROM ai_model WHERE name = $1 AND deleted_at IS NULL", + ) + .bind(name) + .fetch_optional(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + sqlx::query( + "UPDATE ai_model SET provider = $1, context_window = $2, \ + output_token_limit = $3, enabled = true, updated_at = $4 \ + WHERE id = $5", + ) + .bind(provider_id) + .bind(ctx) + .bind(max_out) + .bind(now) + .bind(existing.id) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok((existing, false)) + } else { + let id = Uuid::now_v7(); + sqlx::query( + "INSERT INTO ai_model (id, provider, name, display_name, modality, \ + context_window, input_token_limit, output_token_limit, enabled, \ + public, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, true, $9, $9)", + ) + .bind(id) + .bind(provider_id) + .bind(name) + .bind(name) + .bind(modality) + .bind(ctx) + .bind(ctx) + .bind(max_out) + .bind(now) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let inserted = sqlx::query_as::<_, AiModelModel>( + "SELECT id, provider, name, display_name, description, modality, \ + context_window, input_token_limit, output_token_limit, \ + enabled, public, created_at, updated_at, deleted_at \ + FROM ai_model WHERE id = $1", + ) + .bind(id) + .fetch_one(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok((inserted, true)) + } +} + +async fn upsert_version( + db: &db::AppDatabase, + model_id: Uuid, + provider_model_name: &str, +) -> Result<(AiModelVersionModel, bool), AppError> { + let now = Utc::now(); + + if let Some(existing) = sqlx::query_as::<_, AiModelVersionModel>( + "SELECT id, model, version, provider_model_name, \ + input_price_per_million, output_price_per_million, cached_input_price_per_million, \ + training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ + FROM ai_model_version WHERE model = $1 AND provider_model_name = $2", + ) + .bind(model_id) + .bind(provider_model_name) + .fetch_optional(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + Ok((existing, false)) + } else { + let id = Uuid::now_v7(); + sqlx::query( + "INSERT INTO ai_model_version (id, model, version, provider_model_name, \ + enabled, created_at, updated_at) \ + VALUES ($1, $2, 'latest', $3, true, $4, $4)", + ) + .bind(id) + .bind(model_id) + .bind(provider_model_name) + .bind(now) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let inserted = sqlx::query_as::<_, AiModelVersionModel>( + "SELECT id, model, version, provider_model_name, \ + input_price_per_million, output_price_per_million, cached_input_price_per_million, \ + training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ + FROM ai_model_version WHERE id = $1", + ) + .bind(id) + .fetch_one(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok((inserted, true)) + } +} + +async fn upsert_pricing( + db: &db::AppDatabase, + version_id: Uuid, + pricing: Option<&UpstreamPricing>, +) -> Result { + let Some(p) = pricing else { + return Ok(PricingResult::Skipped); + }; + let input_million: Option = p.prompt.as_deref().and_then(parse_token_price_decimal) + .map(|per_token| per_token * Decimal::from(1_000_000u64)) + .or_else(|| { + p.input + .filter(|v| *v > 0.0) + .map(|v| Decimal::try_from(v).unwrap_or_default()) + }); + + let output_million: Option = p.completion.as_deref().and_then(parse_token_price_decimal) + .map(|per_token| per_token * Decimal::from(1_000_000u64)) + .or_else(|| { + p.output + .filter(|v| *v > 0.0) + .map(|v| Decimal::try_from(v).unwrap_or_default()) + }); + + let cache_input: Option = p.cache_read + .filter(|v| *v > 0.0) + .map(|v| Decimal::try_from(v).unwrap_or_default()); + + if input_million.is_none() && output_million.is_none() { + return Ok(PricingResult::Skipped); + } + + let existing = sqlx::query_scalar::<_, Uuid>( + "SELECT id FROM ai_model_version WHERE id = $1", + ) + .bind(version_id) + .fetch_optional(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if existing.is_none() { + return Ok(PricingResult::Skipped); + } + + let count = sqlx::query_scalar::<_, i64>( + "SELECT COUNT(*) FROM ai_model_version \ + WHERE id = $1 AND input_price_per_million IS NOT NULL", + ) + .bind(version_id) + .fetch_one(db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "UPDATE ai_model_version SET \ + input_price_per_million = COALESCE($1, input_price_per_million), \ + output_price_per_million = COALESCE($2, output_price_per_million), \ + cached_input_price_per_million = COALESCE($3, cached_input_price_per_million), \ + updated_at = $4 \ + WHERE id = $5", + ) + .bind(&input_million) + .bind(&output_million) + .bind(&cache_input) + .bind(Utc::now()) + .bind(version_id) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if count > 0 { + Ok(PricingResult::Updated) + } else { + Ok(PricingResult::Created) + } +} + +enum PricingResult { + Created, + Updated, + Skipped, +} + +fn parse_token_price_decimal(s: &str) -> Option { + use std::str::FromStr; + Decimal::from_str(s).ok() +} + +async fn disable_all_models(db: &db::AppDatabase) -> Result { + let result = sqlx::query( + "UPDATE ai_model SET enabled = false, updated_at = $1 WHERE enabled = true", + ) + .bind(Utc::now()) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(result.rows_affected() as i64) +} + +async fn deactivate_orphaned_models(db: &db::AppDatabase) -> Result { + let now = Utc::now(); + sqlx::query( + "UPDATE ai_model_version SET enabled = false, updated_at = $1 \ + WHERE model IN (SELECT id FROM ai_model WHERE enabled = false AND deleted_at IS NULL)", + ) + .bind(now) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let result = sqlx::query( + "UPDATE ai_model SET deleted_at = $1, updated_at = $1 \ + WHERE enabled = false AND deleted_at IS NULL", + ) + .bind(now) + .execute(db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(result.rows_affected() as i64) +} + +async fn sync_models_from_upstream( + db: &db::AppDatabase, + upstream_models: Vec, +) -> SyncModelsResponse { + let models_offline = disable_all_models(db).await.unwrap_or(0); + + tracing::info!( + upstream_total = upstream_models.len(), + "syncing models from upstream" + ); + + let mut models_created = 0i64; + let mut models_updated = 0i64; + let mut versions_created = 0i64; + let mut pricing_created = 0i64; + let mut pricing_updated = 0i64; + + for model in &upstream_models { + let provider_slug = extract_provider_name(model); + let provider = match upsert_provider(db, &provider_slug).await { + Ok(p) => p, + Err(e) => { + tracing::warn!( + provider = %provider_slug, + error = %e, + "sync: upsert_provider error" + ); + continue; + } + }; + + let (model_record, _is_new) = match upsert_model(db, provider.id, model).await { + Ok((m, created)) => { + if created { + models_created += 1; + } else { + models_updated += 1; + } + (m, created) + } + Err(e) => { + tracing::warn!( + model = %model.id, + error = %e, + "sync: upsert_model error" + ); + continue; + } + }; + + let (version_record, version_is_new) = + match upsert_version(db, model_record.id, &model.id).await { + Ok(v) => v, + Err(e) => { + tracing::warn!( + model = %model.id, + error = %e, + "sync: upsert_version error" + ); + continue; + } + }; + if version_is_new { + versions_created += 1; + } + + match upsert_pricing(db, version_record.id, model.pricing.as_ref()).await { + Ok(PricingResult::Created) => pricing_created += 1, + Ok(PricingResult::Updated) => pricing_updated += 1, + Ok(PricingResult::Skipped) => {} + Err(e) => { + tracing::warn!( + model = %model.id, + error = %e, + "sync: upsert_pricing error" + ); + } + } + } + + let deactivated = deactivate_orphaned_models(db).await.unwrap_or(0); + + SyncModelsResponse { + models_created, + models_updated, + models_offline, + models_deactivated: deactivated, + versions_created, + pricing_created, + pricing_updated, + } +} + +impl AppService { + pub async fn sync_upstream_models(&self) -> Result { + let api_key = self + .config + .ai_api_key() + .map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?; + + let base_url = self.config.ai_basic_url().unwrap_or_default(); + + let config = EndpointConfig::new(&base_url, &api_key) + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let upstream_models = ai::sync::list_models(&config) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + tracing::info!( + model_count = upstream_models.len(), + "sync_upstream_models: {} models from upstream", + upstream_models.len() + ); + + let result = sync_models_from_upstream(&self.db, upstream_models).await; + + tracing::info!( + models_created = result.models_created, + models_updated = result.models_updated, + versions_created = result.versions_created, + pricing_created = result.pricing_created, + pricing_updated = result.pricing_updated, + "sync_upstream_models: complete" + ); + + Ok(result) + } +} + +pub fn spawn_model_sync_loop(service: AppService) -> tokio::task::JoinHandle<()> { + let db = service.db.clone(); + let config = service.config.clone(); + + tokio::spawn(async move { + sync_once(&db, &config).await; + + let mut tick = interval(Duration::from_secs(60 * 10)); + loop { + tick.tick().await; + sync_once(&db, &config).await; + } + }) +} + +async fn sync_once(db: &db::AppDatabase, config: &config::AppConfig) { + let api_key = match config.ai_api_key() { + Ok(k) => k, + Err(e) => { + tracing::warn!(error = %e, "Model sync: AI API key not configured"); + return; + } + }; + + let base_url = config.ai_basic_url().unwrap_or_default(); + + let endpoint_config = match EndpointConfig::new(&base_url, &api_key) { + Ok(c) => c, + Err(e) => { + tracing::warn!(error = %e, "Model sync: invalid endpoint config"); + return; + } + }; + + let upstream_models = match ai::sync::list_models(&endpoint_config).await { + Ok(models) => models, + Err(e) => { + tracing::warn!(error = %e, "Model sync: failed to list upstream models"); + return; + } + }; + + tracing::info!( + model_count = upstream_models.len(), + "Model sync: {} models from upstream", + upstream_models.len() + ); + + let result = sync_models_from_upstream(db, upstream_models).await; + + tracing::info!( + models_created = result.models_created, + models_updated = result.models_updated, + versions_created = result.versions_created, + pricing_created = result.pricing_created, + pricing_updated = result.pricing_updated, + "Model sync complete" + ); +} diff --git a/lib/service/ai/tag.rs b/lib/service/ai/tag.rs new file mode 100644 index 0000000..70d9f4b --- /dev/null +++ b/lib/service/ai/tag.rs @@ -0,0 +1,31 @@ +use crate::AppService; +use crate::error::AppError; +use db::sqlx; +use model::ai::AiModelModelTagModel; +use session::Session; + +impl AppService { + pub async fn ai_model_tags( + &self, + ctx: &Session, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let _user_uid = self.ai_require_login(ctx).await?; + self.ai_tag_list_inner(model_id).await + } + + pub(crate) async fn ai_tag_list_inner( + &self, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let tags = sqlx::query_as::<_, AiModelModelTagModel>( + "SELECT model, tag, created_at FROM ai_model_model_tag WHERE model = $1 ORDER BY tag ASC", + ) + .bind(model_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(tags.into_iter().map(|t| t.tag).collect()) + } +} diff --git a/lib/service/ai/types.rs b/lib/service/ai/types.rs new file mode 100644 index 0000000..555cd28 --- /dev/null +++ b/lib/service/ai/types.rs @@ -0,0 +1,170 @@ +use crate::issues::types::IssueAuthor; +use crate::non_empty; +use model::ai::*; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiProviderResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, + pub base_url: Option, + pub website_url: Option, + pub logo_url: Option, + pub enabled: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +impl From for AiProviderResponse { + fn from(p: AiProviderModel) -> Self { + AiProviderResponse { + id: p.id, + name: p.name, + base_url: non_empty(p.base_url.unwrap_or_default()), + website_url: non_empty(p.website_url.unwrap_or_default()), + logo_url: non_empty(p.logo_url.unwrap_or_default()), + enabled: p.enabled, + created_at: p.created_at, + updated_at: p.updated_at, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiModelListItem { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, + pub display_name: String, + pub description: Option, + pub modality: String, + pub provider_name: String, + pub provider_logo_url: Option, + pub context_window: Option, + pub input_token_limit: Option, + pub output_token_limit: Option, + pub enabled: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiModelResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, + pub display_name: String, + pub description: Option, + pub modality: String, + pub context_window: Option, + pub input_token_limit: Option, + pub output_token_limit: Option, + pub enabled: bool, + pub public: bool, + pub provider_name: String, + pub provider_logo_url: Option, + pub card: Option, + pub versions: Vec, + pub tags: Vec, + pub like_count: i64, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiModelCardResponse { + pub overview: Option, + pub strengths: Option, + pub limitations: Option, + pub safety_notes: Option, + pub eval_summary: Option, + pub metadata: Option, +} + +impl From for AiModelCardResponse { + fn from(c: AiModelCardModel) -> Self { + AiModelCardResponse { + overview: non_empty(c.overview.unwrap_or_default()), + strengths: non_empty(c.strengths.unwrap_or_default()), + limitations: non_empty(c.limitations.unwrap_or_default()), + safety_notes: non_empty(c.safety_notes.unwrap_or_default()), + eval_summary: non_empty(c.eval_summary.unwrap_or_default()), + metadata: non_empty(c.metadata.unwrap_or_default()), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiModelVersionResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub version: String, + pub provider_model_name: String, + pub input_price_per_million: Option, + pub output_price_per_million: Option, + pub cached_input_price_per_million: Option, + pub training_cutoff: Option, + #[schema(value_type = Option)] + pub released_at: Option>, + #[schema(value_type = Option)] + pub deprecated_at: Option>, + pub enabled: bool, +} + +impl From for AiModelVersionResponse { + fn from(v: AiModelVersionModel) -> Self { + AiModelVersionResponse { + id: v.id, + version: v.version, + provider_model_name: v.provider_model_name, + input_price_per_million: v + .input_price_per_million + .map(|d| d.to_string()), + output_price_per_million: v + .output_price_per_million + .map(|d| d.to_string()), + cached_input_price_per_million: v + .cached_input_price_per_million + .map(|d| d.to_string()), + training_cutoff: v.training_cutoff, + released_at: v.released_at, + deprecated_at: v.deprecated_at, + enabled: v.enabled, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiDiscussionResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub author: IssueAuthor, + pub parent: Option, + pub body: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AiLikeResponse { + pub user: IssueAuthor, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::IntoParams)] +pub struct AiModelFilter { + pub enabled: Option, + pub provider: Option, + pub modality: Option, + pub name: Option, +} diff --git a/lib/service/ai/version.rs b/lib/service/ai/version.rs new file mode 100644 index 0000000..1f6ac52 --- /dev/null +++ b/lib/service/ai/version.rs @@ -0,0 +1,40 @@ +use crate::AppService; +use crate::ai::types::AiModelVersionResponse; +use crate::error::AppError; +use db::sqlx; +use model::ai::AiModelVersionModel; +use session::Session; + +impl AppService { + pub async fn ai_model_versions( + &self, + ctx: &Session, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let _user_uid = self.ai_require_login(ctx).await?; + + self.ai_version_list_inner(model_id).await + } + + pub(crate) async fn ai_version_list_inner( + &self, + model_id: uuid::Uuid, + ) -> Result, AppError> { + let versions = sqlx::query_as::<_, AiModelVersionModel>( + "SELECT id, model, version, provider_model_name, \ + input_price_per_million, output_price_per_million, cached_input_price_per_million, \ + training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ + FROM ai_model_version WHERE model = $1 AND enabled = true \ + ORDER BY released_at DESC", + ) + .bind(model_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(versions + .into_iter() + .map(AiModelVersionResponse::from) + .collect()) + } +} diff --git a/lib/service/auth/captcha.rs b/lib/service/auth/captcha.rs new file mode 100644 index 0000000..e655a7e --- /dev/null +++ b/lib/service/auth/captcha.rs @@ -0,0 +1,72 @@ +use session::Session; +use utoipa::{IntoParams, ToSchema}; + +use crate::{ + AppService, auth::rsa::RsaResponse, constant_time_eq, error::AppError, +}; + +#[derive( + serde::Deserialize, serde::Serialize, Clone, Debug, ToSchema, IntoParams, +)] +pub struct CaptchaQuery { + pub w: u32, + pub h: u32, + pub dark: bool, + pub rsa: bool, +} + +#[derive(serde::Serialize, ToSchema)] +pub struct CaptchaResponse { + pub base64: String, + pub rsa: Option, + pub req: CaptchaQuery, +} + +impl AppService { + const CAPTCHA_KEY: &'static str = "captcha"; + const CAPTCHA_LENGTH: usize = 4; + pub async fn auth_captcha( + &self, + context: &Session, + query: CaptchaQuery, + ) -> Result { + let CaptchaQuery { w, h, dark, rsa } = query; + let captcha = captcha_rs::CaptchaBuilder::new() + .width(w) + .height(h) + .dark_mode(dark) + .length(Self::CAPTCHA_LENGTH) + .build(); + + let base64 = captcha.to_base64(); + + let text = captcha.text; + context.insert(Self::CAPTCHA_KEY, text).ok(); + Ok(CaptchaResponse { + base64, + rsa: if rsa { + Some(self.auth_rsa(context).await?) + } else { + None + }, + req: CaptchaQuery { w, h, dark, rsa }, + }) + } + pub async fn auth_check_captcha( + &self, + context: &Session, + captcha: String, + ) -> Result<(), AppError> { + let text = context + .get::(Self::CAPTCHA_KEY) + .map_err(|_| AppError::CaptchaError)? + .ok_or(AppError::CaptchaError)?; + if !constant_time_eq(&text.to_lowercase(), &captcha.to_lowercase()) { + context.remove(Self::CAPTCHA_KEY); + tracing::warn!(ip = ?context.ip_address(), "Captcha verification failed"); + return Err(AppError::CaptchaError); + } + context.remove(Self::CAPTCHA_KEY); + Ok(()) + } +} diff --git a/lib/service/auth/email.rs b/lib/service/auth/email.rs new file mode 100644 index 0000000..d17c1bc --- /dev/null +++ b/lib/service/auth/email.rs @@ -0,0 +1,197 @@ +use argon2::{Argon2, PasswordHash, password_hash::PasswordVerifier}; +use db::sqlx; +use email::EmailMessage; +use model::users::{UserEmailModel, user_pass::UserPasswordModel}; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Deserialize, Serialize, utoipa::ToSchema)] +pub struct EmailChangeRequest { + pub new_email: String, + pub password: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize, utoipa::ToSchema)] +pub struct EmailVerifyRequest { + pub token: String, +} + +#[derive(Debug, Clone, Serialize, utoipa::ToSchema)] +pub struct EmailResponse { + pub email: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct PendingEmailChange { + user_uid: uuid::Uuid, + new_email: String, +} + +impl AppService { + const EMAIL_CHANGE_PREFIX: &'static str = "auth:email_change:"; + + pub async fn auth_get_email( + &self, + ctx: &Session, + ) -> Result { + let user_uid = ctx.user().ok_or(AppError::Unauthorized)?; + let email = sqlx::query_as::<_, UserEmailModel>( + "SELECT \"user\", email, created_at, active, last_use_login, updated_at \ + FROM user_email WHERE \"user\" = $1 AND active = true", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(EmailResponse { + email: email.map(|e| e.email), + }) + } + + pub async fn auth_email_change_request( + &self, + ctx: &Session, + params: EmailChangeRequest, + ) -> Result<(), AppError> { + let user_uid = ctx.user().ok_or(AppError::Unauthorized)?; + let password = self.auth_rsa_decode(ctx, params.password).await?; + + let user_password = sqlx::query_as::<_, UserPasswordModel>( + "SELECT \"user\", hash, salt, is_active, reason, created_at, updated_at \ + FROM user_password WHERE \"user\" = $1 AND is_active = true", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + let hash = PasswordHash::new(&user_password.hash) + .map_err(|_| AppError::UserNotFound)?; + Argon2::default() + .verify_password(password.as_bytes(), &hash) + .map_err(|_| AppError::InvalidPassword)?; + + let existing = sqlx::query_as::<_, UserEmailModel>( + "SELECT \"user\", email, created_at, active, last_use_login, updated_at \ + FROM user_email WHERE email = $1 AND active = true", + ) + .bind(¶ms.new_email) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + if existing.is_some() { + return Err(AppError::EmailExists); + } + + let token = Self::generate_email_change_token(); + let cache_key = format!("{}{}", Self::EMAIL_CHANGE_PREFIX, token); + self.cache + .set( + &cache_key, + &PendingEmailChange { + user_uid, + new_email: params.new_email.clone(), + }, + ) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let domain = self + .config + .main_domain() + .map_err(|_| AppError::DoMainNotSet)?; + let verify_link = + format!("{}/auth/verify-email?token={}", domain, token); + + self.email + .send(EmailMessage { + to: params.new_email.clone(), + subject: "Confirm Email Change".to_string(), + body: format!( + "You requested to change your GitDataAI email address.\n\n\ + Confirm the change here:\n\n{}\n\n\ + If you did not request this change, ignore this email.", + verify_link + ), + }) + .await + .map_err(|e| { + tracing::error!(error = %e, new_email = %params.new_email, "Failed to queue email change verification"); + AppError::InternalServerError(e.to_string()) + })?; + + tracing::info!(new_email = %params.new_email, user_uid = %user_uid, "Email change verification queued"); + Ok(()) + } + + pub async fn auth_email_verify( + &self, + params: EmailVerifyRequest, + ) -> Result<(), AppError> { + if params.token.is_empty() { + return Err(AppError::BadRequest( + "missing email verification token".to_string(), + )); + } + let cache_key = + format!("{}{}", Self::EMAIL_CHANGE_PREFIX, params.token); + let pending = self + .cache + .get::(&cache_key) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .ok_or(AppError::NotFound( + "invalid or expired email verification token".to_string(), + ))?; + + let existing = sqlx::query_as::<_, UserEmailModel>( + "SELECT \"user\", email, created_at, active, last_use_login, updated_at \ + FROM user_email WHERE email = $1 AND active = true", + ) + .bind(&pending.new_email) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + if existing.is_some() { + return Err(AppError::EmailExists); + } + + let now = chrono::Utc::now(); + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + sqlx::query("UPDATE user_email SET active = false, updated_at = $1 WHERE \"user\" = $2") + .bind(now) + .bind(pending.user_uid) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + sqlx::query( + "INSERT INTO user_email (\"user\", email, created_at, active, last_use_login, updated_at) \ + VALUES ($1, $2, $3, true, NULL, $3)", + ) + .bind(pending.user_uid) + .bind(&pending.new_email) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + txn.commit().await.map_err(|_| AppError::TxnError)?; + + let _ = self.cache.remove(&cache_key).await; + tracing::info!(new_email = %pending.new_email, user_uid = %pending.user_uid, "Email changed successfully"); + Ok(()) + } + + fn generate_email_change_token() -> String { + use rand::{RngExt, distr::Alphanumeric}; + + #[allow(deprecated)] + let mut rng = rand::rng(); + let token: String = + (0..64).map(|_| rng.sample(Alphanumeric) as char).collect(); + format!("emc_{}", token) + } +} diff --git a/lib/service/auth/login.rs b/lib/service/auth/login.rs new file mode 100644 index 0000000..903e3bc --- /dev/null +++ b/lib/service/auth/login.rs @@ -0,0 +1,171 @@ +use argon2::{ + Argon2, PasswordHash, + password_hash::{PasswordHasher, PasswordVerifier}, +}; +use db::sqlx; +use model::users::{UserModel, user_pass::UserPasswordModel}; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError}; + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct LoginParams { + pub username: String, + pub password: String, + pub captcha: String, + pub totp_code: Option, +} + +impl AppService { + pub const TOTP_KEY: &'static str = "totp_key"; + pub async fn auth_login( + &self, + params: LoginParams, + context: Session, + ) -> Result<(), AppError> { + self.auth_check_captcha(&context, params.captcha).await?; + let password = self.auth_rsa_decode(&context, params.password).await?; + let user = match self.auth_find_user_by_username(¶ms.username).await + { + Ok(user) => user, + Err(_) => { + match self.auth_find_user_by_email(¶ms.username).await { + Ok(user) => user, + Err(_) => { + let _ = Argon2::default() + .hash_password(password.as_bytes()); + return Err(AppError::UserNotFound); + } + } + } + }; + + let user_password = sqlx::query_as::<_, UserPasswordModel>( + "SELECT \"user\", hash, salt, is_active, reason, created_at, updated_at \ + FROM user_password WHERE \"user\" = $1 AND is_active = true", + ) + .bind(user.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + let password_hash = PasswordHash::new(&user_password.hash) + .map_err(|_| AppError::UserNotFound)?; + if Argon2::default() + .verify_password(password.as_bytes(), &password_hash) + .is_err() + { + tracing::warn!(username = %params.username, ip = ?context.ip_address(), "Login failed: invalid password"); + return Err(AppError::UserNotFound); + } + + if context + .get::(Self::TOTP_KEY) + .ok() + .flatten() + .is_some() + { + if let Some(ref totp_code) = params.totp_code { + if !self.auth_2fa_verify_login(&context, totp_code).await? { + return Err(AppError::InvalidTwoFactorCode); + } + } else { + return Err(AppError::InvalidTwoFactorCode); + } + } else if self.auth_2fa_status_by_uid(user.id).await?.is_enabled { + let totp_session_key = uuid::Uuid::new_v4().to_string(); + context + .insert(Self::TOTP_KEY, totp_session_key.clone()) + .map_err(|_| AppError::InternalError)?; + self.cache + .set(&totp_session_key, &user.id) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + tracing::info!(username = %params.username, ip = ?context.ip_address(), "Login 2FA triggered"); + return Err(AppError::TwoFactorRequired); + } + + sqlx::query("UPDATE \"user\" SET last_sign_in_at = $1, updated_at = $1 WHERE id = $2") + .bind(chrono::Utc::now()) + .bind(user.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + context.renew(); + context.set_user(user.id); + context.remove(Self::RSA_PRIVATE_KEY); + context.remove(Self::RSA_PUBLIC_KEY); + tracing::info!(user_uid = %user.id, username = %user.username, ip = ?context.ip_address(), "User logged in successfully"); + Ok(()) + } + + pub(crate) async fn auth_find_user_by_username( + &self, + username: &str, + ) -> Result { + sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE username = $1", + ) + .bind(username) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound) + } + + pub(crate) async fn auth_find_user_by_email( + &self, + email: &str, + ) -> Result { + sqlx::query_as::<_, UserModel>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, \ + u.can_search, u.last_sign_in_at, u.created_at, u.updated_at \ + FROM \"user\" u \ + INNER JOIN user_email e ON e.\"user\" = u.id \ + WHERE e.email = $1 AND e.active = true", + ) + .bind(email) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound) + } + + pub(crate) async fn auth_find_user_by_uid( + &self, + uid: uuid::Uuid, + ) -> Result { + sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE id = $1", + ) + .bind(uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound) + } + + pub(crate) fn validate_password_strength( + password: &str, + ) -> Result<(), AppError> { + if password.len() < 8 { + return Err(AppError::PasswordTooWeak); + } + + let has_uppercase = password.chars().any(|c| c.is_uppercase()); + let has_lowercase = password.chars().any(|c| c.is_lowercase()); + let has_digit = password.chars().any(|c| c.is_numeric()); + + if !has_uppercase || !has_lowercase || !has_digit { + return Err(AppError::PasswordTooWeak); + } + Ok(()) + } +} diff --git a/lib/service/auth/logout.rs b/lib/service/auth/logout.rs new file mode 100644 index 0000000..bc04c23 --- /dev/null +++ b/lib/service/auth/logout.rs @@ -0,0 +1,14 @@ +use session::Session; + +use crate::{AppService, error::AppError}; + +impl AppService { + pub async fn auth_logout(&self, context: &Session) -> Result<(), AppError> { + if let Some(user_uid) = context.user() { + tracing::info!(user_uid = %user_uid, ip = ?context.ip_address(), "User logged out"); + } + context.clear_user(); + context.clear(); + Ok(()) + } +} diff --git a/lib/service/auth/me.rs b/lib/service/auth/me.rs new file mode 100644 index 0000000..55e3beb --- /dev/null +++ b/lib/service/auth/me.rs @@ -0,0 +1,41 @@ +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError}; + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct ContextMe { + pub id: uuid::Uuid, + pub username: String, + pub display_name: Option, + pub avatar_url: Option, + pub has_unread_notifications: u64, + pub language: String, + pub timezone: String, +} + +impl AppService { + pub async fn auth_me(&self, ctx: Session) -> Result { + let user_id = ctx.user().ok_or(AppError::Unauthorized)?; + let user = self.auth_find_user_by_uid(user_id).await?; + let unread = self.unread_notifications_count(user_id).await.unwrap_or(0); + + Ok(ContextMe { + id: user.id, + username: user.username, + display_name: if user.display_name.is_empty() { + None + } else { + Some(user.display_name) + }, + avatar_url: if user.avatar_url.is_empty() { + None + } else { + Some(user.avatar_url) + }, + has_unread_notifications: unread as u64, + language: "en".to_string(), + timezone: "UTC".to_string(), + }) + } +} diff --git a/lib/service/auth/mod.rs b/lib/service/auth/mod.rs new file mode 100644 index 0000000..f35e534 --- /dev/null +++ b/lib/service/auth/mod.rs @@ -0,0 +1,9 @@ +pub mod captcha; +pub mod email; +pub mod login; +pub mod logout; +pub mod me; +pub mod register; +pub mod reset_pass; +pub mod rsa; +pub mod totp; diff --git a/lib/service/auth/register.rs b/lib/service/auth/register.rs new file mode 100644 index 0000000..e24d7bb --- /dev/null +++ b/lib/service/auth/register.rs @@ -0,0 +1,92 @@ +use argon2::{Argon2, password_hash::PasswordHasher}; +use db::sqlx; +use model::users::UserModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError}; + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct RegisterParams { + pub username: String, + pub email: String, + pub password: String, + pub captcha: String, +} + +impl AppService { + pub async fn auth_register( + &self, + params: RegisterParams, + context: &Session, + ) -> Result { + self.auth_check_captcha(context, params.captcha).await?; + let password = self.auth_rsa_decode(context, params.password).await?; + Self::validate_password_strength(&password)?; + + let username_exists = self + .auth_find_user_by_username(¶ms.username) + .await + .is_ok(); + let email_exists = + self.auth_find_user_by_email(¶ms.email).await.is_ok(); + if username_exists || email_exists { + return Err(AppError::AccountAlreadyExists); + } + + let user_id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + let password_hash = Argon2::default() + .hash_password(password.as_bytes()) + .map_err(|e| AppError::PasswordHashError(e.to_string()))? + .to_string(); + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let user = sqlx::query_as::<_, UserModel>( + "INSERT INTO \"user\" \ + (id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at) \ + VALUES ($1, $2, $3, '', '', true, true, NULL, $4, $4) \ + RETURNING id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at", + ) + .bind(user_id) + .bind(¶ms.username) + .bind(¶ms.username) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO user_email (\"user\", email, created_at, active, last_use_login, updated_at) \ + VALUES ($1, $2, $3, true, NULL, $3)", + ) + .bind(user_id) + .bind(¶ms.email) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO user_password (\"user\", hash, salt, is_active, reason, created_at, updated_at) \ + VALUES ($1, $2, '', true, NULL, $3, $3)", + ) + .bind(user_id) + .bind(&password_hash) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + + context.set_user(user_id); + context.remove(Self::RSA_PRIVATE_KEY); + context.remove(Self::RSA_PUBLIC_KEY); + tracing::info!(user_uid = %user_id, username = %user.username, "User registered successfully"); + Ok(user) + } +} diff --git a/lib/service/auth/reset_pass.rs b/lib/service/auth/reset_pass.rs new file mode 100644 index 0000000..9fc7e49 --- /dev/null +++ b/lib/service/auth/reset_pass.rs @@ -0,0 +1,153 @@ +use argon2::{Argon2, PasswordHasher}; +use chrono::{Duration, Utc}; +use db::sqlx; +use email::EmailMessage; +use rand::{RngExt, distr::Alphanumeric}; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Deserialize, Serialize, utoipa::ToSchema)] +pub struct ResetPasswordRequest { + pub email: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize, utoipa::ToSchema)] +pub struct ResetPasswordVerifyParams { + pub token: String, + pub password: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PendingResetPassword { + user_uid: uuid::Uuid, + created_at: chrono::DateTime, +} + +impl AppService { + const RESET_PASS_PREFIX: &'static str = "auth:reset_pass:"; + const RESET_PASS_EXPIRY_HOURS: i64 = 1; + pub async fn auth_reset_password_request( + &self, + params: ResetPasswordRequest, + ) -> Result<(), AppError> { + let user = self.auth_find_user_by_email(¶ms.email).await.ok(); + + if let Some(user) = user { + let token = Self::generate_reset_token(); + let cache_key = format!("{}{}", Self::RESET_PASS_PREFIX, token); + let now = chrono::Utc::now(); + + if let Err(e) = self + .cache + .set( + &cache_key, + &PendingResetPassword { + user_uid: user.id, + created_at: now, + }, + ) + .await + { + tracing::error!(error = %e, user_uid = %user.id, "Failed to cache reset token"); + return Ok(()); + } + + let domain = match self.config.main_domain() { + Ok(d) => d, + Err(e) => { + tracing::error!(error = %e, "Domain not configured for password reset"); + return Ok(()); + } + }; + let reset_link = + format!("{}/auth/reset-password?token={}", domain, token); + + if let Err(e) = self + .email + .send(EmailMessage { + to: params.email.clone(), + subject: "Reset Your Password".to_string(), + body: format!( + "You requested to reset your GitDataAI password.\n\n\ + Reset your password here:\n\n{}\n\n\ + If you did not request this, ignore this email.", + reset_link + ), + }) + .await + { + tracing::error!(error = %e, email = %params.email, "Failed to queue password reset email"); + } + + tracing::info!(email = %params.email, user_uid = %user.id, "Password reset email queued"); + } + + Ok(()) + } + + pub async fn auth_reset_password_verify( + &self, + context: &Session, + params: ResetPasswordVerifyParams, + ) -> Result<(), AppError> { + if params.token.is_empty() { + return Err(AppError::InvalidResetToken); + } + + let cache_key = format!("{}{}", Self::RESET_PASS_PREFIX, params.token); + let pending = self + .cache + .get::(&cache_key) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .ok_or(AppError::InvalidResetToken)?; + + if Utc::now() - pending.created_at + > Duration::hours(Self::RESET_PASS_EXPIRY_HOURS) + { + let _ = self.cache.remove(&cache_key).await; + return Err(AppError::ResetTokenExpired); + } + + self.cache + .remove(&cache_key) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let password = self.auth_rsa_decode(context, params.password).await?; + Self::validate_password_strength(&password)?; + + let password_hash = Argon2::default() + .hash_password(password.as_bytes()) + .map_err(|e| AppError::PasswordHashError(e.to_string()))? + .to_string(); + + let now = chrono::Utc::now(); + let result = sqlx::query( + "UPDATE user_password SET hash = $1, updated_at = $2 WHERE \"user\" = $3 AND is_active = true", + ) + .bind(&password_hash) + .bind(now) + .bind(pending.user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if result.rows_affected() == 0 { + return Err(AppError::InvalidResetToken); + } + + tracing::info!(user_uid = %pending.user_uid, "Password reset successfully"); + Ok(()) + } + + fn generate_reset_token() -> String { + #[allow(deprecated)] + let mut rng = rand::rng(); + let token: String = + (0..64).map(|_| rng.sample(Alphanumeric) as char).collect(); + format!("rst_{}", token) + } +} diff --git a/lib/service/auth/rsa.rs b/lib/service/auth/rsa.rs new file mode 100644 index 0000000..7398e04 --- /dev/null +++ b/lib/service/auth/rsa.rs @@ -0,0 +1,150 @@ +use base64::Engine; +use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce, aead::Aead}; +use hkdf::Hkdf; +use rand_chacha::{ChaCha12Rng, rand_core::SeedableRng}; +use rsa::{ + Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey, + pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey, EncodeRsaPublicKey}, +}; +use serde::{Deserialize, Serialize}; +use session::Session; +use sha2::Sha256; + +use crate::{AppService, error::AppError}; + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct RsaResponse { + pub public_key: String, +} + +impl AppService { + pub const RSA_PRIVATE_KEY: &'static str = "rsa:private"; + pub const RSA_PUBLIC_KEY: &'static str = "rsa:public"; + const RSA_BIT_SIZE: usize = 2048; + + fn derive_rsa_encryption_key(&self) -> [u8; 32] { + let secret = self + .config + .env + .get("APP_SESSION_SECRET") + .map(|s| s.as_str()) + .expect("APP_SESSION_SECRET must be set in production. Do not use fallback keys."); + let hk = Hkdf::::new( + Some(b"rsa-session-encryption"), + secret.as_bytes(), + ); + let mut okm = [0u8; 32]; + hk.expand(b"rsa-private-key-aead", &mut okm) + .expect("HKDF expand within hash length"); + okm + } + + fn encrypt_rsa_key(&self, plaintext: &str) -> Result { + let key = self.derive_rsa_encryption_key(); + let cipher = ChaCha20Poly1305::new_from_slice(&key) + .expect("32-byte key is valid for ChaCha20Poly1305"); + let nonce_bytes: [u8; 12] = rand::random(); + let nonce = Nonce::from(nonce_bytes); + let ciphertext = cipher + .encrypt(&nonce, plaintext.as_bytes()) + .map_err(|_| AppError::RsaGenerationError)?; + let mut combined = nonce_bytes.to_vec(); + combined.extend_from_slice(&ciphertext); + Ok(base64::engine::general_purpose::STANDARD.encode(&combined)) + } + + fn decrypt_rsa_key(&self, encrypted: &str) -> Result { + let key = self.derive_rsa_encryption_key(); + let cipher = ChaCha20Poly1305::new_from_slice(&key) + .expect("32-byte key is valid for ChaCha20Poly1305"); + let combined = base64::engine::general_purpose::STANDARD + .decode(encrypted) + .map_err(|_| AppError::RsaDecodeError)?; + if combined.len() < 12 { + return Err(AppError::RsaDecodeError); + } + let mut nonce_bytes = [0u8; 12]; + nonce_bytes.copy_from_slice(&combined[..12]); + let nonce = Nonce::from(nonce_bytes); + let plaintext = cipher + .decrypt(&nonce, &combined[12..]) + .map_err(|_| AppError::RsaDecodeError)?; + Ok(String::from_utf8(plaintext) + .map_err(|_| AppError::RsaDecodeError)?) + } + + pub async fn auth_rsa( + &self, + context: &Session, + ) -> Result { + if context + .get::(Self::RSA_PRIVATE_KEY) + .ok() + .flatten() + .is_some() + && context + .get::(Self::RSA_PUBLIC_KEY) + .ok() + .flatten() + .is_some() + { + let public_key = context + .get::(Self::RSA_PUBLIC_KEY) + .ok() + .flatten() + .expect("checked above"); + return Ok(RsaResponse { public_key }); + } + + let seed: [u8; 32] = rand::random(); + let mut rng = ChaCha12Rng::from_seed(seed); + let priv_key = RsaPrivateKey::new(&mut rng, Self::RSA_BIT_SIZE) + .map_err(|_| { + tracing::error!("RSA key generation failed"); + AppError::RsaGenerationError + })?; + let pub_key = RsaPublicKey::from(&priv_key); + let priv_pem = priv_key + .to_pkcs1_pem(Default::default()) + .map_err(|_| AppError::RsaGenerationError)? + .to_string(); + let public_key = pub_key + .to_pkcs1_pem(Default::default()) + .map_err(|_| AppError::RsaGenerationError)? + .to_string(); + + context + .insert(Self::RSA_PRIVATE_KEY, self.encrypt_rsa_key(&priv_pem)?) + .map_err(|_| AppError::RsaGenerationError)?; + context + .insert(Self::RSA_PUBLIC_KEY, public_key.clone()) + .map_err(|_| AppError::RsaGenerationError)?; + + Ok(RsaResponse { public_key }) + } + + pub async fn auth_rsa_decode( + &self, + context: &Session, + data: String, + ) -> Result { + let encrypted_priv = context + .get::(Self::RSA_PRIVATE_KEY) + .map_err(|_| AppError::RsaDecodeError)? + .ok_or(AppError::RsaDecodeError)?; + let priv_pem = self.decrypt_rsa_key(&encrypted_priv)?; + + let priv_key = RsaPrivateKey::from_pkcs1_pem(&priv_pem).map_err(|_| { + tracing::warn!(ip = ?context.ip_address(), "RSA decode failed: invalid private key"); + AppError::RsaDecodeError + })?; + let cipher = base64::engine::general_purpose::STANDARD + .decode(&data) + .map_err(|_| AppError::RsaDecodeError)?; + let decrypted = priv_key.decrypt(Pkcs1v15Encrypt, &cipher).map_err(|_| { + tracing::warn!(ip = ?context.ip_address(), "RSA decrypt failed"); + AppError::RsaDecodeError + })?; + Ok(String::from_utf8_lossy(&decrypted).to_string()) + } +} diff --git a/lib/service/auth/totp.rs b/lib/service/auth/totp.rs new file mode 100644 index 0000000..03bd0c4 --- /dev/null +++ b/lib/service/auth/totp.rs @@ -0,0 +1,430 @@ +use argon2::{Argon2, PasswordHash, password_hash::PasswordVerifier}; +use db::sqlx; +use hmac::{Hmac, KeyInit, Mac}; +use model::users::{User2FaModel, user_pass::UserPasswordModel}; +use rand::RngExt; +use serde::{Deserialize, Serialize}; +use session::Session; +use sha1::Sha1; +use sha2::{Digest, Sha256}; +use uuid::Uuid; + +use crate::{AppService, constant_time_eq, error::AppError}; + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct Enable2FAResponse { + pub secret: String, + pub qr_code: String, + pub backup_codes: Vec, +} + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct Verify2FAParams { + pub code: String, +} + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct Disable2FAParams { + pub code: String, + pub password: String, +} + +#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] +pub struct Get2FAStatusResponse { + pub is_enabled: bool, + pub method: Option, + pub has_backup_codes: bool, +} +impl AppService { + pub async fn auth_2fa_enable( + &self, + context: &Session, + ) -> Result { + let user_uid = context.user().ok_or(AppError::Unauthorized)?; + let user = self.auth_find_user_by_uid(user_uid).await?; + + let existing = self.find_2fa(user_uid).await?; + if existing.as_ref().is_some_and(|two_fa| two_fa.enabled) { + return Err(AppError::TwoFactorAlreadyEnabled); + } + + let secret = self.generate_totp_secret(); + let backup_codes = self.generate_backup_codes(10); + let qr_code = format!( + "otpauth://totp/GitDataAI:{}?secret={}&issuer=GitDataAI", + user.username, secret + ); + let now = chrono::Utc::now(); + let hashed_backup_codes = + Self::hash_backup_codes(&backup_codes).join("."); + + if existing.is_some() { + sqlx::query( + "UPDATE user_2fa SET secret = $1, backup_codes = $2, enabled = false, updated_at = $3 \ + WHERE \"user\" = $4", + ) + .bind(&secret) + .bind(&hashed_backup_codes) + .bind(now) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } else { + sqlx::query( + "INSERT INTO user_2fa (\"user\", secret, backup_codes, enabled, created_at, updated_at) \ + VALUES ($1, $2, $3, false, $4, $4)", + ) + .bind(user_uid) + .bind(&secret) + .bind(&hashed_backup_codes) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + Ok(Enable2FAResponse { + secret, + qr_code, + backup_codes, + }) + } + + pub async fn auth_2fa_verify_and_enable( + &self, + context: &Session, + params: Verify2FAParams, + ) -> Result<(), AppError> { + let user_uid = context.user().ok_or(AppError::Unauthorized)?; + let two_fa = self + .find_2fa(user_uid) + .await? + .ok_or(AppError::TwoFactorNotSetup)?; + if two_fa.enabled { + return Err(AppError::TwoFactorAlreadyEnabled); + } + let secret = + two_fa.secret.as_ref().ok_or(AppError::TwoFactorNotSetup)?; + if !self.verify_totp_code(secret, ¶ms.code)? { + return Err(AppError::InvalidTwoFactorCode); + } + + sqlx::query("UPDATE user_2fa SET enabled = true, updated_at = $1 WHERE \"user\" = $2") + .bind(chrono::Utc::now()) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } + + pub async fn auth_2fa_disable( + &self, + context: &Session, + params: Disable2FAParams, + ) -> Result<(), AppError> { + let user_uid = context.user().ok_or(AppError::Unauthorized)?; + let password = self.auth_rsa_decode(context, params.password).await?; + self.verify_user_password(user_uid, &password).await?; + + let two_fa = self + .find_2fa(user_uid) + .await? + .ok_or(AppError::TwoFactorNotSetup)?; + if !two_fa.enabled { + return Err(AppError::TwoFactorNotEnabled); + } + if !self + .verify_2fa_or_backup_code(&two_fa, ¶ms.code) + .await? + { + return Err(AppError::InvalidTwoFactorCode); + } + + sqlx::query("DELETE FROM user_2fa WHERE \"user\" = $1") + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } + + pub(crate) async fn auth_2fa_verify( + &self, + user_uid: Uuid, + code: &str, + ) -> Result { + let Some(two_fa) = self.find_2fa(user_uid).await? else { + return Ok(true); + }; + if !two_fa.enabled { + return Ok(true); + } + self.verify_2fa_or_backup_code(&two_fa, code).await + } + + pub(crate) async fn auth_2fa_status_by_uid( + &self, + user_uid: Uuid, + ) -> Result { + let Some(two_fa) = self.find_2fa(user_uid).await? else { + return Ok(Get2FAStatusResponse { + is_enabled: false, + method: None, + has_backup_codes: false, + }); + }; + Ok(Get2FAStatusResponse { + is_enabled: two_fa.enabled, + method: Some("totp".to_string()), + has_backup_codes: !two_fa.backup_codes.is_empty(), + }) + } + + pub async fn auth_2fa_status( + &self, + context: &Session, + ) -> Result { + let user_uid = context.user().ok_or(AppError::Unauthorized)?; + self.auth_2fa_status_by_uid(user_uid).await + } + + pub async fn auth_2fa_verify_login( + &self, + context: &Session, + code: &str, + ) -> Result { + let Some(totp_key) = + context.get::(Self::TOTP_KEY).ok().flatten() + else { + return Ok(false); + }; + let Some(user_uid) = self + .cache + .get::(&totp_key) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))? + else { + return Ok(false); + }; + let verified = self.auth_2fa_verify(user_uid, code).await?; + if verified { + context.remove(Self::TOTP_KEY); + let _ = self.cache.remove(&totp_key).await; + context.set_user(user_uid); + } + Ok(verified) + } + + pub async fn auth_2fa_regenerate_backup_codes( + &self, + context: &Session, + password: String, + ) -> Result, AppError> { + let user_uid = context.user().ok_or(AppError::Unauthorized)?; + let password = self.auth_rsa_decode(context, password).await?; + self.verify_user_password(user_uid, &password).await?; + let two_fa = self + .find_2fa(user_uid) + .await? + .ok_or(AppError::TwoFactorNotSetup)?; + if !two_fa.enabled { + return Err(AppError::TwoFactorNotEnabled); + } + + let backup_codes = self.generate_backup_codes(10); + sqlx::query("UPDATE user_2fa SET backup_codes = $1, updated_at = $2 WHERE \"user\" = $3") + .bind(Self::hash_backup_codes(&backup_codes).join(".")) + .bind(chrono::Utc::now()) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(backup_codes) + } + fn generate_totp_secret(&self) -> String { + const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + #[allow(deprecated)] + let mut rng = rand::rng(); + (0..32) + .map(|_| { + #[allow(deprecated)] + let idx = rng.random_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect() + } + + fn generate_backup_codes(&self, count: usize) -> Vec { + #[allow(deprecated)] + let mut rng = rand::rng(); + (0..count) + .map(|_| { + format!( + "{:04}-{:04}-{:04}", + rng.random_range(0..10000), + rng.random_range(0..10000), + rng.random_range(0..10000) + ) + }) + .collect() + } + + fn hash_backup_code(code: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(code.as_bytes()); + hasher + .finalize() + .iter() + .map(|b| format!("{:02x}", b)) + .collect::() + } + + fn hash_backup_codes(codes: &[String]) -> Vec { + codes.iter().map(|c| Self::hash_backup_code(c)).collect() + } + + fn verify_totp_code( + &self, + secret: &str, + code: &str, + ) -> Result { + let now = chrono::Utc::now().timestamp() as u64; + let time_step = 30; + let counter = now / time_step; + + for offset in [-1i64, 0, 1] { + let test_counter = (counter as i64 + offset) as u64; + let expected_code = + self.generate_totp_code(secret, test_counter)?; + if constant_time_eq(&expected_code, code) { + return Ok(true); + } + } + + Ok(false) + } + + fn generate_totp_code( + &self, + secret: &str, + counter: u64, + ) -> Result { + let secret_bytes = self.decode_base32(secret)?; + + let counter_bytes = counter.to_be_bytes(); + + let mut mac = Hmac::::new_from_slice(&secret_bytes) + .map_err(|_| AppError::InvalidTwoFactorCode)?; + mac.update(&counter_bytes); + let result = mac.finalize().into_bytes(); + + let offset = (result[19] & 0x0f) as usize; + let code = u32::from_be_bytes([ + result[offset] & 0x7f, + result[offset + 1], + result[offset + 2], + result[offset + 3], + ]); + + Ok(format!("{:06}", code % 1_000_000)) + } + + fn decode_base32(&self, input: &str) -> Result, AppError> { + const CHARSET: &str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + let input = input.to_uppercase().replace("=", ""); + let mut bits = 0u64; + let mut bit_count = 0; + let mut output = Vec::new(); + + for c in input.chars() { + let val = + CHARSET.find(c).ok_or(AppError::InvalidTwoFactorCode)? as u64; + bits = (bits << 5) | val; + bit_count += 5; + + if bit_count >= 8 { + bit_count -= 8; + output.push((bits >> bit_count) as u8); + bits &= (1 << bit_count) - 1; + } + } + + Ok(output) + } + + async fn verify_user_password( + &self, + user_uid: Uuid, + password: &str, + ) -> Result<(), AppError> { + let user_password = sqlx::query_as::<_, UserPasswordModel>( + "SELECT \"user\", hash, salt, is_active, reason, created_at, updated_at \ + FROM user_password WHERE \"user\" = $1 AND is_active = true", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + let password_hash = PasswordHash::new(&user_password.hash) + .map_err(|_| AppError::InvalidPassword)?; + + Argon2::default() + .verify_password(password.as_bytes(), &password_hash) + .map_err(|_| AppError::InvalidPassword)?; + + Ok(()) + } + + async fn find_2fa( + &self, + user_uid: Uuid, + ) -> Result, AppError> { + sqlx::query_as::<_, User2FaModel>( + "SELECT \"user\", secret, backup_codes, enabled, created_at, updated_at \ + FROM user_2fa WHERE \"user\" = $1", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string())) + } + + async fn verify_2fa_or_backup_code( + &self, + two_fa: &User2FaModel, + code: &str, + ) -> Result { + let secret = + two_fa.secret.as_ref().ok_or(AppError::TwoFactorNotSetup)?; + if self.verify_totp_code(secret, code)? { + return Ok(true); + } + + let hashed_code = Self::hash_backup_code(code); + let mut backup_codes: Vec = two_fa + .backup_codes + .split('.') + .filter(|code| !code.is_empty()) + .map(ToOwned::to_owned) + .collect(); + if backup_codes.contains(&hashed_code) { + backup_codes.retain(|stored| stored != &hashed_code); + sqlx::query( + "UPDATE user_2fa SET backup_codes = $1, updated_at = $2 WHERE \"user\" = $3", + ) + .bind(backup_codes.join(".")) + .bind(chrono::Utc::now()) + .bind(two_fa.user) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + return Ok(true); + } + + Ok(false) + } +} diff --git a/lib/service/error.rs b/lib/service/error.rs new file mode 100644 index 0000000..e7f3a6a --- /dev/null +++ b/lib/service/error.rs @@ -0,0 +1,145 @@ +use std::fmt; + +#[derive(Debug)] +pub enum AppError { + UserNotFound, + RsaGenerationError, + RsaDecodeError, + CaptchaError, + TwoFactorRequired, + Unauthorized, + DoMainNotSet, + UserNameExists, + EmailExists, + AccountAlreadyExists, + TxnError, + PasswordHashError(String), + TwoFactorAlreadyEnabled, + TwoFactorNotSetup, + InvalidTwoFactorCode, + TwoFactorNotEnabled, + DatabaseError(String), + InvalidPassword, + PasswordTooWeak, + ProjectNotFound, + NoPower, + InternalError, + NotFound(String), + RoleParseError, + ProjectNameAlreadyExists, + RepoNameAlreadyExists, + AvatarUploadError(String), + InternalServerError(String), + PermissionDenied, + RepoNotFound, + RepoForBidAccess, + SerdeError(serde_json::Error), + Io(std::io::Error), + BadRequest(String), + Forbidden(String), + Conflict(String), + InvalidResetToken, + ResetTokenExpired, + ResetTokenUsed, + IssueNotFound, + LabelNotFound, + MilestoneNotFound, + PullRequestNotFound, + CommentNotFound, + GitRpcError(String), + AiError(ai::error::AiError), +} + +impl fmt::Display for AppError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AppError::UserNotFound => write!(f, "user not found"), + AppError::RsaGenerationError => { + write!(f, "RSA key generation failed") + } + AppError::RsaDecodeError => write!(f, "RSA decode failed"), + AppError::CaptchaError => write!(f, "captcha verification failed"), + AppError::TwoFactorRequired => { + write!(f, "two-factor authentication required") + } + AppError::Unauthorized => write!(f, "unauthorized"), + AppError::DoMainNotSet => write!(f, "domain not configured"), + AppError::UserNameExists => write!(f, "username already exists"), + AppError::EmailExists => write!(f, "email already exists"), + AppError::AccountAlreadyExists => { + write!(f, "account already exists") + } + AppError::TxnError => write!(f, "transaction error"), + AppError::PasswordHashError(e) => { + write!(f, "password hash error: {}", e) + } + AppError::TwoFactorAlreadyEnabled => { + write!(f, "two-factor already enabled") + } + AppError::TwoFactorNotSetup => write!(f, "two-factor not setup"), + AppError::InvalidTwoFactorCode => { + write!(f, "invalid two-factor code") + } + AppError::TwoFactorNotEnabled => { + write!(f, "two-factor not enabled") + } + AppError::DatabaseError(e) => write!(f, "database error: {}", e), + AppError::InvalidPassword => write!(f, "invalid password"), + AppError::PasswordTooWeak => write!(f, "password too weak"), + AppError::ProjectNotFound => write!(f, "project not found"), + AppError::NoPower => write!(f, "permission denied"), + AppError::InternalError => write!(f, "internal error"), + AppError::NotFound(msg) => write!(f, "not found: {}", msg), + AppError::RoleParseError => write!(f, "role parse error"), + AppError::ProjectNameAlreadyExists => { + write!(f, "project name already exists") + } + AppError::RepoNameAlreadyExists => { + write!(f, "repo name already exists") + } + AppError::AvatarUploadError(e) => { + write!(f, "avatar upload error: {}", e) + } + AppError::InternalServerError(e) => { + write!(f, "internal server error: {}", e) + } + AppError::PermissionDenied => write!(f, "permission denied"), + AppError::RepoNotFound => write!(f, "repo not found"), + AppError::RepoForBidAccess => write!(f, "repo access forbidden"), + AppError::SerdeError(e) => write!(f, "serde error: {}", e), + AppError::Io(e) => write!(f, "IO error: {}", e), + AppError::BadRequest(msg) => write!(f, "bad request: {}", msg), + AppError::Forbidden(msg) => write!(f, "forbidden: {}", msg), + AppError::Conflict(msg) => write!(f, "conflict: {}", msg), + AppError::InvalidResetToken => write!(f, "invalid reset token"), + AppError::ResetTokenExpired => write!(f, "reset token expired"), + AppError::ResetTokenUsed => write!(f, "reset token already used"), + AppError::IssueNotFound => write!(f, "issue not found"), + AppError::LabelNotFound => write!(f, "label not found"), + AppError::MilestoneNotFound => write!(f, "milestone not found"), + AppError::PullRequestNotFound => { + write!(f, "pull request not found") + } + AppError::CommentNotFound => write!(f, "comment not found"), + AppError::GitRpcError(e) => write!(f, "git rpc error: {}", e), + AppError::AiError(e) => write!(f, "ai error: {}", e), + } + } +} + +impl std::error::Error for AppError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + AppError::SerdeError(e) => Some(e), + AppError::Io(e) => Some(e), + AppError::AiError(e) => Some(e), + _ => None, + } + } +} + +impl From for AppError { + fn from(e: ai::error::AiError) -> Self { + AppError::AiError(e) + } +} diff --git a/lib/service/git/archive.rs b/lib/service/git/archive.rs new file mode 100644 index 0000000..5e78d11 --- /dev/null +++ b/lib/service/git/archive.rs @@ -0,0 +1,48 @@ +use git::rpc::{ + proto as p, proto::archive_service_client::ArchiveServiceClient, +}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_archive_tar( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = ArchiveServiceClient::new(self.git.clone()); + let resp = client + .archive_tar(tonic::Request::new(p::ArchiveTarRequest { + repo_id: repo.id.to_string(), + options, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_archive_zip( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = ArchiveServiceClient::new(self.git.clone()); + let resp = client + .archive_zip(tonic::Request::new(p::ArchiveZipRequest { + repo_id: repo.id.to_string(), + options, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } +} diff --git a/lib/service/git/blame.rs b/lib/service/git/blame.rs new file mode 100644 index 0000000..1cfd6da --- /dev/null +++ b/lib/service/git/blame.rs @@ -0,0 +1,82 @@ +use git::rpc::{proto as p, proto::blame_service_client::BlameServiceClient}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_blame_file( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + path: String, + rev: Option, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlameServiceClient::new(self.git.clone()); + let resp = client + .blame_file(tonic::Request::new(p::BlameFileRequest { + repo_id: repo.id.to_string(), + path, + rev, + options, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_blame_hunk( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + path: String, + rev: Option, + start_line: u32, + end_line: u32, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlameServiceClient::new(self.git.clone()); + let resp = client + .blame_hunk(tonic::Request::new(p::BlameHunkRequest { + repo_id: repo.id.to_string(), + path, + rev, + start_line, + end_line, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_blame_lines( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + path: String, + rev: Option, + start_line: u32, + end_line: u32, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlameServiceClient::new(self.git.clone()); + let resp = client + .blame_lines(tonic::Request::new(p::BlameLinesRequest { + repo_id: repo.id.to_string(), + path, + rev, + start_line, + end_line, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } +} diff --git a/lib/service/git/blob.rs b/lib/service/git/blob.rs new file mode 100644 index 0000000..06c7880 --- /dev/null +++ b/lib/service/git/blob.rs @@ -0,0 +1,113 @@ +use git::rpc::{proto as p, proto::blob_service_client::BlobServiceClient}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_blob_load( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + path: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlobServiceClient::new(self.git.clone()); + let resp = client + .blob_load(tonic::Request::new(p::BlobLoadRequest { + repo_id: repo.id.to_string(), + id: Some(p::ObjectId { value: oid }), + path, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_blob_size( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + path: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlobServiceClient::new(self.git.clone()); + let resp = client + .blob_size(tonic::Request::new(p::BlobSizeRequest { + repo_id: repo.id.to_string(), + id: Some(p::ObjectId { value: oid }), + path, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_blob_exists( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlobServiceClient::new(self.git.clone()); + let resp = client + .blob_exists(tonic::Request::new(p::BlobExistsRequest { + repo_id: repo.id.to_string(), + id: Some(p::ObjectId { value: oid }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_blob_is_binary( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BlobServiceClient::new(self.git.clone()); + let resp = client + .blob_is_binary(tonic::Request::new(p::BlobIsBinaryRequest { + repo_id: repo.id.to_string(), + id: Some(p::ObjectId { value: oid }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_blob_upload( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + path: String, + blob: Vec, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = BlobServiceClient::new(self.git.clone()); + let resp = client + .blob_upload(tonic::Request::new(p::BlobUploadRequest { + repo_id: repo.id.to_string(), + blob, + path, + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } +} diff --git a/lib/service/git/branch.rs b/lib/service/git/branch.rs new file mode 100644 index 0000000..ec97703 --- /dev/null +++ b/lib/service/git/branch.rs @@ -0,0 +1,196 @@ +use git::rpc::{proto as p, proto::branch_service_client::BranchServiceClient}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_branch_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + pagination: Pagination, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let mut resp = client + .branch_list(tonic::Request::new(p::BranchListRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let offset = pagination.offset() as usize; + let limit = pagination.limit() as usize; + if offset > 0 || resp.branches.len() > limit { + let start = offset.min(resp.branches.len()); + let end = (start + limit).min(resp.branches.len()); + resp.branches = resp.branches.drain(start..end).collect(); + } + Ok(resp) + } + + pub async fn git_branch_info( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + branch: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_info(tonic::Request::new(p::BranchInfoRequest { + repo_id: repo.id.to_string(), + branch, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_branch_summary( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_summary(tonic::Request::new(p::BranchSummaryRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_branch_head( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_head(tonic::Request::new(p::BranchHeadRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_branch_ahead_behind( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + local_branch: String, + remote_branch: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_ahead_behind(tonic::Request::new( + p::BranchAheadBehindRequest { + repo_id: repo.id.to_string(), + local_branch, + remote_branch, + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_branch_upstream( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + branch: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_upstream(tonic::Request::new(p::BranchUpstreamRequest { + repo_id: repo.id.to_string(), + branch, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_branch_fork( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::BranchForkParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_fork(tonic::Request::new(p::BranchForkRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } + + pub async fn git_branch_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::BranchDeleteParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_delete(tonic::Request::new(p::BranchDeleteRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } + + pub async fn git_branch_rename( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::BranchReNameParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_rename(tonic::Request::new(p::BranchRenameRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } +} diff --git a/lib/service/git/commit.rs b/lib/service/git/commit.rs new file mode 100644 index 0000000..5b82c19 --- /dev/null +++ b/lib/service/git/commit.rs @@ -0,0 +1,192 @@ +use git::rpc::{proto as p, proto::commit_service_client::CommitServiceClient}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_commit_info( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_info(tonic::Request::new(p::CommitInfoRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: oid }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_commit_history( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + limit: u64, + skip: u64, + sort: i32, + branch: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_history(tonic::Request::new(p::CommitHistoryRequest { + repo_id: repo.id.to_string(), + limit, + skip, + sort, + branch, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_commit_summary( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_summary(tonic::Request::new(p::CommitSummaryRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_commit_walk( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::CommitWalkParams, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_walk(tonic::Request::new(p::CommitWalkRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_commit_refs( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_refs(tonic::Request::new(p::CommitRefsRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_commit_prefix( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + prefix: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_prefix(tonic::Request::new(p::CommitPrefixRequest { + repo_id: repo.id.to_string(), + prefix, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_commit_exists( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .commit_exists(tonic::Request::new(p::CommitExistsRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: oid }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_cherry_pick( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::CommitCherryPickParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .cherry_pick(tonic::Request::new(p::CherryPickRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } + + pub async fn git_cherry_pick_sequence( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::CommitCherryPickSequence, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + let resp = client + .cherry_pick_sequence(tonic::Request::new( + p::CherryPickSequenceRequest { + repo_id: repo.id.to_string(), + params: Some(params), + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } +} diff --git a/lib/service/git/commit_status.rs b/lib/service/git/commit_status.rs new file mode 100644 index 0000000..36f38e6 --- /dev/null +++ b/lib/service/git/commit_status.rs @@ -0,0 +1,151 @@ +use chrono::Utc; +use db::sqlx; +use model::repos::RepoCommitStatusModel; +use session::Session; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct CommitStatusResponse { + pub id: Uuid, + pub commit_sha: String, + pub state: String, + pub target_url: Option, + pub description: Option, + pub context: String, + pub creator: Uuid, + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct CombinedCommitStatus { + pub sha: String, + pub state: String, + pub total_count: i64, + pub statuses: Vec, +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct CreateCommitStatus { + pub state: String, + pub target_url: Option, + pub description: Option, + pub context: Option, +} + +impl AppService { + pub async fn git_commit_status_list( + &self, + repo_id: Uuid, + commit_sha: &str, + ) -> Result, AppError> { + let rows = sqlx::query_as::<_, RepoCommitStatusModel>( + "SELECT id, repo, commit_sha, state, target_url, description, \ + context, creator, created_at, updated_at \ + FROM repo_commit_status \ + WHERE repo = $1 AND commit_sha = $2 \ + ORDER BY created_at DESC", + ) + .bind(repo_id) + .bind(commit_sha) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(status_to_response).collect()) + } + + pub async fn git_commit_status_combined( + &self, + repo_id: Uuid, + commit_sha: &str, + ) -> Result { + let statuses = self.git_commit_status_list(repo_id, commit_sha).await?; + let state = combined_state(&statuses); + Ok(CombinedCommitStatus { + sha: commit_sha.to_string(), + state, + total_count: statuses.len() as i64, + statuses, + }) + } + + pub async fn git_commit_status_create( + &self, + repo_id: Uuid, + user_id: Uuid, + commit_sha: &str, + params: CreateCommitStatus, + ) -> Result { + if !["pending", "success", "failure", "error"].contains(¶ms.state.as_str()) { + return Err(AppError::BadRequest( + "state must be one of: pending, success, failure, error".to_string(), + )); + } + + let id = Uuid::now_v7(); + let now = Utc::now(); + let context = params.context.unwrap_or_else(|| "default".to_string()); + + let row = sqlx::query_as::<_, RepoCommitStatusModel>( + "INSERT INTO repo_commit_status (id, repo, commit_sha, state, target_url, \ + description, context, creator, created_at, updated_at) \ + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$9) RETURNING *", + ) + .bind(id) + .bind(repo_id) + .bind(commit_sha) + .bind(¶ms.state) + .bind(¶ms.target_url) + .bind(¶ms.description) + .bind(&context) + .bind(user_id) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(status_to_response(row)) + } +} + +impl AppService { + pub async fn git_commit_status_list_by_name(&self, ctx: &Session, wk: &str, repo: &str, sha: &str) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_commit_status_list(repo.id, sha).await + } + pub async fn git_commit_status_combined_by_name(&self, ctx: &Session, wk: &str, repo: &str, sha: &str) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_commit_status_combined(repo.id, sha).await + } + pub async fn git_commit_status_create_by_name(&self, ctx: &Session, user_id: Uuid, wk: &str, repo: &str, sha: &str, params: CreateCommitStatus) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_commit_status_create(repo.id, user_id, sha, params).await + } +} + +fn status_to_response(s: RepoCommitStatusModel) -> CommitStatusResponse { + CommitStatusResponse { + id: s.id, + commit_sha: s.commit_sha, + state: s.state, + target_url: s.target_url, + description: s.description, + context: s.context, + creator: s.creator, + created_at: s.created_at, + } +} + +fn combined_state(statuses: &[CommitStatusResponse]) -> String { + if statuses.is_empty() { + return "pending".to_string(); + } + let has = |s: &str| statuses.iter().any(|st| st.state == s); + (if has("error") { "error" } + else if has("failure") { "failure" } + else if has("pending") { "pending" } + else { "success" }).to_string() +} diff --git a/lib/service/git/compare.rs b/lib/service/git/compare.rs new file mode 100644 index 0000000..aa4b8b1 --- /dev/null +++ b/lib/service/git/compare.rs @@ -0,0 +1,114 @@ +use git::rpc::{proto as p, proto::commit_service_client::CommitServiceClient}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct CompareResponse { + pub base_commit: CompareCommit, + pub head_commit: CompareCommit, + pub ahead_by: i32, + pub behind_by: i32, + pub total_commits: i32, + pub commits: Vec, + pub files_changed: u64, + pub insertions: u64, + pub deletions: u64, +} + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct CompareCommit { + pub sha: String, + pub message: String, + pub author_name: Option, + pub author_email: Option, +} + +impl AppService { + pub async fn git_compare( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + base: &str, + head: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = CommitServiceClient::new(self.git.clone()); + + fn oid(s: &str) -> p::ObjectId { + p::ObjectId { value: s.to_string() } + } + + let base_info = client + .commit_info(tonic::Request::new(p::CommitInfoRequest { + repo_id: repo.id.to_string(), + oid: Some(oid(base)), + })) + .await.map_err(rpc_err)?.into_inner(); + + let head_info = client + .commit_info(tonic::Request::new(p::CommitInfoRequest { + repo_id: repo.id.to_string(), + oid: Some(oid(head)), + })) + .await.map_err(rpc_err)?.into_inner(); + + let history = client + .commit_history(tonic::Request::new(p::CommitHistoryRequest { + repo_id: repo.id.to_string(), + limit: 250, skip: 0, sort: 0, + branch: Some(format!("{base}..{head}")), + })) + .await.map_err(rpc_err)?.into_inner(); + + let commits: Vec = history.commits.into_iter().map(|c| { + let author_name = c.author.as_ref().map(|a| a.name.clone()); + let author_email = c.author.as_ref().map(|a| a.email.clone()); + CompareCommit { + sha: c.oid.map(|o| o.value).unwrap_or_default(), + message: c.summary, + author_name, + author_email, + } + }).collect(); + + let diff = crate::AppService::git_diff_stats( + self, ctx, wk_name, repo_name, + base.to_string(), head.to_string(), None, + ).await?; + + let stats = diff.result.and_then(|r| r.stats); + let files_changed = stats.as_ref().map(|s| s.files_changed).unwrap_or(0); + let insertions = stats.as_ref().map(|s| s.insertions).unwrap_or(0); + let deletions = stats.as_ref().map(|s| s.deletions).unwrap_or(0); + + Ok(CompareResponse { + base_commit: cmt(base_info.commit), + head_commit: cmt(head_info.commit), + ahead_by: commits.len() as i32, + behind_by: 0, + total_commits: commits.len() as i32, + commits, + files_changed, + insertions, + deletions, + }) + } +} + +fn cmt(c: Option) -> CompareCommit { + c.map(|c| { + let author_name = c.author.as_ref().map(|a| a.name.clone()); + let author_email = c.author.as_ref().map(|a| a.email.clone()); + CompareCommit { + sha: c.oid.map(|o| o.value).unwrap_or_default(), + message: c.message, + author_name, + author_email, + } + }).unwrap_or_else(|| CompareCommit { + sha: String::new(), message: String::new(), + author_name: None, author_email: None, + }) +} diff --git a/lib/service/git/contents.rs b/lib/service/git/contents.rs new file mode 100644 index 0000000..552c248 --- /dev/null +++ b/lib/service/git/contents.rs @@ -0,0 +1,182 @@ +use db::sqlx; +use git::rpc::{ + proto as p, + proto::blob_service_client::BlobServiceClient, +}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct ContentResponse { + pub path: String, + pub name: String, + #[serde(rename = "type")] + pub content_type: String, + pub size: i64, + pub encoding: Option, + pub content: Option, +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct CreateContent { + pub message: String, + pub content: String, + pub branch: Option, +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct UpdateContent { + pub message: String, + pub content: String, + pub sha: String, + pub branch: Option, +} + +impl AppService { + pub async fn git_contents_get_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, ref_name: Option<&str>) -> Result { + let _ = self.git_require_member(ctx, wk, repo).await?; + self.git_contents_get(ctx, wk, repo, path, ref_name).await + } + pub async fn git_contents_create_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, params: CreateContent) -> Result { + let _ = self.git_require_member(ctx, wk, repo).await?; + self.git_contents_create(ctx, wk, repo, path, params).await + } + pub async fn git_contents_update_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, params: UpdateContent) -> Result { + let _ = self.git_require_member(ctx, wk, repo).await?; + self.git_contents_update(ctx, wk, repo, path, params).await + } + pub async fn git_contents_delete_by_name(&self, ctx: &Session, wk: &str, repo: &str, path: &str, msg: &str, sha: &str, branch: Option<&str>) -> Result<(), AppError> { + let _ = self.git_require_member(ctx, wk, repo).await?; + self.git_contents_delete(ctx, wk, repo, path, msg, sha, branch).await + } +} + +impl AppService { + pub async fn git_contents_get( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + path: &str, + _ref_name: Option<&str>, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut blob_client = BlobServiceClient::new(self.git.clone()); + + let empty_oid = p::ObjectId { value: String::new() }; + + let resp = blob_client + .blob_load(tonic::Request::new(p::BlobLoadRequest { + repo_id: repo.id.to_string(), + id: Some(empty_oid.clone()), + path: path.to_string(), + })) + .await.map_err(rpc_err)?.into_inner(); + + let is_binary = blob_client + .blob_is_binary(tonic::Request::new(p::BlobIsBinaryRequest { + repo_id: repo.id.to_string(), + id: Some(empty_oid.clone()), + })) + .await.map(|r| r.into_inner().is_binary).unwrap_or(false); + + let size_resp = blob_client + .blob_size(tonic::Request::new(p::BlobSizeRequest { + repo_id: repo.id.to_string(), + id: Some(empty_oid), + path: String::new(), + })) + .await.map_err(rpc_err)?.into_inner(); + + let blob_data = resp.blob; + let size = size_resp.size as i64; + + let (encoding, content) = if is_binary { + (Some("base64".to_string()), Some(base64_encode(&blob_data))) + } else { + (None, Some(String::from_utf8_lossy(&blob_data).to_string())) + }; + + Ok(ContentResponse { + path: path.to_string(), + name: path.rsplit('/').next().unwrap_or(path).to_string(), + content_type: "file".to_string(), + size, + encoding, + content, + }) + } + + pub async fn git_contents_create( + &self, ctx: &Session, wk_name: &str, repo_name: &str, path: &str, + params: CreateContent, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let user = crate::session_user(ctx)?; + + let user_model = sqlx::query_as::<_, model::users::UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at FROM \"user\" WHERE id = $1", + ) + .bind(user) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + let display = user_model.display_name.clone(); + let username = user_model.username.clone(); + let author_name = if display.is_empty() { username.clone() } else { display }; + + let file_size = params.content.len() as i64; + let content_bytes = params.content.clone().into_bytes(); + let mut client = p::commit_service_client::CommitServiceClient::new(self.git.clone()); + let resp = client + .create_commit(tonic::Request::new(p::CreateCommitRequest { + repo_id: repo.id.to_string(), + branch: params.branch.unwrap_or_else(|| repo.default_branch.clone()), + message: params.message, + author_name: author_name.clone(), + author_email: format!("{username}@gitdata.ai"), + committer_name: "redpanda".to_string(), + committer_email: "redpanda@gitdata.ai".to_string(), + files: vec![p::FileChange { + path: path.to_string(), + content: content_bytes, + }], + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let _oid = resp.oid.map(|o| o.value).unwrap_or_default(); + Ok(ContentResponse { + path: path.to_string(), + name: path.rsplit('/').next().unwrap_or(path).to_string(), + content_type: "file".to_string(), + size: file_size, + encoding: None, + content: Some(params.content), + }) + } + + pub async fn git_contents_update( + &self, _ctx: &Session, _wk: &str, _repo: &str, _path: &str, + _params: UpdateContent, + ) -> Result { + Err(AppError::InternalServerError("contents update not yet implemented".to_string())) + } + + pub async fn git_contents_delete( + &self, _ctx: &Session, _wk: &str, _repo: &str, _path: &str, + _message: &str, _sha: &str, _branch: Option<&str>, + ) -> Result<(), AppError> { + Err(AppError::InternalServerError("contents delete not yet implemented".to_string())) + } +} + +fn base64_encode(data: &[u8]) -> String { + use base64::Engine as _; + base64::engine::general_purpose::STANDARD.encode(data) +} diff --git a/lib/service/git/contributor.rs b/lib/service/git/contributor.rs new file mode 100644 index 0000000..2ed9c90 --- /dev/null +++ b/lib/service/git/contributor.rs @@ -0,0 +1,56 @@ +use db::sqlx; +use serde::{Deserialize, Serialize}; +use session::Session; +use utoipa::ToSchema; + +use crate::{AppService, Pagination, error::AppError}; + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct ContributorDto { + pub name: String, + pub email: String, + #[schema(value_type = Option)] + pub user_id: Option, + pub commit_count: i64, +} + +impl AppService { + pub async fn git_repo_contributors( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + pagination: Pagination, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let offset = pagination.offset() as i64; + let limit = pagination.limit() as i64; + + let rows = sqlx::query_as::<_, (String, String, Option, Option)>( + "SELECT rc.name, rc.email, rc.user, COUNT(rco.id)::bigint AS commit_count \ + FROM repo_committer rc \ + LEFT JOIN repo_commit rco ON rc.id = rco.author \ + WHERE rc.repo = $1 \ + GROUP BY rc.id, rc.name, rc.email, rc.user \ + ORDER BY commit_count DESC \ + OFFSET $2 LIMIT $3", + ) + .bind(repo.id) + .bind(offset) + .bind(limit) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(|(name, email, user_id, commit_count)| ContributorDto { + name, + email, + user_id, + commit_count: commit_count.unwrap_or(0), + }) + .collect()) + } +} diff --git a/lib/service/git/diff.rs b/lib/service/git/diff.rs new file mode 100644 index 0000000..12e3a03 --- /dev/null +++ b/lib/service/git/diff.rs @@ -0,0 +1,128 @@ +use git::rpc::{proto as p, proto::diff_service_client::DiffServiceClient}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_diff_stats( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + old_oid: String, + new_oid: String, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = DiffServiceClient::new(self.git.clone()); + let resp = client + .diff_stats(tonic::Request::new(p::DiffStatsRequest { + repo_id: repo.id.to_string(), + old_oid: Some(p::ObjectId { value: old_oid }), + new_oid: Some(p::ObjectId { value: new_oid }), + options, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_diff_patch( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + old_oid: String, + new_oid: String, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = DiffServiceClient::new(self.git.clone()); + let resp = client + .diff_patch(tonic::Request::new(p::DiffPatchRequest { + repo_id: repo.id.to_string(), + old_oid: Some(p::ObjectId { value: old_oid }), + new_oid: Some(p::ObjectId { value: new_oid }), + options, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_diff_patch_side_by_side( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + old_oid: String, + new_oid: String, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = DiffServiceClient::new(self.git.clone()); + let resp = client + .diff_patch_side_by_side(tonic::Request::new( + p::DiffPatchSideBySideRequest { + repo_id: repo.id.to_string(), + old_oid: Some(p::ObjectId { value: old_oid }), + new_oid: Some(p::ObjectId { value: new_oid }), + options, + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_diff_tree_to_tree( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + old_tree: String, + new_tree: String, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = DiffServiceClient::new(self.git.clone()); + let resp = client + .diff_tree_to_tree(tonic::Request::new(p::DiffTreeToTreeRequest { + repo_id: repo.id.to_string(), + old_tree: Some(p::ObjectId { value: old_tree }), + new_tree: Some(p::ObjectId { value: new_tree }), + options, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_diff_index_to_tree( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + tree_oid: String, + options: Option, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = DiffServiceClient::new(self.git.clone()); + let resp = client + .diff_index_to_tree(tonic::Request::new( + p::DiffIndexToTreeRequest { + repo_id: repo.id.to_string(), + tree_oid: Some(p::ObjectId { value: tree_oid }), + options, + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } +} diff --git a/lib/service/git/fork.rs b/lib/service/git/fork.rs new file mode 100644 index 0000000..92e9a4a --- /dev/null +++ b/lib/service/git/fork.rs @@ -0,0 +1,204 @@ +use db::sqlx; +use git::rpc::{proto as p, proto::fork_service_client::ForkServiceClient}; +use model::repos::RepoModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ForkResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, + pub description: Option, + pub default_branch: String, + pub visibility: String, + #[schema(value_type = String)] + pub source_repo: uuid::Uuid, + #[schema(value_type = String)] + pub forked_by: uuid::Uuid, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateFork { + pub name: Option, + pub visibility: Option, +} + +#[derive(db::sqlx::FromRow)] +struct ForkListRow { + source_repo: uuid::Uuid, + forked_by: uuid::Uuid, + fork_created_at: chrono::DateTime, + repo_id: uuid::Uuid, + repo_name: String, + repo_description: Option, + repo_default_branch: String, + repo_visibility: String, +} + +impl AppService { + pub async fn repo_fork_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: CreateFork, + ) -> Result { + let user_uid = session_user(ctx)?; + let src_wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(src_wk.id, user_uid).await?; + let source_repo = self.repo_resolve(src_wk.id, repo_name).await?; + + if source_repo.visibility == "private" { + return Err(AppError::Forbidden( + "cannot fork a private repo".to_string(), + )); + } + + let fork_name = params.name.unwrap_or_else(|| source_repo.name.clone()); + let fork_visibility = params + .visibility + .unwrap_or_else(|| source_repo.visibility.clone()); + + let existing = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL AND created_by = $3)", + ) + .bind(src_wk.id) + .bind(&fork_name) + .bind(user_uid) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + if existing { + return Err(AppError::Conflict("fork already exists".to_string())); + } + + let repo_id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + let description = source_repo.description.clone(); + let default_branch = source_repo.default_branch.clone(); + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + let _fork_repo = sqlx::query_as::<_, RepoModel>( + "INSERT INTO repo (id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, 0, false, false, false, $7, '', $8, $8) \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(repo_id) + .bind(src_wk.id) + .bind(&fork_name) + .bind(&description) + .bind(&default_branch) + .bind(&fork_visibility) + .bind(user_uid) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let storage_root = self + .config + .repos_root() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let mut client = ForkServiceClient::new(self.git.clone()); + let rpc_resp = client + .fork_bare(tonic::Request::new(p::ForkBareRequest { + storage_root, + source_storage_path: source_repo.storage_path.clone(), + params: Some(p::ForkRepoParams { + namespace: src_wk.name.clone(), + repo_name: fork_name.clone(), + default_branch: default_branch.clone(), + description: description.clone(), + enable_lfs: false, + }), + })) + .await + .map_err(crate::git::rpc_err)? + .into_inner(); + let fork_repo = sqlx::query_as::<_, RepoModel>( + "UPDATE repo SET storage_path = $1 WHERE id = $2 \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(&rpc_resp.storage_path) + .bind(repo_id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + sqlx::query( + "INSERT INTO repo_fork (id, repo, source_repo, forked_by, created_at) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(fork_repo.id) + .bind(source_repo.id) + .bind(user_uid) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + self.queue_sync(repo_id).await; + + Ok(ForkResponse { + id: fork_repo.id, + name: fork_repo.name, + description: fork_repo.description, + default_branch: fork_repo.default_branch, + visibility: fork_repo.visibility, + source_repo: source_repo.id, + forked_by: user_uid, + created_at: fork_repo.created_at, + }) + } + + pub async fn repo_fork_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + pagination: Pagination, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let rows = sqlx::query_as::<_, ForkListRow>( + "SELECT f.id as fork_id, f.source_repo, f.forked_by, f.created_at as fork_created_at, \ + r.id as repo_id, r.name as repo_name, r.description as repo_description, \ + r.default_branch as repo_default_branch, r.visibility as repo_visibility, \ + r.created_at as repo_created_at \ + FROM repo_fork f \ + INNER JOIN repo r ON r.id = f.repo AND r.deleted_at IS NULL \ + WHERE f.source_repo = $1 \ + ORDER BY f.created_at DESC \ + OFFSET $2 LIMIT $3", + ) + .bind(repo.id) + .bind(pagination.offset() as i64) + .bind(pagination.limit() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(|row| ForkResponse { + id: row.repo_id, + name: row.repo_name, + description: row.repo_description, + default_branch: row.repo_default_branch, + visibility: row.repo_visibility, + source_repo: row.source_repo, + forked_by: row.forked_by, + created_at: row.fork_created_at, + }) + .collect()) + } +} diff --git a/lib/service/git/init.rs b/lib/service/git/init.rs new file mode 100644 index 0000000..9140d2c --- /dev/null +++ b/lib/service/git/init.rs @@ -0,0 +1,217 @@ +use db::sqlx; +use git::rpc::{proto as p, proto::init_service_client::InitServiceClient}; +use model::repos::RepoModel; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err, session_user}; + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct CreateRepo { + pub name: String, + pub description: Option, + pub default_branch: Option, + pub visibility: Option, + pub initialize_with_readme: Option, + pub enable_lfs: Option, +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct CloneRepo { + pub name: String, + pub source_url: String, + pub description: Option, + pub visibility: Option, +} + +impl AppService { + pub async fn git_init_bare( + &self, + ctx: &Session, + wk_name: &str, + params: CreateRepo, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let name = params.name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest( + "repo name is required".to_string(), + )); + } + + let existing = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL)", + ) + .bind(wk.id) + .bind(name) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if existing { + return Err(AppError::RepoNameAlreadyExists); + } + + let default_branch = + params.default_branch.unwrap_or_else(|| "main".to_string()); + let visibility = + params.visibility.unwrap_or_else(|| "private".to_string()); + let description = params.description.unwrap_or_default(); + let now = chrono::Utc::now(); + let repo_id = uuid::Uuid::now_v7(); + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let _repo = sqlx::query_as::<_, RepoModel>( + "INSERT INTO repo (id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, 0, false, false, false, $7, '', $8, $8) \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(repo_id) + .bind(wk.id) + .bind(name) + .bind(&description) + .bind(&default_branch) + .bind(&visibility) + .bind(user_uid) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let storage_root = self + .config + .repos_root() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let mut client = InitServiceClient::new(self.git.clone()); + let rpc_resp = client + .init_bare(tonic::Request::new(p::InitBareRequest { + storage_root, + params: Some(p::InitRepoParams { + namespace: wk.name.clone(), + repo_name: name.to_string(), + default_branch, + description: Some(description), + initialize_with_readme: params + .initialize_with_readme + .unwrap_or(false), + enable_lfs: params.enable_lfs.unwrap_or(false), + }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let repo = sqlx::query_as::<_, RepoModel>( + "UPDATE repo SET storage_path = $1 WHERE id = $2 \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(&rpc_resp.storage_path) + .bind(repo_id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + self.queue_sync(repo_id).await; + Ok(repo) + } + + pub async fn git_clone_bare( + &self, + ctx: &Session, + wk_name: &str, + params: CloneRepo, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let name = params.name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest("repo name is required".to_string())); + } + + let source_url = params.source_url.trim(); + if source_url.is_empty() { + return Err(AppError::BadRequest("source URL is required".to_string())); + } + + let existing = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL)", + ) + .bind(wk.id) + .bind(name) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if existing { + return Err(AppError::RepoNameAlreadyExists); + } + + let visibility = + params.visibility.unwrap_or_else(|| "private".to_string()); + let description = params.description.unwrap_or_default(); + let now = chrono::Utc::now(); + let repo_id = uuid::Uuid::now_v7(); + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let _repo = sqlx::query_as::<_, RepoModel>( + "INSERT INTO repo (id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, '', $5, 0, false, false, true, $6, '', $7, $7) \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(repo_id) + .bind(wk.id) + .bind(name) + .bind(&description) + .bind(&visibility) + .bind(user_uid) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let storage_root = self + .config + .repos_root() + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + let mut client = InitServiceClient::new(self.git.clone()); + let rpc_resp = client + .clone_bare(tonic::Request::new(p::CloneBareRequest { + storage_root, + source_url: source_url.to_string(), + namespace: wk.name.clone(), + repo_name: name.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let repo = sqlx::query_as::<_, RepoModel>( + "UPDATE repo SET storage_path = $1 WHERE id = $2 \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(&rpc_resp.storage_path) + .bind(repo_id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + self.queue_sync(repo_id).await; + Ok(repo) + } +} diff --git a/lib/service/git/language.rs b/lib/service/git/language.rs new file mode 100644 index 0000000..350d97c --- /dev/null +++ b/lib/service/git/language.rs @@ -0,0 +1,42 @@ +use db::sqlx; +use serde::{Deserialize, Serialize}; +use session::Session; +use utoipa::ToSchema; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct LanguageStatDto { + pub language: String, + pub bytes: i64, + pub percentage: f32, +} + +impl AppService { + pub async fn git_repo_languages( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let rows = sqlx::query_as::<_, (String, i64, f32)>( + "SELECT language, bytes, percentage FROM repo_language WHERE repo = $1 \ + ORDER BY bytes DESC", + ) + .bind(repo.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(|(language, bytes, percentage)| LanguageStatDto { + language, + bytes, + percentage, + }) + .collect()) + } +} diff --git a/lib/service/git/mod.rs b/lib/service/git/mod.rs new file mode 100644 index 0000000..cfe1078 --- /dev/null +++ b/lib/service/git/mod.rs @@ -0,0 +1,95 @@ +pub mod archive; +pub mod blame; +pub mod blob; +pub mod branch; +pub mod commit; +pub mod commit_status; +pub mod compare; +pub mod contents; +pub mod contributor; +pub mod diff; +pub mod fork; +pub mod init; +pub mod language; +pub mod protect; +pub mod readme; +pub mod refs; +pub mod release; +pub mod repo; +pub mod star; +pub mod tag; +pub mod tree; +pub mod watch; +pub mod webhook; + +use db::sqlx; +use git::sync::{ReceiveSyncService, RepoReceiveSyncTask}; +use model::repos::RepoModel; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +impl AppService { + pub(crate) async fn queue_sync(&self, repo_uid: uuid::Uuid) { + let sync_service = ReceiveSyncService::new(self.redis_pool.clone()); + sync_service.send(RepoReceiveSyncTask { repo_uid }).await; + } + + pub(crate) async fn repo_resolve( + &self, + wk_id: uuid::Uuid, + repo_name: &str, + ) -> Result { + sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at \ + FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL", + ) + .bind(wk_id) + .bind(repo_name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::RepoNotFound) + } + + pub(crate) async fn git_require_member( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + self.repo_resolve(wk.id, repo_name).await + } + + pub(crate) async fn git_require_admin( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + self.repo_resolve(wk.id, repo_name).await + } +} + +pub(crate) fn rpc_err(status: tonic::Status) -> AppError { + match status.code() { + tonic::Code::NotFound => { + AppError::NotFound(status.message().to_string()) + } + tonic::Code::PermissionDenied => AppError::PermissionDenied, + tonic::Code::InvalidArgument => { + AppError::BadRequest(status.message().to_string()) + } + tonic::Code::AlreadyExists => { + AppError::Conflict(status.message().to_string()) + } + _ => AppError::GitRpcError(status.message().to_string()), + } +} diff --git a/lib/service/git/protect.rs b/lib/service/git/protect.rs new file mode 100644 index 0000000..7bd369e --- /dev/null +++ b/lib/service/git/protect.rs @@ -0,0 +1,259 @@ +use db::sqlx; +use model::repos::RepoProtectModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ProtectResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + #[schema(value_type = String)] + pub repo: uuid::Uuid, + pub pattern: String, + pub require_pull_request: bool, + pub required_approvals: i32, + pub require_status_checks: bool, + pub required_status_contexts: Vec, + pub enforce_admins: bool, + pub allow_force_pushes: bool, + pub allow_deletions: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +pub(crate) fn protect_response(p: RepoProtectModel) -> ProtectResponse { + ProtectResponse { + id: p.id, + repo: p.repo, + pattern: p.pattern, + require_pull_request: p.require_pull_request, + required_approvals: p.required_approvals, + require_status_checks: p.require_status_checks, + required_status_contexts: p + .required_status_contexts + .split('.') + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect(), + enforce_admins: p.enforce_admins, + allow_force_pushes: p.allow_force_pushes, + allow_deletions: p.allow_deletions, + created_at: p.created_at, + updated_at: p.updated_at, + } +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateProtect { + pub pattern: String, + pub require_pull_request: Option, + pub required_approvals: Option, + pub require_status_checks: Option, + pub required_status_contexts: Option>, + pub enforce_admins: Option, + pub allow_force_pushes: Option, + pub allow_deletions: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateProtect { + pub pattern: Option, + pub require_pull_request: Option, + pub required_approvals: Option, + pub require_status_checks: Option, + pub required_status_contexts: Option>, + pub enforce_admins: Option, + pub allow_force_pushes: Option, + pub allow_deletions: Option, +} + +impl AppService { + pub async fn repo_protect_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + pagination: Pagination, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let rows = sqlx::query_as::<_, RepoProtectModel>( + "SELECT id, repo, pattern, require_pull_request, required_approvals, \ + require_status_checks, required_status_contexts, enforce_admins, \ + allow_force_pushes, allow_deletions, created_at, updated_at \ + FROM repo_protect WHERE repo = $1 \ + ORDER BY pattern ASC OFFSET $2 LIMIT $3", + ) + .bind(repo.id) + .bind(pagination.offset() as i64) + .bind(pagination.limit() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(protect_response).collect()) + } + + pub async fn repo_protect_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: CreateProtect, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let pattern = params.pattern.trim(); + if pattern.is_empty() { + return Err(AppError::BadRequest( + "pattern is required".to_string(), + )); + } + + let contexts = params + .required_status_contexts + .unwrap_or_default() + .join("."); + + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + + let row = sqlx::query_as::<_, RepoProtectModel>( + "INSERT INTO repo_protect \ + (id, repo, pattern, require_pull_request, required_approvals, \ + require_status_checks, required_status_contexts, enforce_admins, \ + allow_force_pushes, allow_deletions, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $11) \ + RETURNING id, repo, pattern, require_pull_request, required_approvals, \ + require_status_checks, required_status_contexts, enforce_admins, \ + allow_force_pushes, allow_deletions, created_at, updated_at", + ) + .bind(id) + .bind(repo.id) + .bind(pattern) + .bind(params.require_pull_request.unwrap_or(true)) + .bind(params.required_approvals.unwrap_or(1)) + .bind(params.require_status_checks.unwrap_or(false)) + .bind(&contexts) + .bind(params.enforce_admins.unwrap_or(false)) + .bind(params.allow_force_pushes.unwrap_or(false)) + .bind(params.allow_deletions.unwrap_or(false)) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(protect_response(row)) + } + + pub async fn repo_protect_update( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + protect_id: uuid::Uuid, + params: UpdateProtect, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let existing = sqlx::query_as::<_, RepoProtectModel>( + "SELECT id, repo, pattern, require_pull_request, required_approvals, \ + require_status_checks, required_status_contexts, enforce_admins, \ + allow_force_pushes, allow_deletions, created_at, updated_at \ + FROM repo_protect WHERE id = $1 AND repo = $2", + ) + .bind(protect_id) + .bind(repo.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("protected branch rule not found".to_string()))?; + + let pattern = params.pattern.unwrap_or(existing.pattern); + let require_pull_request = params + .require_pull_request + .unwrap_or(existing.require_pull_request); + let required_approvals = params + .required_approvals + .unwrap_or(existing.required_approvals); + let require_status_checks = params + .require_status_checks + .unwrap_or(existing.require_status_checks); + let contexts = params + .required_status_contexts + .map(|c| c.join(".")) + .unwrap_or(existing.required_status_contexts); + let enforce_admins = + params.enforce_admins.unwrap_or(existing.enforce_admins); + let allow_force_pushes = params + .allow_force_pushes + .unwrap_or(existing.allow_force_pushes); + let allow_deletions = + params.allow_deletions.unwrap_or(existing.allow_deletions); + + let row = sqlx::query_as::<_, RepoProtectModel>( + "UPDATE repo_protect SET \ + pattern = $1, require_pull_request = $2, required_approvals = $3, \ + require_status_checks = $4, required_status_contexts = $5, enforce_admins = $6, \ + allow_force_pushes = $7, allow_deletions = $8, updated_at = $9 \ + WHERE id = $10 \ + RETURNING id, repo, pattern, require_pull_request, required_approvals, \ + require_status_checks, required_status_contexts, enforce_admins, \ + allow_force_pushes, allow_deletions, created_at, updated_at", + ) + .bind(&pattern) + .bind(require_pull_request) + .bind(required_approvals) + .bind(require_status_checks) + .bind(&contexts) + .bind(enforce_admins) + .bind(allow_force_pushes) + .bind(allow_deletions) + .bind(chrono::Utc::now()) + .bind(protect_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(protect_response(row)) + } + + pub async fn repo_protect_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + protect_id: uuid::Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let result = + sqlx::query("DELETE FROM repo_protect WHERE id = $1 AND repo = $2") + .bind(protect_id) + .bind(repo.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if result.rows_affected() == 0 { + return Err(AppError::NotFound( + "protected branch rule not found".to_string(), + )); + } + + Ok(()) + } +} diff --git a/lib/service/git/readme.rs b/lib/service/git/readme.rs new file mode 100644 index 0000000..d7e252c --- /dev/null +++ b/lib/service/git/readme.rs @@ -0,0 +1,139 @@ +use serde::{Deserialize, Serialize}; +use session::Session; +use utoipa::ToSchema; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct ReadmeDto { + pub content: String, + pub html: String, +} + +impl AppService { + pub async fn git_repo_readme( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let readme_names = [ + "README.md", + "README.markdown", + "README.txt", + "README", + "Readme.md", + "readme.md", + ]; + + for name in &readme_names { + match self + .git_tree_entry_by_path_from_commit_for_readme( + &repo.id, + name, + ) + .await? + { + Some((content, oid)) => { + return self + .git_blob_load_for_readme( + &repo, &content, &oid, + ) + .await; + } + None => continue, + } + } + + Ok(None) + } +} +impl AppService { + async fn git_tree_entry_by_path_from_commit_for_readme( + &self, + repo_id: &uuid::Uuid, + readme_name: &str, + ) -> Result, AppError> { + use git::rpc::proto as p; + use git::rpc::proto::tree_service_client::TreeServiceClient; + use crate::git::rpc_err; + + let mut client = TreeServiceClient::new(self.git.clone()); + let mut commit_client = + git::rpc::proto::commit_service_client::CommitServiceClient::new(self.git.clone()); + let summary_resp = commit_client + .commit_summary(tonic::Request::new(p::CommitSummaryRequest { + repo_id: repo_id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let head_commit = match summary_resp.summary.and_then(|s| s.head) { + Some(c) => c, + None => return Ok(None), + }; + + let tree_id = match head_commit.tree_id { + Some(id) => id.value, + None => return Ok(None), + }; + + let resp = client + .tree_entry_by_path(tonic::Request::new(p::TreeEntryByPathRequest { + repo_id: repo_id.to_string(), + tree_oid: Some(p::ObjectId { value: tree_id.clone() }), + path: readme_name.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + match resp.entry { + Some(entry) => { + let oid = entry.oid.map(|o| o.value).unwrap_or_default(); + if oid.is_empty() { + Ok(None) + } else { + Ok(Some((readme_name.to_string(), oid))) + } + } + None => Ok(None), + } + } + async fn git_blob_load_for_readme( + &self, + repo: &model::repos::RepoModel, + _path: &str, + oid: &str, + ) -> Result, AppError> { + use git::rpc::proto as p; + use git::rpc::proto::blob_service_client::BlobServiceClient; + use crate::git::rpc_err; + + let mut client = BlobServiceClient::new(self.git.clone()); + let resp = client + .blob_load(tonic::Request::new(p::BlobLoadRequest { + repo_id: repo.id.to_string(), + id: Some(p::ObjectId { value: oid.to_string() }), + path: String::new(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let content = String::from_utf8_lossy(&resp.blob).to_string(); + if content.is_empty() { + return Ok(None); + } + + let html = comrak::markdown_to_html(&content, &comrak::ComrakOptions::default()); + + Ok(Some(super::readme::ReadmeDto { + content, + html, + })) + } +} diff --git a/lib/service/git/refs.rs b/lib/service/git/refs.rs new file mode 100644 index 0000000..e861daa --- /dev/null +++ b/lib/service/git/refs.rs @@ -0,0 +1,76 @@ +use db::sqlx; +use model::repos::RepoRefModel; +use session::Session; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct GitRefResponse { + pub name: String, + pub kind: String, + pub target_sha: String, + pub is_default: bool, + pub is_protected: bool, +} + +impl AppService { + pub async fn git_ref_list_by_name(&self, ctx: &Session, wk: &str, repo: &str) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_ref_list(repo.id).await + } + pub async fn git_ref_get_by_name(&self, ctx: &Session, wk: &str, repo: &str, ref_name: &str) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_ref_get(repo.id, ref_name).await + } +} + +impl AppService { + pub async fn git_ref_list( + &self, + repo_id: Uuid, + ) -> Result, AppError> { + let refs = sqlx::query_as::<_, RepoRefModel>( + "SELECT id, repo, name, kind, target_sha, is_default, is_protected, \ + created_at, updated_at FROM repo_ref WHERE repo = $1 ORDER BY name", + ) + .bind(repo_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(refs.into_iter().map(|r| GitRefResponse { + name: r.name, + kind: r.kind, + target_sha: r.target_sha, + is_default: r.is_default, + is_protected: r.is_protected, + }).collect()) + } + + pub async fn git_ref_get( + &self, + repo_id: Uuid, + ref_name: &str, + ) -> Result { + let r = sqlx::query_as::<_, RepoRefModel>( + "SELECT id, repo, name, kind, target_sha, is_default, is_protected, \ + created_at, updated_at FROM repo_ref WHERE repo = $1 AND name = $2", + ) + .bind(repo_id) + .bind(ref_name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("ref not found".to_string()))?; + + Ok(GitRefResponse { + name: r.name, + kind: r.kind, + target_sha: r.target_sha, + is_default: r.is_default, + is_protected: r.is_protected, + }) + } +} diff --git a/lib/service/git/release.rs b/lib/service/git/release.rs new file mode 100644 index 0000000..967d8b9 --- /dev/null +++ b/lib/service/git/release.rs @@ -0,0 +1,386 @@ +use chrono::Utc; +use db::sqlx; +use model::repos::{ + RepoReleaseAssetModel, RepoReleaseModel, +}; +use session::Session; +use uuid::Uuid; + +use crate::error::AppError; +use crate::AppService; + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct ReleaseResponse { + pub id: Uuid, + pub tag_name: String, + pub target_commit_sha: String, + pub name: String, + pub body: Option, + pub draft: bool, + pub prerelease: bool, + pub author: Uuid, + pub assets: Vec, + pub published_at: Option>, + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] +pub struct ReleaseAssetResponse { + pub id: Uuid, + pub name: String, + pub content_type: Option, + pub size: i64, + pub download_count: i64, + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct CreateRelease { + pub tag_name: String, + pub target_commit_sha: Option, + pub name: String, + pub body: Option, + #[serde(default)] + pub draft: bool, + #[serde(default)] + pub prerelease: bool, +} + +#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] +pub struct UpdateRelease { + pub tag_name: Option, + pub name: Option, + pub body: Option>, + pub draft: Option, + pub prerelease: Option, +} + +impl AppService { + pub async fn git_release_list( + &self, + repo_id: Uuid, + ) -> Result, AppError> { + let releases = sqlx::query_as::<_, RepoReleaseModel>( + "SELECT id, repo, tag_name, target_commit_sha, name, body, \ + draft, prerelease, author, published_at, created_at, updated_at \ + FROM repo_release WHERE repo = $1 ORDER BY created_at DESC LIMIT 50", + ) + .bind(repo_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut result = Vec::with_capacity(releases.len()); + for r in releases { + let assets = self.git_release_assets(r.id).await?; + result.push(release_to_response(r, assets)); + } + Ok(result) + } + + pub async fn git_release_get( + &self, + repo_id: Uuid, + release_id: Uuid, + ) -> Result { + let r = sqlx::query_as::<_, RepoReleaseModel>( + "SELECT id, repo, tag_name, target_commit_sha, name, body, \ + draft, prerelease, author, published_at, created_at, updated_at \ + FROM repo_release WHERE id = $1 AND repo = $2", + ) + .bind(release_id) + .bind(repo_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("release not found".to_string()))?; + + let assets = self.git_release_assets(r.id).await?; + Ok(release_to_response(r, assets)) + } + + pub async fn git_release_get_by_tag( + &self, + repo_id: Uuid, + tag: &str, + ) -> Result { + let r = sqlx::query_as::<_, RepoReleaseModel>( + "SELECT id, repo, tag_name, target_commit_sha, name, body, \ + draft, prerelease, author, published_at, created_at, updated_at \ + FROM repo_release WHERE repo = $1 AND tag_name = $2", + ) + .bind(repo_id) + .bind(tag) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("release not found".to_string()))?; + + let assets = self.git_release_assets(r.id).await?; + Ok(release_to_response(r, assets)) + } + + pub async fn git_release_create( + &self, + _session: &Session, + repo_id: Uuid, + user_id: Uuid, + params: CreateRelease, + ) -> Result { + let id = Uuid::now_v7(); + let now = Utc::now(); + let published_at = if params.draft { None } else { Some(now) }; + + let target = if let Some(ref sha) = params.target_commit_sha { + if !sha.trim().is_empty() { sha.clone() } else { self.default_branch_sha(repo_id).await? } + } else { + self.default_branch_sha(repo_id).await? + }; + + let r = sqlx::query_as::<_, RepoReleaseModel>( + "INSERT INTO repo_release (id, repo, tag_name, target_commit_sha, name, body, \ + draft, prerelease, author, published_at, created_at, updated_at) \ + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$11) \ + RETURNING id, repo, tag_name, target_commit_sha, name, body, \ + draft, prerelease, author, published_at, created_at, updated_at", + ) + .bind(id) + .bind(repo_id) + .bind(¶ms.tag_name) + .bind(&target) + .bind(¶ms.name) + .bind(¶ms.body) + .bind(params.draft) + .bind(params.prerelease) + .bind(user_id) + .bind(published_at) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(release_to_response(r, Vec::new())) + } + + pub async fn git_release_update( + &self, + repo_id: Uuid, + release_id: Uuid, + params: UpdateRelease, + ) -> Result { + let existing = self.git_release_get(repo_id, release_id).await?; + let now = Utc::now(); + + let tag_name = params.tag_name.unwrap_or(existing.tag_name); + let name = params.name.unwrap_or(existing.name); + let body = params.body.unwrap_or(existing.body); + let draft = params.draft.unwrap_or(existing.draft); + let prerelease = params.prerelease.unwrap_or(existing.prerelease); + + let published_at = if draft { None } else { existing.published_at.or(Some(now)) }; + + let r = sqlx::query_as::<_, RepoReleaseModel>( + "UPDATE repo_release SET tag_name=$1, name=$2, body=$3, draft=$4, \ + prerelease=$5, published_at=$6, updated_at=$7 \ + WHERE id=$8 AND repo=$9 \ + RETURNING id, repo, tag_name, target_commit_sha, name, body, \ + draft, prerelease, author, published_at, created_at, updated_at", + ) + .bind(&tag_name) + .bind(&name) + .bind(&body) + .bind(draft) + .bind(prerelease) + .bind(published_at) + .bind(now) + .bind(release_id) + .bind(repo_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let assets = self.git_release_assets(r.id).await?; + Ok(release_to_response(r, assets)) + } + + pub async fn git_release_delete( + &self, + repo_id: Uuid, + release_id: Uuid, + ) -> Result<(), AppError> { + let rows = sqlx::query( + "DELETE FROM repo_release WHERE id = $1 AND repo = $2", + ) + .bind(release_id) + .bind(repo_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + if rows.rows_affected() == 0 { + return Err(AppError::NotFound("release not found".to_string())); + } + Ok(()) + } + + pub async fn git_release_delete_by_tag( + &self, + repo_id: Uuid, + tag: &str, + ) -> Result<(), AppError> { + let release_id: Option = sqlx::query_scalar( + "SELECT id FROM repo_release WHERE repo = $1 AND tag_name = $2", + ) + .bind(repo_id) + .bind(tag) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let release_id = release_id + .ok_or_else(|| AppError::NotFound("release not found".to_string()))?; + self.git_release_delete(repo_id, release_id).await + } + + async fn git_release_assets( + &self, + release_id: Uuid, + ) -> Result, AppError> { + let assets = sqlx::query_as::<_, RepoReleaseAssetModel>( + "SELECT id, release_id, name, content_type, size, download_count, \ + storage_path, uploader, created_at \ + FROM repo_release_asset WHERE release_id = $1 ORDER BY created_at", + ) + .bind(release_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(assets.into_iter().map(|a| ReleaseAssetResponse { + id: a.id, + name: a.name, + content_type: a.content_type, + size: a.size, + download_count: a.download_count, + created_at: a.created_at, + }).collect()) + } + + pub async fn git_release_asset_create( + &self, + repo_id: Uuid, + release_id: Uuid, + user_id: Uuid, + name: String, + content_type: Option, + size: i64, + storage_path: String, + ) -> Result { + let _ = self.git_release_get(repo_id, release_id).await?; + + let id = Uuid::now_v7(); + let now = Utc::now(); + let a = sqlx::query_as::<_, RepoReleaseAssetModel>( + "INSERT INTO repo_release_asset (id, release_id, name, content_type, size, \ + download_count, storage_path, uploader, created_at) \ + VALUES ($1,$2,$3,$4,$5,0,$6,$7,$8) RETURNING *", + ) + .bind(id) + .bind(release_id) + .bind(&name) + .bind(&content_type) + .bind(size) + .bind(&storage_path) + .bind(user_id) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(ReleaseAssetResponse { + id: a.id, + name: a.name, + content_type: a.content_type, + size: a.size, + download_count: a.download_count, + created_at: a.created_at, + }) + } + + pub async fn git_release_asset_delete( + &self, + repo_id: Uuid, + release_id: Uuid, + asset_id: Uuid, + ) -> Result<(), AppError> { + let _ = self.git_release_get(repo_id, release_id).await?; + sqlx::query( + "DELETE FROM repo_release_asset WHERE id = $1 AND release_id = $2", + ) + .bind(asset_id) + .bind(release_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } +} + +impl AppService { + pub async fn git_release_list_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_list(repo.id).await + } + pub async fn git_release_get_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, id: Uuid) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_get(repo.id, id).await + } + pub async fn git_release_get_by_tag_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, tag: &str) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_get_by_tag(repo.id, tag).await + } + pub async fn git_release_create_by_name(&self, ctx: &Session, user_id: Uuid, wk: &str, repo: &str, params: CreateRelease) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_create(ctx, repo.id, user_id, params).await + } + pub async fn git_release_update_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, id: Uuid, params: UpdateRelease) -> Result { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_update(repo.id, id, params).await + } + pub async fn git_release_delete_by_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, id: Uuid) -> Result<(), AppError> { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_delete(repo.id, id).await + } + pub async fn git_release_delete_by_tag_name(&self, ctx: &Session, _user_id: Uuid, wk: &str, repo: &str, tag: &str) -> Result<(), AppError> { + let repo = self.git_require_member(ctx, wk, repo).await?; + self.git_release_delete_by_tag(repo.id, tag).await + } +} + +impl AppService { + async fn default_branch_sha(&self, repo_id: Uuid) -> Result { + sqlx::query_scalar("SELECT target_sha FROM repo_ref WHERE repo = $1 AND is_default = true") + .bind(repo_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::BadRequest("no default branch to target".to_string())) + } +} + +fn release_to_response( + r: RepoReleaseModel, + assets: Vec, +) -> ReleaseResponse { + ReleaseResponse { + id: r.id, + tag_name: r.tag_name, + target_commit_sha: r.target_commit_sha, + name: r.name, + body: r.body, + draft: r.draft, + prerelease: r.prerelease, + author: r.author, + assets, + published_at: r.published_at, + created_at: r.created_at, + } +} diff --git a/lib/service/git/repo.rs b/lib/service/git/repo.rs new file mode 100644 index 0000000..4315b68 --- /dev/null +++ b/lib/service/git/repo.rs @@ -0,0 +1,441 @@ +use db::sqlx; +use git::rpc::{proto as p, proto::init_service_client::InitServiceClient}; +use model::repos::{RepoModel, RepoTopicModel}; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, git::rpc_err, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct RepoResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, + pub description: Option, + pub default_branch: String, + pub visibility: String, + #[schema(value_type = i64)] + pub size_bytes: i64, + pub is_archived: bool, + pub is_template: bool, + pub is_mirror: bool, + #[schema(value_type = String)] + pub created_by: uuid::Uuid, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +pub(crate) fn repo_response(repo: RepoModel) -> RepoResponse { + RepoResponse { + id: repo.id, + name: repo.name, + description: repo.description, + default_branch: repo.default_branch, + visibility: repo.visibility, + size_bytes: repo.size_bytes, + is_archived: repo.is_archived, + is_template: repo.is_template, + is_mirror: repo.is_mirror, + created_by: repo.created_by, + created_at: repo.created_at, + updated_at: repo.updated_at, + } +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateRepo { + pub name: Option, + pub description: Option, + pub default_branch: Option, + pub visibility: Option, + pub is_archived: Option, + pub is_template: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct TransferRepo { + pub target_workspace: String, +} + +#[derive(Debug, Clone, Deserialize, utoipa::IntoParams)] +pub struct RepoFilter { + pub visibility: Option, + pub is_archived: Option, + pub search: Option, +} + +impl AppService { + pub async fn repo_list( + &self, + ctx: &Session, + wk_name: &str, + filter: RepoFilter, + pagination: Pagination, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let mut param_idx = 2; + let mut sql = String::from( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at \ + FROM repo WHERE wk = $1 AND deleted_at IS NULL", + ); + + if filter.visibility.is_some() { + sql.push_str(&format!(" AND visibility = ${param_idx}")); + param_idx += 1; + } + if filter.is_archived.is_some() { + sql.push_str(&format!(" AND is_archived = ${param_idx}")); + param_idx += 1; + } + if filter.search.is_some() { + sql.push_str(&format!(" AND name ILIKE ${param_idx}")); + param_idx += 1; + } + + let offset_idx = param_idx; + let limit_idx = param_idx + 1; + sql.push_str(&format!( + " ORDER BY name ASC OFFSET ${offset_idx} LIMIT ${limit_idx}" + )); + + let mut q = sqlx::query_as::<_, RepoModel>(sqlx::AssertSqlSafe(sql)) + .bind(wk.id); + + if let Some(vis) = &filter.visibility { + q = q.bind(vis.clone()); + } + if let Some(archived) = filter.is_archived { + q = q.bind(archived); + } + if let Some(search) = &filter.search { + q = q.bind(format!("%{}%", search)); + } + + q = q + .bind(pagination.offset() as i64) + .bind(pagination.limit() as i64); + + let rows = q + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(repo_response).collect()) + } + + pub async fn repo_get( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + Ok(repo_response(repo)) + } + + pub async fn repo_update( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: UpdateRepo, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let mut repo = self.repo_resolve(wk.id, repo_name).await?; + + let next_name = match params.name { + Some(name) => { + let name = name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest( + "repo name is required".to_string(), + )); + } + if name != repo.name { + let existing = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL)", + ) + .bind(wk.id) + .bind(name) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + if existing { + return Err(AppError::RepoNameAlreadyExists); + } + Some(name.to_string()) + } else { + None + } + } + None => None, + }; + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + if let Some(next_name) = &next_name { + sqlx::query( + "INSERT INTO repo_history_name (id, repo, name, changed_by, created_at) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(repo.id) + .bind(&repo.name) + .bind(user_uid) + .bind(chrono::Utc::now()) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + repo.name = next_name.clone(); + } + + if let Some(desc) = params.description { + repo.description = if desc.is_empty() { None } else { Some(desc) }; + } + let mut default_branch_changed = false; + if let Some(branch) = params.default_branch { + repo.default_branch = branch; + default_branch_changed = true; + } + if let Some(vis) = params.visibility { + repo.visibility = vis; + } + if let Some(archived) = params.is_archived { + repo.is_archived = archived; + } + if let Some(template) = params.is_template { + repo.is_template = template; + } + + let updated = sqlx::query_as::<_, RepoModel>( + "UPDATE repo SET name = $1, description = $2, default_branch = $3, \ + visibility = $4, is_archived = $5, is_template = $6, updated_at = $7 \ + WHERE id = $8 \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(&repo.name) + .bind(&repo.description) + .bind(&repo.default_branch) + .bind(&repo.visibility) + .bind(repo.is_archived) + .bind(repo.is_template) + .bind(chrono::Utc::now()) + .bind(repo.id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + if default_branch_changed { + let mut client = InitServiceClient::new(self.git.clone()); + let _ = client + .set_default_branch(tonic::Request::new( + p::SetDefaultBranchRequest { + repo_id: repo.id.to_string(), + branch_name: repo.default_branch.clone(), + }, + )) + .await + .map_err(rpc_err); + self.queue_sync(repo.id).await; + } + + Ok(repo_response(updated)) + } + + pub async fn repo_archive( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + self.repo_update( + ctx, + wk_name, + repo_name, + UpdateRepo { + name: None, + description: None, + default_branch: None, + visibility: None, + is_archived: Some(true), + is_template: None, + }, + ) + .await + } + + pub async fn repo_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_owner(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + sqlx::query("UPDATE repo SET deleted_at = $1 WHERE id = $2") + .bind(chrono::Utc::now()) + .bind(repo.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub async fn repo_transfer( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: TransferRepo, + ) -> Result { + let user_uid = session_user(ctx)?; + let src_wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_owner(src_wk.id, user_uid).await?; + let repo = self.repo_resolve(src_wk.id, repo_name).await?; + + let target_wk = + self.workspace_resolve(¶ms.target_workspace).await?; + self.workspace_require_admin(target_wk.id, user_uid).await?; + + let existing = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL)", + ) + .bind(target_wk.id) + .bind(&repo.name) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + if existing { + return Err(AppError::Conflict( + "repo name already exists in target workspace".to_string(), + )); + } + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + sqlx::query( + "INSERT INTO repo_history_name (id, repo, name, changed_by, created_at) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(repo.id) + .bind(&format!("{}/{}", src_wk.name, repo.name)) + .bind(user_uid) + .bind(chrono::Utc::now()) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let updated = sqlx::query_as::<_, RepoModel>( + "UPDATE repo SET wk = $1, updated_at = $2 WHERE id = $3 \ + RETURNING id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at", + ) + .bind(target_wk.id) + .bind(chrono::Utc::now()) + .bind(repo.id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + Ok(repo_response(updated)) + } + + pub async fn repo_topics( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let rows = sqlx::query_as::<_, RepoTopicModel>( + "SELECT repo, topic, created_at FROM repo_topic WHERE repo = $1", + ) + .bind(repo.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(|r| r.topic).collect()) + } + + pub async fn repo_update_topics( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + topics: Vec, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + sqlx::query("DELETE FROM repo_topic WHERE repo = $1") + .bind(repo.id) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + for topic in &topics { + let topic = topic.trim(); + if topic.is_empty() { + continue; + } + sqlx::query( + "INSERT INTO repo_topic (repo, topic, created_at) VALUES ($1, $2, $3)", + ) + .bind(repo.id) + .bind(topic) + .bind(chrono::Utc::now()) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + txn.commit().await.map_err(|_| AppError::TxnError)?; + Ok(topics) + } + + /// CMDK BFF: list repo names + descriptions for a workspace. + pub async fn repo_list_inner( + &self, + wk_name: &str, + ) -> Result)>, AppError> { + let wk = sqlx::query_as::<_, (uuid::Uuid,)>( + "SELECT id FROM workspace WHERE name = $1" + ) + .bind(wk_name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("workspace not found".to_string()))?; + + let rows = sqlx::query_as::<_, (String, Option)>( + "SELECT name, description FROM repo WHERE wk = $1 AND deleted_at IS NULL ORDER BY updated_at DESC" + ) + .bind(wk.0) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(rows) + } +} diff --git a/lib/service/git/star.rs b/lib/service/git/star.rs new file mode 100644 index 0000000..089b034 --- /dev/null +++ b/lib/service/git/star.rs @@ -0,0 +1,98 @@ +use db::sqlx; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +impl AppService { + pub async fn git_repo_star( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + sqlx::query( + "INSERT INTO repo_star (repo, \"user\", created_at) VALUES ($1, $2, $3) \ + ON CONFLICT (repo, \"user\") DO NOTHING", + ) + .bind(repo.id) + .bind(user_uid) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM repo_star WHERE repo = $1", + ) + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(serde_json::json!({ "starred": true, "count": count.0 })) + } + + pub async fn git_repo_unstar( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + sqlx::query( + "DELETE FROM repo_star WHERE repo = $1 AND \"user\" = $2", + ) + .bind(repo.id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM repo_star WHERE repo = $1", + ) + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(serde_json::json!({ "starred": false, "count": count.0 })) + } + + pub async fn git_repo_star_status( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let exists: (bool,) = sqlx::query_as( + "SELECT EXISTS(SELECT 1 FROM repo_star WHERE repo = $1 AND \"user\" = $2)", + ) + .bind(repo.id) + .bind(user_uid) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM repo_star WHERE repo = $1", + ) + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(serde_json::json!({ + "starred": exists.0, + "count": count.0, + })) + } +} diff --git a/lib/service/git/tag.rs b/lib/service/git/tag.rs new file mode 100644 index 0000000..096bae0 --- /dev/null +++ b/lib/service/git/tag.rs @@ -0,0 +1,157 @@ +use git::rpc::{proto as p, proto::tag_service_client::TagServiceClient}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_tag_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + pagination: Pagination, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let mut resp = client + .tag_list(tonic::Request::new(p::TagListRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + + let offset = pagination.offset() as usize; + let limit = pagination.limit() as usize; + if offset > 0 || resp.tags.len() > limit { + let start = offset.min(resp.tags.len()); + let end = (start + limit).min(resp.tags.len()); + resp.tags = resp.tags.drain(start..end).collect(); + } + Ok(resp) + } + + pub async fn git_tag_info( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + name: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let resp = client + .tag_info(tonic::Request::new(p::TagInfoRequest { + repo_id: repo.id.to_string(), + name, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_tag_summary( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let resp = client + .tag_summary(tonic::Request::new(p::TagSummaryRequest { + repo_id: repo.id.to_string(), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_tag_init( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::TagInitParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let resp = client + .tag_init(tonic::Request::new(p::TagInitRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } + + pub async fn git_tag_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::TagDeleteParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let resp = client + .tag_delete(tonic::Request::new(p::TagDeleteRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } + + pub async fn git_tag_rename( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::TagRenameParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let resp = client + .tag_rename(tonic::Request::new(p::TagRenameRequest { + repo_id: repo.id.to_string(), + params: Some(params), + })) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } + + pub async fn git_tag_update_message( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: p::TagUpdateMessageParams, + ) -> Result { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + let mut client = TagServiceClient::new(self.git.clone()); + let resp = client + .tag_update_message(tonic::Request::new( + p::TagUpdateMessageRequest { + repo_id: repo.id.to_string(), + params: Some(params), + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + self.queue_sync(repo.id).await; + Ok(resp) + } +} diff --git a/lib/service/git/tree.rs b/lib/service/git/tree.rs new file mode 100644 index 0000000..033dbb3 --- /dev/null +++ b/lib/service/git/tree.rs @@ -0,0 +1,98 @@ +use git::rpc::{proto as p, proto::tree_service_client::TreeServiceClient}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub async fn git_tree_entries( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + base_path: String, + last: bool, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TreeServiceClient::new(self.git.clone()); + let resp = client + .tree_entries(tonic::Request::new(p::TreeEntriesRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: oid }), + base_path, + last, + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_tree_entry_by_path( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + tree_oid: String, + path: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TreeServiceClient::new(self.git.clone()); + let resp = client + .tree_entry_by_path(tonic::Request::new( + p::TreeEntryByPathRequest { + repo_id: repo.id.to_string(), + tree_oid: Some(p::ObjectId { value: tree_oid }), + path, + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_tree_entry_by_path_from_commit( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + commit_oid: String, + path: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TreeServiceClient::new(self.git.clone()); + let resp = client + .tree_entry_by_path_from_commit(tonic::Request::new( + p::TreeEntryByPathFromCommitRequest { + repo_id: repo.id.to_string(), + commit_oid: Some(p::ObjectId { value: commit_oid }), + path, + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn git_resolve_tree( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + oid: String, + ) -> Result { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let mut client = TreeServiceClient::new(self.git.clone()); + let resp = client + .resolve_tree(tonic::Request::new(p::ResolveTreeRequest { + repo_id: repo.id.to_string(), + oid: Some(p::ObjectId { value: oid }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } +} diff --git a/lib/service/git/watch.rs b/lib/service/git/watch.rs new file mode 100644 index 0000000..b0dd850 --- /dev/null +++ b/lib/service/git/watch.rs @@ -0,0 +1,105 @@ +use db::sqlx; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +impl AppService { + pub async fn git_repo_watch( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + level: Option, + ) -> Result { + let user_uid = session_user(ctx)?; + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + let level = level.unwrap_or_else(|| "participating".to_string()); + + sqlx::query( + "INSERT INTO repo_watch (repo, \"user\", \"level\", created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5) \ + ON CONFLICT (repo, \"user\") DO UPDATE SET level = $3, updated_at = $5", + ) + .bind(repo.id) + .bind(user_uid) + .bind(&level) + .bind(chrono::Utc::now()) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM repo_watch WHERE repo = $1", + ) + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(serde_json::json!({ "watching": true, "count": count.0, "level": level })) + } + + pub async fn git_repo_unwatch( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + sqlx::query( + "DELETE FROM repo_watch WHERE repo = $1 AND \"user\" = $2", + ) + .bind(repo.id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM repo_watch WHERE repo = $1", + ) + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(serde_json::json!({ "watching": false, "count": count.0 })) + } + + pub async fn git_repo_watch_status( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let watch: Option<(bool, String)> = sqlx::query_as( + "SELECT EXISTS(SELECT 1 FROM repo_watch WHERE repo = $1 AND \"user\" = $2), \ + COALESCE((SELECT level FROM repo_watch WHERE repo = $1 AND \"user\" = $2), '')", + ) + .bind(repo.id) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM repo_watch WHERE repo = $1", + ) + .bind(repo.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(serde_json::json!({ + "watching": watch.as_ref().map(|(w, _)| *w).unwrap_or(false), + "count": count.0, + "level": watch.map(|(_, l)| l).unwrap_or_default(), + })) + } +} diff --git a/lib/service/git/webhook.rs b/lib/service/git/webhook.rs new file mode 100644 index 0000000..4cf068c --- /dev/null +++ b/lib/service/git/webhook.rs @@ -0,0 +1,365 @@ +use db::sqlx; +use hmac::{Hmac, KeyInit, Mac}; +use model::repos::{RepoWebhookDeliveryModel, RepoWebhookModel}; +use serde::{Deserialize, Serialize}; +use session::Session; +use sha2::Sha256; + +use crate::{AppService, Pagination, error::AppError, session_user}; + +type HmacSha256 = Hmac; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum WebhookEvent { + Push, + PushBranch, + PushTag, + Issue, + PullRequest, + Comment, + Release, + Fork, + Wiki, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WebhookResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + #[schema(value_type = String)] + pub repo: uuid::Uuid, + pub url: String, + pub events: Vec, + pub active: bool, + #[schema(value_type = String)] + pub created_by: uuid::Uuid, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +pub(crate) fn webhook_response(w: RepoWebhookModel) -> WebhookResponse { + WebhookResponse { + id: w.id, + repo: w.repo, + url: w.url, + events: w + .events + .split('.') + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect(), + active: w.active, + created_by: w.created_by, + created_at: w.created_at, + updated_at: w.updated_at, + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WebhookDeliveryResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + #[schema(value_type = String)] + pub webhook: uuid::Uuid, + pub event: String, + pub response_status: Option, + pub error: Option, + #[schema(value_type = Option)] + pub delivered_at: Option>, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +pub(crate) fn delivery_response( + d: RepoWebhookDeliveryModel, +) -> WebhookDeliveryResponse { + WebhookDeliveryResponse { + id: d.id, + webhook: d.webhook, + event: d.event, + response_status: d.response_status, + error: d.error, + delivered_at: d.delivered_at, + created_at: d.created_at, + } +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateWebhook { + pub url: String, + pub secret: Option, + pub events: Vec, + pub active: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateWebhook { + pub url: Option, + pub secret: Option, + pub events: Option>, + pub active: Option, +} + +impl AppService { + pub async fn repo_webhook_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + pagination: Pagination, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let rows = sqlx::query_as::<_, RepoWebhookModel>( + "SELECT id, repo, url, secret_hash, events, active, created_by, created_at, updated_at \ + FROM repo_webhook WHERE repo = $1 \ + ORDER BY created_at DESC OFFSET $2 LIMIT $3", + ) + .bind(repo.id) + .bind(pagination.offset() as i64) + .bind(pagination.limit() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(webhook_response).collect()) + } + + pub async fn repo_webhook_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: CreateWebhook, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let url = params.url.trim(); + if url.is_empty() { + return Err(AppError::BadRequest("url is required".to_string())); + } + + let secret_hash = params.secret.map(|s| { + let mut mac = HmacSha256::new_from_slice(b"gitdata-webhook-secret") + .expect("HMAC can take key of any size"); + mac.update(s.as_bytes()); + let result = mac.finalize(); + let code_bytes = result.into_bytes(); + hex::encode(code_bytes) + }); + + let events = params.events.join("."); + let active = params.active.unwrap_or(true); + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + + let row = sqlx::query_as::<_, RepoWebhookModel>( + "INSERT INTO repo_webhook \ + (id, repo, url, secret_hash, events, active, created_by, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8) \ + RETURNING id, repo, url, secret_hash, events, active, created_by, created_at, updated_at", + ) + .bind(id) + .bind(repo.id) + .bind(url) + .bind(&secret_hash) + .bind(&events) + .bind(active) + .bind(user_uid) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(webhook_response(row)) + } + + pub async fn repo_webhook_update( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + webhook_id: uuid::Uuid, + params: UpdateWebhook, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let existing = sqlx::query_as::<_, RepoWebhookModel>( + "SELECT id, repo, url, secret_hash, events, active, created_by, created_at, updated_at \ + FROM repo_webhook WHERE id = $1 AND repo = $2", + ) + .bind(webhook_id) + .bind(repo.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("webhook not found".to_string()))?; + + let url = params.url.unwrap_or(existing.url); + let events = params + .events + .map(|e| e.join(".")) + .unwrap_or(existing.events); + let active = params.active.unwrap_or(existing.active); + + let secret_hash = if let Some(secret) = params.secret { + let mut mac = HmacSha256::new_from_slice(b"gitdata-webhook-secret") + .expect("HMAC can take key of any size"); + mac.update(secret.as_bytes()); + let result = mac.finalize(); + let code_bytes = result.into_bytes(); + Some(hex::encode(code_bytes)) + } else { + existing.secret_hash + }; + + let row = sqlx::query_as::<_, RepoWebhookModel>( + "UPDATE repo_webhook SET url = $1, secret_hash = $2, events = $3, \ + active = $4, updated_at = $5 WHERE id = $6 \ + RETURNING id, repo, url, secret_hash, events, active, created_by, created_at, updated_at", + ) + .bind(&url) + .bind(&secret_hash) + .bind(&events) + .bind(active) + .bind(chrono::Utc::now()) + .bind(webhook_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(webhook_response(row)) + } + + pub async fn repo_webhook_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + webhook_id: uuid::Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let result = + sqlx::query("DELETE FROM repo_webhook WHERE id = $1 AND repo = $2") + .bind(webhook_id) + .bind(repo.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if result.rows_affected() == 0 { + return Err(AppError::NotFound("webhook not found".to_string())); + } + + Ok(()) + } + + pub async fn repo_webhook_deliveries( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + webhook_id: uuid::Uuid, + pagination: Pagination, + ) -> Result, AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + + let rows = sqlx::query_as::<_, RepoWebhookDeliveryModel>( + "SELECT id, repo, webhook, event, request_headers, request_body, \ + response_status, response_headers, response_body, error, delivered_at, created_at \ + FROM repo_webhook_delivery WHERE webhook = $1 AND repo = $2 \ + ORDER BY created_at DESC OFFSET $3 LIMIT $4", + ) + .bind(webhook_id) + .bind(repo.id) + .bind(pagination.offset() as i64) + .bind(pagination.limit() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(delivery_response).collect()) + } + pub async fn trigger_webhook_event( + &self, + repo_id: uuid::Uuid, + event: &str, + payload: serde_json::Value, + ) -> Result<(), AppError> { + let webhooks = sqlx::query_as::<_, RepoWebhookModel>( + "SELECT id, repo, url, secret_hash, events, active, created_by, created_at, updated_at \ + FROM repo_webhook WHERE repo = $1 AND active = true", + ) + .bind(repo_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + for wh in webhooks { + let subscribed_events: Vec<&str> = + wh.events.split('.').filter(|s| !s.is_empty()).collect(); + let matches = subscribed_events.iter().any(|e| { + *e == event + || *e == "push" + && (event == "push_branch" || event == "push_tag") + }); + + if !matches { + continue; + } + + let delivery_id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO repo_webhook_delivery \ + (id, repo, webhook, event, request_headers, request_body, \ + response_status, response_headers, response_body, error, delivered_at, created_at) \ + VALUES ($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL, NULL, NULL, $5)", + ) + .bind(delivery_id) + .bind(repo_id) + .bind(wh.id) + .bind(event) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let task = git::sync::webhook::WebhookDeliveryTask { + id: delivery_id.to_string(), + webhook_id: wh.id.to_string(), + repo_id: repo_id.to_string(), + event: event.to_string(), + url: wh.url.clone(), + secret: wh.secret_hash.clone(), + payload: payload.clone(), + created_at: now, + retry_count: 0, + }; + + if let Err(e) = + git::sync::webhook::enqueue_delivery(task, &self.redis_pool) + .await + { + tracing::error!( + repo_id = %repo_id, + webhook_id = %wh.id, + error = %e, + "failed to enqueue webhook delivery" + ); + } + } + + Ok(()) + } +} diff --git a/lib/service/issues/assignee.rs b/lib/service/issues/assignee.rs new file mode 100644 index 0000000..7898b18 --- /dev/null +++ b/lib/service/issues/assignee.rs @@ -0,0 +1,118 @@ +use db::sqlx; +use model::users::UserModel; +use serde::Deserialize; +use session::Session; + +use super::types::{IssueAuthor, issue_author}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AssignIssueUser { + pub username: String, +} + +impl AppService { + pub async fn issue_assign( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: AssignIssueUser, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + let target = self + .users_find_active_user_by_username(¶ms.username) + .await?; + self.workspace_require_member(wk.id, target.id).await?; + + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO issue_assignee (issue, \"user\", assigned_by, created_at) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT (issue, \"user\") DO NOTHING", + ) + .bind(issue.id) + .bind(target.id) + .bind(user_uid) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'assigned', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(¶ms.username) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_assignees(issue.id).await + } + + pub async fn issue_unassign( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + username: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + let target = self.users_find_active_user_by_username(username).await?; + + sqlx::query( + "DELETE FROM issue_assignee WHERE issue = $1 AND \"user\" = $2", + ) + .bind(issue.id) + .bind(target.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'unassigned', $4, NULL, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(username) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_assignees(issue.id).await + } + + pub(crate) async fn issue_assignees( + &self, + issue_id: uuid::Uuid, + ) -> Result, AppError> { + let rows = sqlx::query_as::<_, UserModel>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at \ + FROM issue_assignee ia \ + INNER JOIN \"user\" u ON u.id = ia.\"user\" \ + WHERE ia.issue = $1 AND u.allow_use = true \ + ORDER BY ia.created_at ASC", + ) + .bind(issue_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(issue_author).collect()) + } +} diff --git a/lib/service/issues/binding.rs b/lib/service/issues/binding.rs new file mode 100644 index 0000000..88529a1 --- /dev/null +++ b/lib/service/issues/binding.rs @@ -0,0 +1,261 @@ +use db::sqlx; +use model::{pull_request::PullRequestModel, repos::RepoModel}; +use serde::Deserialize; +use session::Session; + +use super::types::{ + IssuePullRequestResponse, IssueRepoResponse, issue_pr_response, + issue_repo_response, +}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct BindIssueRepo { + pub repo_id: uuid::Uuid, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct BindIssuePullRequest { + pub pull_request_id: uuid::Uuid, +} + +impl AppService { + pub async fn issue_bind_repo( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: BindIssueRepo, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let repo = sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at \ + FROM repo WHERE id = $1 AND wk = $2 AND deleted_at IS NULL", + ) + .bind(params.repo_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::RepoNotFound)?; + + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO issue_repo (issue, repo, created_at) VALUES ($1, $2, $3) \ + ON CONFLICT (issue, repo) DO NOTHING", + ) + .bind(issue.id) + .bind(repo.id) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'linked_repo', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&repo.name) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_repos(issue.id).await + } + + pub async fn issue_unbind_repo( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + repo_id: uuid::Uuid, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let repo = sqlx::query_as::<_, RepoModel>( + "SELECT id, wk, name, description, default_branch, visibility, size_bytes, \ + is_archived, is_template, is_mirror, created_by, storage_path, created_at, updated_at, deleted_at \ + FROM repo WHERE id = $1 AND wk = $2", + ) + .bind(repo_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::RepoNotFound)?; + + sqlx::query("DELETE FROM issue_repo WHERE issue = $1 AND repo = $2") + .bind(issue.id) + .bind(repo_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'unlinked_repo', $4, NULL, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&repo.name) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_repos(issue.id).await + } + + pub async fn issue_bind_pull_request( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: BindIssuePullRequest, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let pr = sqlx::query_as::<_, PullRequestModel>( + "SELECT id, repo, number, title, body, state, draft, author, \ + source_repo, source_branch, source_sha, target_branch, target_sha, \ + merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at \ + FROM pull_request WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(params.pull_request_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::PullRequestNotFound)?; + + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO issue_pull_request (issue, pull_request, created_at) VALUES ($1, $2, $3) \ + ON CONFLICT (issue, pull_request) DO NOTHING", + ) + .bind(issue.id) + .bind(pr.id) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'linked_pull_request', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(format!("#{}", pr.number)) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_pull_requests(issue.id).await + } + + pub async fn issue_unbind_pull_request( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + pull_request_id: uuid::Uuid, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let pr = sqlx::query_as::<_, PullRequestModel>( + "SELECT id, repo, number, title, body, state, draft, author, \ + source_repo, source_branch, source_sha, target_branch, target_sha, \ + merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at \ + FROM pull_request WHERE id = $1", + ) + .bind(pull_request_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::PullRequestNotFound)?; + + sqlx::query("DELETE FROM issue_pull_request WHERE issue = $1 AND pull_request = $2") + .bind(issue.id) + .bind(pull_request_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'unlinked_pull_request', $4, NULL, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(format!("#{}", pr.number)) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_pull_requests(issue.id).await + } + + pub(crate) async fn issue_repos( + &self, + issue_id: uuid::Uuid, + ) -> Result, AppError> { + let repos = sqlx::query_as::<_, RepoModel>( + "SELECT r.id, r.wk, r.name, r.description, r.default_branch, r.visibility, r.size_bytes, \ + r.is_archived, r.is_template, r.is_mirror, r.created_by, r.storage_path, r.created_at, r.updated_at, r.deleted_at \ + FROM issue_repo ir \ + INNER JOIN repo r ON r.id = ir.repo \ + WHERE ir.issue = $1 AND r.deleted_at IS NULL \ + ORDER BY r.name ASC", + ) + .bind(issue_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(repos.into_iter().map(issue_repo_response).collect()) + } + + pub(crate) async fn issue_pull_requests( + &self, + issue_id: uuid::Uuid, + ) -> Result, AppError> { + let prs = sqlx::query_as::<_, PullRequestModel>( + "SELECT pr.id, pr.repo, pr.number, pr.title, pr.body, pr.state, pr.draft, pr.author, \ + pr.source_repo, pr.source_branch, pr.source_sha, pr.target_branch, pr.target_sha, \ + pr.merged_by, pr.merged_at, pr.closed_by, pr.closed_at, pr.created_at, pr.updated_at, pr.deleted_at \ + FROM issue_pull_request ip \ + INNER JOIN pull_request pr ON pr.id = ip.pull_request \ + WHERE ip.issue = $1 AND pr.deleted_at IS NULL \ + ORDER BY pr.created_at DESC", + ) + .bind(issue_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(prs.into_iter().map(issue_pr_response).collect()) + } +} diff --git a/lib/service/issues/comment.rs b/lib/service/issues/comment.rs new file mode 100644 index 0000000..2808aa6 --- /dev/null +++ b/lib/service/issues/comment.rs @@ -0,0 +1,209 @@ +use db::sqlx; +use model::issues::IssueCommentModel; +use serde::Deserialize; +use session::Session; + +use super::types::{IssueCommentResponse, issue_author}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateComment { + pub body: String, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateComment { + pub body: String, +} + +impl AppService { + pub async fn issue_comment_create( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: CreateComment, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let body = params.body.trim(); + if body.is_empty() { + return Err(AppError::BadRequest( + "comment body is required".to_string(), + )); + } + + let now = chrono::Utc::now(); + let id = uuid::Uuid::now_v7(); + + let comment = sqlx::query_as::<_, IssueCommentModel>( + "INSERT INTO issue_comment (id, issue, author, body, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $5) \ + RETURNING id, issue, author, body, created_at, updated_at, deleted_at", + ) + .bind(id) + .bind(issue.id) + .bind(user_uid) + .bind(body) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'commented', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(comment.id) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let author = self.users_find_by_id(user_uid).await?; + Ok(IssueCommentResponse { + id: comment.id, + author: issue_author(author), + body: comment.body, + created_at: comment.created_at, + updated_at: comment.updated_at, + }) + } + + pub async fn issue_comment_list( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let comments = sqlx::query_as::<_, IssueCommentModel>( + "SELECT id, issue, author, body, created_at, updated_at, deleted_at \ + FROM issue_comment WHERE issue = $1 AND deleted_at IS NULL \ + ORDER BY created_at ASC", + ) + .bind(issue.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for comment in comments { + let author = self.users_find_by_id(comment.author).await?; + results.push(IssueCommentResponse { + id: comment.id, + author: issue_author(author), + body: comment.body, + created_at: comment.created_at, + updated_at: comment.updated_at, + }); + } + Ok(results) + } + + pub async fn issue_comment_update( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + comment_id: uuid::Uuid, + params: UpdateComment, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let comment = sqlx::query_as::<_, IssueCommentModel>( + "SELECT id, issue, author, body, created_at, updated_at, deleted_at \ + FROM issue_comment WHERE id = $1 AND issue = $2 AND deleted_at IS NULL", + ) + .bind(comment_id) + .bind(issue.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::CommentNotFound)?; + + if comment.author != user_uid { + return Err(AppError::Forbidden( + "only the comment author can edit".to_string(), + )); + } + + let body = params.body.trim(); + if body.is_empty() { + return Err(AppError::BadRequest( + "comment body is required".to_string(), + )); + } + + let now = chrono::Utc::now(); + let updated = sqlx::query_as::<_, IssueCommentModel>( + "UPDATE issue_comment SET body = $1, updated_at = $2 WHERE id = $3 \ + RETURNING id, issue, author, body, created_at, updated_at, deleted_at", + ) + .bind(body) + .bind(now) + .bind(comment_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let author = self.users_find_by_id(user_uid).await?; + Ok(IssueCommentResponse { + id: updated.id, + author: issue_author(author), + body: updated.body, + created_at: updated.created_at, + updated_at: updated.updated_at, + }) + } + + pub async fn issue_comment_delete( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + comment_id: uuid::Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let comment = sqlx::query_as::<_, IssueCommentModel>( + "SELECT id, issue, author, body, created_at, updated_at, deleted_at \ + FROM issue_comment WHERE id = $1 AND issue = $2 AND deleted_at IS NULL", + ) + .bind(comment_id) + .bind(issue.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::CommentNotFound)?; + + if comment.author != user_uid { + self.workspace_require_admin(wk.id, user_uid).await?; + } + + sqlx::query("UPDATE issue_comment SET deleted_at = $1 WHERE id = $2") + .bind(chrono::Utc::now()) + .bind(comment_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } +} diff --git a/lib/service/issues/event.rs b/lib/service/issues/event.rs new file mode 100644 index 0000000..72b421b --- /dev/null +++ b/lib/service/issues/event.rs @@ -0,0 +1,48 @@ +use db::sqlx; +use model::issues::IssueEventModel; +use session::Session; + +use super::types::{IssueEventResponse, issue_author}; +use crate::{AppService, error::AppError, session_user}; + +impl AppService { + pub async fn issue_events( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let events = sqlx::query_as::<_, IssueEventModel>( + "SELECT id, issue, actor, event, from_value, to_value, metadata, created_at \ + FROM issue_event WHERE issue = $1 \ + ORDER BY created_at ASC", + ) + .bind(issue.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for event in events { + let actor = if let Some(actor_uid) = event.actor { + let user = self.users_find_by_id(actor_uid).await.ok(); + user.map(issue_author) + } else { + None + }; + results.push(IssueEventResponse { + actor, + event: event.event, + from_value: event.from_value, + to_value: event.to_value, + created_at: event.created_at, + }); + } + Ok(results) + } +} diff --git a/lib/service/issues/issue.rs b/lib/service/issues/issue.rs new file mode 100644 index 0000000..b82e96e --- /dev/null +++ b/lib/service/issues/issue.rs @@ -0,0 +1,501 @@ +use db::{sqlx, sqlx::AssertSqlSafe}; +use model::{issues::IssueModel, users::UserModel}; +use serde::Deserialize; +use session::Session; + +use super::types::{IssueFilter, IssueResponse, issue_author}; +use crate::{AppService, Pagination, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateIssue { + pub title: String, + pub body: Option, + pub priority: Option, + pub due_at: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateIssue { + pub title: Option, + pub body: Option, + pub priority: Option, + pub due_at: Option, +} + +impl AppService { + pub async fn issue_create( + &self, + ctx: &Session, + wk_name: &str, + params: CreateIssue, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let title = params.title.trim(); + if title.is_empty() { + return Err(AppError::BadRequest( + "issue title is required".to_string(), + )); + } + + let now = chrono::Utc::now(); + let id = uuid::Uuid::now_v7(); + let priority = params.priority.unwrap_or_else(|| "normal".to_string()); + let due_at = params.due_at.and_then(|d| { + chrono::DateTime::parse_from_rfc3339(&d) + .ok() + .map(|dt| dt.to_utc()) + }); + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let issue = sqlx::query_as::<_, IssueModel>( + "INSERT INTO issue (id, wk, number, title, body, state, priority, author, due_at, created_at, updated_at) \ + VALUES ($1, $2, (SELECT COALESCE(MAX(number), 0) + 1 FROM issue WHERE wk = $2 AND deleted_at IS NULL), \ + $3, $4, 'open', $5, $6, $7, $8, $8) \ + RETURNING id, wk, number, title, body, state, priority, author, closed_by, closed_at, due_at, created_at, updated_at, deleted_at", + ) + .bind(id) + .bind(wk.id) + .bind(title) + .bind(¶ms.body) + .bind(&priority) + .bind(user_uid) + .bind(due_at) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'created', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&issue.title) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + + let author = self.users_find_by_id(user_uid).await?; + Ok(IssueResponse { + number: issue.number, + title: issue.title, + body: issue.body, + state: issue.state, + priority: issue.priority, + due_at: issue.due_at, + author: issue_author(author), + closed_by: None, + closed_at: None, + created_at: issue.created_at, + updated_at: issue.updated_at, + labels: Vec::new(), + assignees: Vec::new(), + milestone: None, + repos: Vec::new(), + pull_requests: Vec::new(), + }) + } + + pub async fn issue_list( + &self, + ctx: &Session, + wk_name: &str, + filter: IssueFilter, + pagination: Pagination, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let mut conditions = + vec!["i.wk = $1".to_string(), "i.deleted_at IS NULL".to_string()]; + let mut param_idx = 2; + + if filter.state.is_some() { + conditions.push(format!("i.state = ${param_idx}")); + param_idx += 1; + } + if filter.priority.is_some() { + conditions.push(format!("i.priority = ${param_idx}")); + param_idx += 1; + } + if filter.label.is_some() { + conditions.push(format!( + "EXISTS(SELECT 1 FROM issue_label il INNER JOIN label l ON l.id = il.label WHERE il.issue = i.id AND l.name = ${param_idx})" + )); + param_idx += 1; + } + if filter.milestone.is_some() { + conditions.push(format!( + "EXISTS(SELECT 1 FROM issue_milestone im INNER JOIN milestone m ON m.id = im.milestone WHERE im.issue = i.id AND m.title = ${param_idx})" + )); + param_idx += 1; + } + if filter.assignee.is_some() { + conditions.push(format!( + "EXISTS(SELECT 1 FROM issue_assignee ia INNER JOIN \"user\" u ON u.id = ia.\"user\" WHERE ia.issue = i.id AND u.username = ${param_idx})" + )); + param_idx += 1; + } + + let where_clause = conditions.join(" AND "); + let limit_idx = param_idx; + let offset_idx = param_idx + 1; + + let query = format!( + "SELECT i.id, i.wk, i.number, i.title, i.body, i.state, i.priority, i.author, \ + i.closed_by, i.closed_at, i.due_at, i.created_at, i.updated_at, i.deleted_at \ + FROM issue i WHERE {where_clause} \ + ORDER BY i.created_at DESC LIMIT ${limit_idx} OFFSET ${offset_idx}" + ); + + let mut q = + sqlx::query_as::<_, IssueModel>(AssertSqlSafe(query)).bind(wk.id); + if let Some(state) = &filter.state { + q = q.bind(state); + } + if let Some(priority) = &filter.priority { + q = q.bind(priority); + } + if let Some(label_name) = &filter.label { + q = q.bind(label_name); + } + if let Some(milestone_title) = &filter.milestone { + q = q.bind(milestone_title); + } + if let Some(assignee_username) = &filter.assignee { + q = q.bind(assignee_username); + } + q = q + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64); + + let issues = q + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for issue in issues { + results.push(self.issue_build_response(issue).await?); + } + Ok(results) + } + + pub async fn issue_get( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let issue = self.issue_resolve(wk.id, number).await?; + self.issue_build_response(issue).await + } + + pub async fn issue_update( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: UpdateIssue, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let mut issue = self.issue_resolve(wk.id, number).await?; + + let now = chrono::Utc::now(); + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + if let Some(title) = ¶ms.title { + let title = title.trim(); + if title.is_empty() { + return Err(AppError::BadRequest( + "issue title is required".to_string(), + )); + } + if title != &issue.title { + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'title_changed', $4, $5, $6)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&issue.title) + .bind(title) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + issue.title = title.to_string(); + } + } + + if let Some(priority) = ¶ms.priority { + if priority != &issue.priority { + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'priority_changed', $4, $5, $6)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&issue.priority) + .bind(priority) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + issue.priority = priority.to_string(); + } + } + + let next_body = params.body.map(Some).unwrap_or(issue.body.clone()); + let next_due_at = params + .due_at + .and_then(|d| { + chrono::DateTime::parse_from_rfc3339(&d) + .ok() + .map(|dt| dt.to_utc()) + }) + .or(issue.due_at); + + issue = sqlx::query_as::<_, IssueModel>( + "UPDATE issue SET title = $1, body = $2, priority = $3, due_at = $4, updated_at = $5 WHERE id = $6 \ + RETURNING id, wk, number, title, body, state, priority, author, closed_by, closed_at, due_at, created_at, updated_at, deleted_at", + ) + .bind(&issue.title) + .bind(&next_body) + .bind(&issue.priority) + .bind(next_due_at) + .bind(now) + .bind(issue.id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + self.issue_build_response(issue).await + } + + pub async fn issue_close( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + if issue.state == "closed" { + return Err(AppError::BadRequest( + "issue is already closed".to_string(), + )); + } + + let now = chrono::Utc::now(); + sqlx::query( + "UPDATE issue SET state = 'closed', closed_by = $1, closed_at = $2, updated_at = $2 WHERE id = $3", + ) + .bind(user_uid) + .bind(now) + .bind(issue.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'closed', $4, 'closed', $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&issue.state) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let issue = self.issue_resolve(wk.id, number).await?; + self.issue_build_response(issue).await + } + + pub async fn issue_reopen( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + if issue.state == "open" { + return Err(AppError::BadRequest( + "issue is already open".to_string(), + )); + } + + let now = chrono::Utc::now(); + sqlx::query( + "UPDATE issue SET state = 'open', closed_by = NULL, closed_at = NULL, updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(issue.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'reopened', $4, 'open', $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&issue.state) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let issue = self.issue_resolve(wk.id, number).await?; + self.issue_build_response(issue).await + } + + pub async fn issue_delete( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + sqlx::query("UPDATE issue SET deleted_at = $1 WHERE id = $2") + .bind(chrono::Utc::now()) + .bind(issue.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub(crate) async fn issue_resolve( + &self, + wk_id: uuid::Uuid, + number: i64, + ) -> Result { + sqlx::query_as::<_, IssueModel>( + "SELECT id, wk, number, title, body, state, priority, author, closed_by, closed_at, due_at, \ + created_at, updated_at, deleted_at \ + FROM issue WHERE wk = $1 AND number = $2 AND deleted_at IS NULL", + ) + .bind(wk_id) + .bind(number) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::IssueNotFound) + } + + async fn issue_build_response( + &self, + issue: IssueModel, + ) -> Result { + let author = self.users_find_by_id(issue.author).await?; + + let closed_by = if let Some(closed_uid) = issue.closed_by { + let user = self.users_find_by_id(closed_uid).await?; + Some(issue_author(user)) + } else { + None + }; + + let labels = self.issue_labels(issue.id).await?; + let assignees = self.issue_assignees(issue.id).await?; + let milestone = self.issue_milestone(issue.id).await?; + let repos = self.issue_repos(issue.id).await?; + let pull_requests = self.issue_pull_requests(issue.id).await?; + + Ok(IssueResponse { + number: issue.number, + title: issue.title, + body: issue.body, + state: issue.state, + priority: issue.priority, + due_at: issue.due_at, + author: issue_author(author), + closed_by, + closed_at: issue.closed_at, + created_at: issue.created_at, + updated_at: issue.updated_at, + labels, + assignees, + milestone, + repos, + pull_requests, + }) + } + + pub(crate) async fn users_find_by_id( + &self, + uid: uuid::Uuid, + ) -> Result { + sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE id = $1 AND allow_use = true", + ) + .bind(uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound) + } + + /// CMDK BFF: list issue numbers, titles, states for a workspace. + pub async fn issue_list_inner( + &self, + wk_name: &str, + ) -> Result, AppError> { + let wk = sqlx::query_as::<_, (uuid::Uuid,)>( + "SELECT id FROM workspace WHERE name = $1" + ) + .bind(wk_name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("workspace not found".to_string()))?; + + let rows = sqlx::query_as::<_, (i32, String, String)>( + "SELECT number, title, state FROM issue WHERE wk = $1 AND deleted_at IS NULL ORDER BY updated_at DESC LIMIT 50" + ) + .bind(wk.0) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(rows) + } +} diff --git a/lib/service/issues/label.rs b/lib/service/issues/label.rs new file mode 100644 index 0000000..e6fb740 --- /dev/null +++ b/lib/service/issues/label.rs @@ -0,0 +1,286 @@ +use db::sqlx; +use model::issues::LabelModel; +use serde::Deserialize; +use session::Session; + +use super::types::{LabelResponse, label_response}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateLabel { + pub name: String, + pub color: String, + pub description: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateLabel { + pub name: Option, + pub color: Option, + pub description: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AddIssueLabel { + pub label_id: uuid::Uuid, +} + +impl AppService { + pub async fn label_create( + &self, + ctx: &Session, + wk_name: &str, + params: CreateLabel, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let name = params.name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest( + "label name is required".to_string(), + )); + } + + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + + let label = sqlx::query_as::<_, LabelModel>( + "INSERT INTO label (id, wk, name, color, description, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $6) \ + RETURNING id, wk, name, color, description, created_at, updated_at", + ) + .bind(id) + .bind(wk.id) + .bind(name) + .bind(¶ms.color) + .bind(¶ms.description) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(label_response(label)) + } + + pub async fn label_list( + &self, + ctx: &Session, + wk_name: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let labels = sqlx::query_as::<_, LabelModel>( + "SELECT id, wk, name, color, description, created_at, updated_at \ + FROM label WHERE wk = $1 ORDER BY name ASC", + ) + .bind(wk.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(labels.into_iter().map(label_response).collect()) + } + + pub async fn label_update( + &self, + ctx: &Session, + wk_name: &str, + label_id: uuid::Uuid, + params: UpdateLabel, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let label = sqlx::query_as::<_, LabelModel>( + "SELECT id, wk, name, color, description, created_at, updated_at \ + FROM label WHERE id = $1 AND wk = $2", + ) + .bind(label_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::LabelNotFound)?; + + let name = params.name.unwrap_or_else(|| label.name.clone()); + let color = params.color.unwrap_or_else(|| label.color.clone()); + let description = params + .description + .unwrap_or_else(|| label.description.clone().unwrap_or_default()); + + let now = chrono::Utc::now(); + let updated = sqlx::query_as::<_, LabelModel>( + "UPDATE label SET name = $1, color = $2, description = $3, updated_at = $4 WHERE id = $5 \ + RETURNING id, wk, name, color, description, created_at, updated_at", + ) + .bind(name) + .bind(color) + .bind(description) + .bind(now) + .bind(label_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(label_response(updated)) + } + + pub async fn label_delete( + &self, + ctx: &Session, + wk_name: &str, + label_id: uuid::Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let exists = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM label WHERE id = $1 AND wk = $2)", + ) + .bind(label_id) + .bind(wk.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if !exists { + return Err(AppError::LabelNotFound); + } + + sqlx::query("DELETE FROM issue_label WHERE label = $1") + .bind(label_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query("DELETE FROM label WHERE id = $1") + .bind(label_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub async fn issue_add_label( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: AddIssueLabel, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let label = sqlx::query_as::<_, LabelModel>( + "SELECT id, wk, name, color, description, created_at, updated_at \ + FROM label WHERE id = $1 AND wk = $2", + ) + .bind(params.label_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::LabelNotFound)?; + + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO issue_label (issue, label, created_at) VALUES ($1, $2, $3) \ + ON CONFLICT (issue, label) DO NOTHING", + ) + .bind(issue.id) + .bind(label.id) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'labeled', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&label.name) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_labels(issue.id).await + } + + pub async fn issue_remove_label( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + label_id: uuid::Uuid, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let label = sqlx::query_as::<_, LabelModel>( + "SELECT id, wk, name, color, description, created_at, updated_at \ + FROM label WHERE id = $1 AND wk = $2", + ) + .bind(label_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::LabelNotFound)?; + + sqlx::query("DELETE FROM issue_label WHERE issue = $1 AND label = $2") + .bind(issue.id) + .bind(label_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'unlabeled', $4, NULL, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&label.name) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_labels(issue.id).await + } + + pub(crate) async fn issue_labels( + &self, + issue_id: uuid::Uuid, + ) -> Result, AppError> { + let labels = sqlx::query_as::<_, LabelModel>( + "SELECT l.id, l.wk, l.name, l.color, l.description, l.created_at, l.updated_at \ + FROM issue_label il \ + INNER JOIN label l ON l.id = il.label \ + WHERE il.issue = $1 \ + ORDER BY l.name ASC", + ) + .bind(issue_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(labels.into_iter().map(label_response).collect()) + } +} diff --git a/lib/service/issues/milestone.rs b/lib/service/issues/milestone.rs new file mode 100644 index 0000000..d5d7ab9 --- /dev/null +++ b/lib/service/issues/milestone.rs @@ -0,0 +1,302 @@ +use db::sqlx; +use model::issues::MilestoneModel; +use serde::Deserialize; +use session::Session; + +use super::types::{MilestoneResponse, milestone_response}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateMilestone { + pub title: String, + pub description: Option, + pub due_at: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateMilestone { + pub title: Option, + pub description: Option, + pub state: Option, + pub due_at: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct SetIssueMilestone { + pub milestone_id: uuid::Uuid, +} + +impl AppService { + pub async fn milestone_create( + &self, + ctx: &Session, + wk_name: &str, + params: CreateMilestone, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let title = params.title.trim(); + if title.is_empty() { + return Err(AppError::BadRequest( + "milestone title is required".to_string(), + )); + } + + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + let due_at = params.due_at.and_then(|d| { + chrono::DateTime::parse_from_rfc3339(&d) + .ok() + .map(|dt| dt.to_utc()) + }); + + let milestone = sqlx::query_as::<_, MilestoneModel>( + "INSERT INTO milestone (id, wk, title, description, state, due_at, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, 'open', $5, $6, $6) \ + RETURNING id, wk, title, description, state, due_at, closed_at, created_at, updated_at", + ) + .bind(id) + .bind(wk.id) + .bind(title) + .bind(¶ms.description) + .bind(due_at) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(milestone_response(milestone)) + } + + pub async fn milestone_list( + &self, + ctx: &Session, + wk_name: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let milestones = sqlx::query_as::<_, MilestoneModel>( + "SELECT id, wk, title, description, state, due_at, closed_at, created_at, updated_at \ + FROM milestone WHERE wk = $1 ORDER BY created_at ASC", + ) + .bind(wk.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(milestones.into_iter().map(milestone_response).collect()) + } + + pub async fn milestone_update( + &self, + ctx: &Session, + wk_name: &str, + milestone_id: uuid::Uuid, + params: UpdateMilestone, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let milestone = sqlx::query_as::<_, MilestoneModel>( + "SELECT id, wk, title, description, state, due_at, closed_at, created_at, updated_at \ + FROM milestone WHERE id = $1 AND wk = $2", + ) + .bind(milestone_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::MilestoneNotFound)?; + + let title = params.title.unwrap_or_else(|| milestone.title.clone()); + let description = params.description.unwrap_or_else(|| { + milestone.description.clone().unwrap_or_default() + }); + let state = params.state.unwrap_or_else(|| milestone.state.clone()); + let due_at = params + .due_at + .and_then(|d| { + chrono::DateTime::parse_from_rfc3339(&d) + .ok() + .map(|dt| dt.to_utc()) + }) + .or(milestone.due_at); + let closed_at = if state == "closed" && milestone.state != "closed" { + Some(chrono::Utc::now()) + } else { + milestone.closed_at + }; + + let now = chrono::Utc::now(); + let updated = sqlx::query_as::<_, MilestoneModel>( + "UPDATE milestone SET title = $1, description = $2, state = $3, due_at = $4, closed_at = $5, updated_at = $6 \ + WHERE id = $7 \ + RETURNING id, wk, title, description, state, due_at, closed_at, created_at, updated_at", + ) + .bind(title) + .bind(description) + .bind(state) + .bind(due_at) + .bind(closed_at) + .bind(now) + .bind(milestone_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(milestone_response(updated)) + } + + pub async fn milestone_delete( + &self, + ctx: &Session, + wk_name: &str, + milestone_id: uuid::Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let exists = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM milestone WHERE id = $1 AND wk = $2)", + ) + .bind(milestone_id) + .bind(wk.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if !exists { + return Err(AppError::MilestoneNotFound); + } + + sqlx::query("DELETE FROM issue_milestone WHERE milestone = $1") + .bind(milestone_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query("DELETE FROM milestone WHERE id = $1") + .bind(milestone_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub async fn issue_set_milestone( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: SetIssueMilestone, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let milestone = sqlx::query_as::<_, MilestoneModel>( + "SELECT id, wk, title, description, state, due_at, closed_at, created_at, updated_at \ + FROM milestone WHERE id = $1 AND wk = $2", + ) + .bind(params.milestone_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::MilestoneNotFound)?; + + let now = chrono::Utc::now(); + sqlx::query("DELETE FROM issue_milestone WHERE issue = $1") + .bind(issue.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_milestone (issue, milestone, created_at) VALUES ($1, $2, $3)", + ) + .bind(issue.id) + .bind(milestone.id) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'milestone_changed', NULL, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&milestone.title) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(Some(milestone_response(milestone))) + } + + pub async fn issue_clear_milestone( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let old_milestone = self.issue_milestone(issue.id).await?; + if let Some(old) = &old_milestone { + sqlx::query("DELETE FROM issue_milestone WHERE issue = $1") + .bind(issue.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO issue_event (id, issue, actor, event, from_value, to_value, created_at) \ + VALUES ($1, $2, $3, 'milestone_removed', $4, NULL, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(issue.id) + .bind(user_uid) + .bind(&old.title) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + Ok(()) + } + + pub(crate) async fn issue_milestone( + &self, + issue_id: uuid::Uuid, + ) -> Result, AppError> { + let milestone = sqlx::query_as::<_, MilestoneModel>( + "SELECT m.id, m.wk, m.title, m.description, m.state, m.due_at, m.closed_at, m.created_at, m.updated_at \ + FROM issue_milestone im \ + INNER JOIN milestone m ON m.id = im.milestone \ + WHERE im.issue = $1", + ) + .bind(issue_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(milestone.map(milestone_response)) + } +} diff --git a/lib/service/issues/mod.rs b/lib/service/issues/mod.rs new file mode 100644 index 0000000..b66e8e0 --- /dev/null +++ b/lib/service/issues/mod.rs @@ -0,0 +1,9 @@ +pub mod assignee; +pub mod binding; +pub mod comment; +pub mod event; +pub mod issue; +pub mod label; +pub mod milestone; +pub mod reaction; +pub mod types; diff --git a/lib/service/issues/reaction.rs b/lib/service/issues/reaction.rs new file mode 100644 index 0000000..f440d98 --- /dev/null +++ b/lib/service/issues/reaction.rs @@ -0,0 +1,186 @@ +use db::sqlx; +use model::issues::IssueReactionModel; +use serde::Deserialize; +use session::Session; + +use super::types::{IssueReactionResponse, issue_author}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AddReaction { + pub reaction: String, +} + +impl AppService { + pub async fn issue_add_reaction( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + params: AddReaction, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + + sqlx::query( + "INSERT INTO issue_reaction (id, issue, comment, \"user\", reaction, created_at) \ + VALUES ($1, $2, NULL, $3, $4, $5) \ + ON CONFLICT (issue, comment, \"user\", reaction) DO NOTHING", + ) + .bind(id) + .bind(issue.id) + .bind(user_uid) + .bind(¶ms.reaction) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_reactions_for(issue.id, None).await + } + + pub async fn issue_remove_reaction( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + reaction: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + sqlx::query( + "DELETE FROM issue_reaction WHERE issue = $1 AND comment IS NULL AND \"user\" = $2 AND reaction = $3", + ) + .bind(issue.id) + .bind(user_uid) + .bind(reaction) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_reactions_for(issue.id, None).await + } + + pub async fn issue_comment_add_reaction( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + comment_id: uuid::Uuid, + params: AddReaction, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + let exists = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM issue_comment WHERE id = $1 AND issue = $2 AND deleted_at IS NULL)", + ) + .bind(comment_id) + .bind(issue.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if !exists { + return Err(AppError::CommentNotFound); + } + + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + + sqlx::query( + "INSERT INTO issue_reaction (id, issue, comment, \"user\", reaction, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6) \ + ON CONFLICT (issue, comment, \"user\", reaction) DO NOTHING", + ) + .bind(id) + .bind(issue.id) + .bind(comment_id) + .bind(user_uid) + .bind(¶ms.reaction) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_reactions_for(issue.id, Some(comment_id)).await + } + + pub async fn issue_comment_remove_reaction( + &self, + ctx: &Session, + wk_name: &str, + number: i64, + comment_id: uuid::Uuid, + reaction: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let issue = self.issue_resolve(wk.id, number).await?; + + sqlx::query( + "DELETE FROM issue_reaction WHERE issue = $1 AND comment = $2 AND \"user\" = $3 AND reaction = $4", + ) + .bind(issue.id) + .bind(comment_id) + .bind(user_uid) + .bind(reaction) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.issue_reactions_for(issue.id, Some(comment_id)).await + } + + pub(crate) async fn issue_reactions_for( + &self, + issue_id: uuid::Uuid, + comment_id: Option, + ) -> Result, AppError> { + let reactions = if let Some(cid) = comment_id { + sqlx::query_as::<_, IssueReactionModel>( + "SELECT id, issue, comment, \"user\", reaction, created_at \ + FROM issue_reaction WHERE issue = $1 AND comment = $2 \ + ORDER BY created_at ASC", + ) + .bind(issue_id) + .bind(cid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + } else { + sqlx::query_as::<_, IssueReactionModel>( + "SELECT id, issue, comment, \"user\", reaction, created_at \ + FROM issue_reaction WHERE issue = $1 AND comment IS NULL \ + ORDER BY created_at ASC", + ) + .bind(issue_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + }; + + let mut results = Vec::new(); + for reaction in reactions { + let user = self.users_find_by_id(reaction.user).await?; + results.push(IssueReactionResponse { + id: reaction.id, + user: issue_author(user), + reaction: reaction.reaction, + created_at: reaction.created_at, + }); + } + Ok(results) + } +} diff --git a/lib/service/issues/types.rs b/lib/service/issues/types.rs new file mode 100644 index 0000000..f6547fd --- /dev/null +++ b/lib/service/issues/types.rs @@ -0,0 +1,164 @@ +use model::{ + issues::{LabelModel, MilestoneModel}, + pull_request::PullRequestModel, + repos::RepoModel, + users::UserModel, +}; +use serde::{Deserialize, Serialize}; + +use crate::non_empty; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssueResponse { + pub number: i64, + pub title: String, + pub body: Option, + pub state: String, + pub priority: String, + #[schema(value_type = Option)] + pub due_at: Option>, + pub author: IssueAuthor, + pub closed_by: Option, + #[schema(value_type = Option)] + pub closed_at: Option>, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, + pub labels: Vec, + pub assignees: Vec, + pub milestone: Option, + pub repos: Vec, + pub pull_requests: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssueAuthor { + pub username: String, + pub display_name: Option, + pub avatar_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct LabelResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, + pub color: String, + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct MilestoneResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub title: String, + pub description: Option, + pub state: String, + #[schema(value_type = Option)] + pub due_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssueRepoResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssuePullRequestResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub number: i64, + pub title: String, + pub state: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssueCommentResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub author: IssueAuthor, + pub body: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssueEventResponse { + pub actor: Option, + pub event: String, + pub from_value: Option, + pub to_value: Option, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct IssueReactionResponse { + #[schema(value_type = String)] + pub id: uuid::Uuid, + pub user: IssueAuthor, + pub reaction: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +pub(crate) fn issue_author(user: UserModel) -> IssueAuthor { + IssueAuthor { + username: user.username, + display_name: non_empty(user.display_name), + avatar_url: non_empty(user.avatar_url), + } +} + +pub(crate) fn label_response(label: LabelModel) -> LabelResponse { + LabelResponse { + id: label.id, + name: label.name, + color: label.color, + description: non_empty(label.description.unwrap_or_default()), + } +} + +pub(crate) fn milestone_response( + milestone: MilestoneModel, +) -> MilestoneResponse { + MilestoneResponse { + id: milestone.id, + title: milestone.title, + description: non_empty(milestone.description.unwrap_or_default()), + state: milestone.state, + due_at: milestone.due_at, + } +} + +pub(crate) fn issue_repo_response(repo: RepoModel) -> IssueRepoResponse { + IssueRepoResponse { + id: repo.id, + name: repo.name, + } +} + +pub(crate) fn issue_pr_response( + pr: PullRequestModel, +) -> IssuePullRequestResponse { + IssuePullRequestResponse { + id: pr.id, + number: pr.number, + title: pr.title, + state: pr.state, + } +} + +#[derive(Debug, Clone, Deserialize, utoipa::IntoParams)] +pub struct IssueFilter { + pub state: Option, + pub label: Option, + pub milestone: Option, + pub assignee: Option, + pub priority: Option, +} diff --git a/lib/service/lib.rs b/lib/service/lib.rs new file mode 100644 index 0000000..7a19d43 --- /dev/null +++ b/lib/service/lib.rs @@ -0,0 +1,65 @@ +use cache::AppCache; +use config::AppConfig; +use db::AppDatabase; +use deadpool_redis::cluster::Pool as RedisPool; +use email::AppEmail; +use error::AppError; +use serde::Deserialize; +use session::Session; +use storage::AppStorage; +use tonic::transport::Channel; +use uuid::Uuid; + +pub mod agent; +pub mod ai; +pub mod auth; +pub mod error; +pub mod git; +pub mod issues; +pub mod pull_request; +pub mod user; +pub mod users; +pub mod workspace; +pub(crate) fn session_user(ctx: &Session) -> Result { + ctx.user().ok_or(AppError::Unauthorized) +} + +pub(crate) fn non_empty(value: String) -> Option { + if value.is_empty() { None } else { Some(value) } +} + +pub(crate) fn constant_time_eq(a: &str, b: &str) -> bool { + if a.len() != b.len() { + return false; + } + a.bytes() + .zip(b.bytes()) + .fold(0u8, |acc, (x, y)| acc | (x ^ y)) + == 0 +} + +#[derive(Debug, Clone, Deserialize, utoipa::IntoParams, utoipa::ToSchema)] +pub struct Pagination { + pub offset: Option, + pub limit: Option, +} + +impl Pagination { + pub fn offset(&self) -> u32 { + self.offset.unwrap_or(0) + } + pub fn limit(&self) -> u32 { + self.limit.unwrap_or(20).min(100) + } +} + +#[derive(Clone)] +pub struct AppService { + pub db: AppDatabase, + pub cache: AppCache, + pub email: AppEmail, + pub storage: AppStorage, + pub config: AppConfig, + pub git: Channel, + pub redis_pool: RedisPool, +} diff --git a/lib/service/pull_request/assignee.rs b/lib/service/pull_request/assignee.rs new file mode 100644 index 0000000..a94a015 --- /dev/null +++ b/lib/service/pull_request/assignee.rs @@ -0,0 +1,110 @@ +use db::sqlx; +use model::users::UserModel; +use serde::Deserialize; +use session::Session; + +use crate::{ + AppService, error::AppError, issues::types::IssueAuthor, session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AssignPrUser { + pub username: String, +} + +impl AppService { + pub async fn pr_assign( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: AssignPrUser, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let assignee = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE username = $1 AND allow_use = true", + ) + .bind(¶ms.username) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + sqlx::query( + "INSERT INTO pull_request_assignee (pull_request, \"user\", assigned_by, created_at) \ + VALUES ($1, $2, $3, $4) ON CONFLICT (pull_request, \"user\") DO NOTHING", + ) + .bind(pr.id) + .bind(assignee.id) + .bind(user_uid) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.pr_assignees_list(pr.id).await + } + + pub async fn pr_unassign( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + username: &str, + ) -> Result, AppError> { + let _user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let assignee = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE username = $1", + ) + .bind(username) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + sqlx::query("DELETE FROM pull_request_assignee WHERE pull_request = $1 AND \"user\" = $2") + .bind(pr.id) + .bind(assignee.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.pr_assignees_list(pr.id).await + } + + pub(crate) async fn pr_assignees_list( + &self, + pr_id: uuid::Uuid, + ) -> Result, AppError> { + let assignees = sqlx::query_as::<_, UserModel>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at \ + FROM pull_request_assignee pa INNER JOIN \"user\" u ON u.id = pa.\"user\" \ + WHERE pa.pull_request = $1 \ + ORDER BY u.username ASC", + ) + .bind(pr_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(assignees + .into_iter() + .map(|u| crate::issues::types::issue_author(u)) + .collect()) + } +} diff --git a/lib/service/pull_request/comment.rs b/lib/service/pull_request/comment.rs new file mode 100644 index 0000000..67973bb --- /dev/null +++ b/lib/service/pull_request/comment.rs @@ -0,0 +1,187 @@ +use db::sqlx; +use model::pull_request::PullRequestCommentModel; +use serde::Deserialize; +use session::Session; + +use crate::{ + AppService, error::AppError, + pull_request::types::PullRequestCommentResponse, session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreatePrComment { + pub body: String, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdatePrComment { + pub body: String, +} + +impl AppService { + pub async fn pr_comment_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: CreatePrComment, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + if params.body.trim().is_empty() { + return Err(AppError::BadRequest( + "comment body is required".to_string(), + )); + } + + let now = chrono::Utc::now(); + let comment = sqlx::query_as::<_, PullRequestCommentModel>( + "INSERT INTO pull_request_comment (id, pull_request, author, body, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $5) \ + RETURNING id, pull_request, author, body, created_at, updated_at, deleted_at", + ) + .bind(uuid::Uuid::now_v7()) + .bind(pr.id) + .bind(user_uid) + .bind(¶ms.body) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let author = self.users_find_by_id(user_uid).await?; + Ok(PullRequestCommentResponse { + id: comment.id, + author: crate::issues::types::issue_author(author), + body: comment.body, + created_at: comment.created_at, + updated_at: comment.updated_at, + }) + } + + pub async fn pr_comment_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result, AppError> { + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let comments = sqlx::query_as::<_, PullRequestCommentModel>( + "SELECT id, pull_request, author, body, created_at, updated_at, deleted_at \ + FROM pull_request_comment WHERE pull_request = $1 AND deleted_at IS NULL \ + ORDER BY created_at ASC", + ) + .bind(pr.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for comment in comments { + let author = self.users_find_by_id(comment.author).await?; + results.push(PullRequestCommentResponse { + id: comment.id, + author: crate::issues::types::issue_author(author), + body: comment.body, + created_at: comment.created_at, + updated_at: comment.updated_at, + }); + } + Ok(results) + } + + pub async fn pr_comment_update( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + _number: i64, + comment_id: uuid::Uuid, + params: UpdatePrComment, + ) -> Result { + let user_uid = session_user(ctx)?; + let _repo_id = self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + let comment = sqlx::query_as::<_, PullRequestCommentModel>( + "SELECT id, pull_request, author, body, created_at, updated_at, deleted_at \ + FROM pull_request_comment WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(comment_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::CommentNotFound)?; + + if comment.author != user_uid { + return Err(AppError::Forbidden( + "only the author can update this comment".to_string(), + )); + } + + let now = chrono::Utc::now(); + let comment = sqlx::query_as::<_, PullRequestCommentModel>( + "UPDATE pull_request_comment SET body = $1, updated_at = $2 WHERE id = $3 \ + RETURNING id, pull_request, author, body, created_at, updated_at, deleted_at", + ) + .bind(¶ms.body) + .bind(now) + .bind(comment_id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let author = self.users_find_by_id(comment.author).await?; + Ok(PullRequestCommentResponse { + id: comment.id, + author: crate::issues::types::issue_author(author), + body: comment.body, + created_at: comment.created_at, + updated_at: comment.updated_at, + }) + } + + pub async fn pr_comment_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + comment_id: uuid::Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let _repo_id = self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + let comment = sqlx::query_as::<_, PullRequestCommentModel>( + "SELECT id, pull_request, author, body, created_at, updated_at, deleted_at \ + FROM pull_request_comment WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(comment_id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::CommentNotFound)?; + + if comment.author != user_uid { + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + } + + sqlx::query( + "UPDATE pull_request_comment SET deleted_at = $1 WHERE id = $2", + ) + .bind(chrono::Utc::now()) + .bind(comment_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } +} diff --git a/lib/service/pull_request/label.rs b/lib/service/pull_request/label.rs new file mode 100644 index 0000000..fb78cae --- /dev/null +++ b/lib/service/pull_request/label.rs @@ -0,0 +1,100 @@ +use db::sqlx; +use model::issues::LabelModel; +use serde::Deserialize; +use session::Session; +use uuid::Uuid; + +use crate::{ + AppService, error::AppError, issues::types::LabelResponse, session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AddPrLabel { + pub label_id: Uuid, +} + +impl AppService { + pub async fn pr_add_label( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: AddPrLabel, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + let pr = self.pr_resolve(repo.id, number).await?; + + let _label = sqlx::query_as::<_, LabelModel>( + "SELECT id, wk, name, color, description, created_at, updated_at, deleted_at \ + FROM label WHERE id = $1 AND wk = $2 AND deleted_at IS NULL", + ) + .bind(params.label_id) + .bind(wk.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::LabelNotFound)?; + + sqlx::query( + "INSERT INTO pull_request_label (pull_request, label, created_at) VALUES ($1, $2, $3) \ + ON CONFLICT (pull_request, label) DO NOTHING", + ) + .bind(pr.id) + .bind(params.label_id) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.pr_labels(pr.id).await + } + + pub async fn pr_remove_label( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + label_id: Uuid, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + let pr = self.pr_resolve(repo.id, number).await?; + + sqlx::query("DELETE FROM pull_request_label WHERE pull_request = $1 AND label = $2") + .bind(pr.id) + .bind(label_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.pr_labels(pr.id).await + } + + pub(crate) async fn pr_labels( + &self, + pr_id: uuid::Uuid, + ) -> Result, AppError> { + let labels = sqlx::query_as::<_, LabelModel>( + "SELECT l.id, l.wk, l.name, l.color, l.description, l.created_at, l.updated_at, l.deleted_at \ + FROM pull_request_label pl INNER JOIN label l ON l.id = pl.label \ + WHERE pl.pull_request = $1 AND l.deleted_at IS NULL \ + ORDER BY l.name ASC", + ) + .bind(pr_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(labels + .into_iter() + .map(crate::issues::types::label_response) + .collect()) + } +} diff --git a/lib/service/pull_request/merge.rs b/lib/service/pull_request/merge.rs new file mode 100644 index 0000000..14c1d71 --- /dev/null +++ b/lib/service/pull_request/merge.rs @@ -0,0 +1,244 @@ +use db::sqlx; +use git::rpc::{proto as p, proto::merge_service_client::MergeServiceClient}; +use serde::Deserialize; +use session::Session; + +use crate::{ + AppService, error::AppError, git::rpc_err, + pull_request::types::PullRequestResponse, session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct MergePullRequest { + pub method: Option, + pub commit_title: Option, + pub commit_message: Option, +} + +impl AppService { + pub async fn pr_merge_analysis( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result { + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let mut client = MergeServiceClient::new(self.git.clone()); + let resp = client + .merge_analysis(tonic::Request::new(p::MergeAnalysisRequest { + repo_id: repo_id.to_string(), + oid_a: Some(p::ObjectId { + value: pr.source_sha.clone(), + }), + oid_b: Some(p::ObjectId { + value: pr.target_sha.clone(), + }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn pr_merge_base( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result { + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let mut client = MergeServiceClient::new(self.git.clone()); + let resp = client + .merge_base(tonic::Request::new(p::MergeBaseRequest { + repo_id: repo_id.to_string(), + oid_a: Some(p::ObjectId { + value: pr.source_sha.clone(), + }), + oid_b: Some(p::ObjectId { + value: pr.target_sha.clone(), + }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + Ok(resp) + } + + pub async fn pr_merge( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: MergePullRequest, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo_admin(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + if pr.state != "open" { + return Err(AppError::BadRequest( + "pull request is not open".to_string(), + )); + } + if pr.draft { + return Err(AppError::BadRequest( + "draft pull request cannot be merged".to_string(), + )); + } + + let method = params.method.unwrap_or_else(|| "merge".to_string()); + let now = chrono::Utc::now(); + + let merge_result_sha = match method.as_str() { + "merge" => { + let mut client = MergeServiceClient::new(self.git.clone()); + let resp = client + .merge_commit(tonic::Request::new(p::MergeCommitRequest { + repo_id: repo_id.to_string(), + params: Some(p::MergeCommitParams { + their_commit: Some(p::ObjectId { + value: pr.source_sha.clone(), + }), + author: Some(p::CommitSignature { + name: format!("merge: {}", pr.title), + email: "noreply@gitdata.ai".to_string(), + time_secs: now.timestamp(), + offset_minutes: 0, + }), + committer: Some(p::CommitSignature { + name: "gitdata".to_string(), + email: "noreply@gitdata.ai".to_string(), + time_secs: now.timestamp(), + offset_minutes: 0, + }), + message: params.commit_message.unwrap_or_else( + || { + format!( + "Merge pull request #{}: {}", + pr.number, pr.title + ) + }, + ), + update_ref: Some(format!( + "refs/heads/{}", + pr.target_branch + )), + options: None, + }), + })) + .await + .map_err(rpc_err)? + .into_inner(); + resp.oid.map(|oid| oid.value).unwrap_or_default() + } + "squash" => { + let mut client = MergeServiceClient::new(self.git.clone()); + let resp = client + .squash_commit(tonic::Request::new( + p::SquashCommitRequest { + repo_id: repo_id.to_string(), + params: Some(p::SquashCommitParams { + their_commit: Some(p::ObjectId { + value: pr.source_sha.clone(), + }), + options: None, + }), + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + resp.oid.map(|oid| oid.value).unwrap_or_default() + } + _ => { + return Err(AppError::BadRequest( + "merge method must be 'merge' or 'squash'".to_string(), + )); + } + }; + + sqlx::query( + "UPDATE pull_request SET state = 'merged', merged_by = $1, merged_at = $2, \ + target_sha = $3, updated_at = $2 WHERE id = $4", + ) + .bind(user_uid) + .bind(now) + .bind(&merge_result_sha) + .bind(pr.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let pr = self.pr_resolve(repo_id, number).await?; + self.pr_build_response(pr).await + } + + pub async fn pr_merge_abort( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result<(), AppError> { + let (repo_id, _) = + self.pr_resolve_repo_admin(ctx, wk_name, repo_name).await?; + + let mut client = MergeServiceClient::new(self.git.clone()); + client + .merge_abort(tonic::Request::new(p::MergeAbortRequest { + repo_id: repo_id.to_string(), + })) + .await + .map_err(rpc_err)?; + + Ok(()) + } + + pub async fn pr_update_branch( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result { + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + if pr.state != "open" { + return Err(AppError::BadRequest( + "pull request is not open".to_string(), + )); + } + + let source_sha = self + .branch_head_sha(pr.source_repo, &pr.source_branch) + .await?; + let target_sha = + self.branch_head_sha(repo_id, &pr.target_branch).await?; + let now = chrono::Utc::now(); + + sqlx::query( + "UPDATE pull_request SET source_sha = $1, target_sha = $2, updated_at = $3 WHERE id = $4", + ) + .bind(&source_sha) + .bind(&target_sha) + .bind(now) + .bind(pr.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let pr = self.pr_resolve(repo_id, number).await?; + self.pr_build_response(pr).await + } +} diff --git a/lib/service/pull_request/mod.rs b/lib/service/pull_request/mod.rs new file mode 100644 index 0000000..ef5472b --- /dev/null +++ b/lib/service/pull_request/mod.rs @@ -0,0 +1,78 @@ +pub mod assignee; +pub mod comment; +pub mod label; +pub mod merge; +pub mod pull_request; +pub mod reaction; +pub mod review; +pub mod types; + +use db::sqlx; +use git::rpc::proto::branch_service_client::BranchServiceClient; +use model::{pull_request::PullRequestModel, repos::RepoModel}; +use session::Session; + +use crate::{AppService, error::AppError, git::rpc_err}; + +impl AppService { + pub(crate) async fn pr_resolve( + &self, + repo_id: uuid::Uuid, + number: i64, + ) -> Result { + sqlx::query_as::<_, PullRequestModel>( + "SELECT id, repo, number, title, body, state, draft, author, \ + source_repo, source_branch, source_sha, target_branch, target_sha, \ + merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at \ + FROM pull_request WHERE repo = $1 AND number = $2 AND deleted_at IS NULL", + ) + .bind(repo_id) + .bind(number) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::PullRequestNotFound) + } + + pub(crate) async fn pr_resolve_repo( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result<(uuid::Uuid, RepoModel), AppError> { + let repo = self.git_require_member(ctx, wk_name, repo_name).await?; + Ok((repo.id, repo)) + } + + pub(crate) async fn pr_resolve_repo_admin( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + ) -> Result<(uuid::Uuid, RepoModel), AppError> { + let repo = self.git_require_admin(ctx, wk_name, repo_name).await?; + Ok((repo.id, repo)) + } + + pub(crate) async fn branch_head_sha( + &self, + repo_id: uuid::Uuid, + branch: &str, + ) -> Result { + let mut client = BranchServiceClient::new(self.git.clone()); + let resp = client + .branch_info(tonic::Request::new( + git::rpc::proto::BranchInfoRequest { + repo_id: repo_id.to_string(), + branch: branch.to_string(), + }, + )) + .await + .map_err(rpc_err)? + .into_inner(); + let sha = resp.branch.and_then(|b| b.oid).map(|oid| oid.value).ok_or( + AppError::NotFound(format!("branch '{}' not found", branch)), + )?; + Ok(sha) + } +} diff --git a/lib/service/pull_request/pull_request.rs b/lib/service/pull_request/pull_request.rs new file mode 100644 index 0000000..7c2df1a --- /dev/null +++ b/lib/service/pull_request/pull_request.rs @@ -0,0 +1,409 @@ +use db::{sqlx, sqlx::AssertSqlSafe}; +use model::pull_request::PullRequestModel; +use serde::Deserialize; +use session::Session; + +use crate::{ + AppService, Pagination, + error::AppError, + issues::types::issue_author, + pull_request::types::{PullRequestFilter, PullRequestResponse}, + session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreatePullRequest { + pub title: String, + pub body: Option, + pub source_branch: String, + pub target_branch: Option, + pub source_repo: Option, + pub draft: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdatePullRequest { + pub title: Option, + pub body: Option>, + pub draft: Option, + pub state: Option, +} + +impl AppService { + pub async fn pr_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + params: CreatePullRequest, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, repo) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + let title = params.title.trim(); + if title.is_empty() { + return Err(AppError::BadRequest( + "pull request title is required".to_string(), + )); + } + + let target_branch = params + .target_branch + .unwrap_or_else(|| repo.default_branch.clone()); + let source_repo_name = + params.source_repo.unwrap_or_else(|| repo_name.to_string()); + + let source_repo_id = if source_repo_name == repo_name { + repo_id + } else { + let wk = self.workspace_resolve(wk_name).await?; + let source_repo = + self.repo_resolve(wk.id, &source_repo_name).await?; + source_repo.id + }; + + let source_sha = self + .branch_head_sha(source_repo_id, ¶ms.source_branch) + .await?; + let target_sha = self.branch_head_sha(repo_id, &target_branch).await?; + + if source_sha == target_sha { + return Err(AppError::Conflict( + "source and target branches are the same".to_string(), + )); + } + + let now = chrono::Utc::now(); + let id = uuid::Uuid::now_v7(); + + let pr = sqlx::query_as::<_, PullRequestModel>( + "INSERT INTO pull_request (id, repo, number, title, body, state, draft, author, \ + source_repo, source_branch, source_sha, target_branch, target_sha, \ + merged_by, merged_at, closed_by, closed_at, created_at, updated_at) \ + VALUES ($1, $2, (SELECT COALESCE(MAX(number), 0) + 1 FROM pull_request WHERE repo = $2 AND deleted_at IS NULL), \ + $3, $4, 'open', $5, $6, $7, $8, $9, $10, $11, NULL, NULL, NULL, NULL, $12, $12) \ + RETURNING id, repo, number, title, body, state, draft, author, \ + source_repo, source_branch, source_sha, target_branch, target_sha, \ + merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at", + ) + .bind(id) + .bind(repo_id) + .bind(title) + .bind(¶ms.body) + .bind(params.draft.unwrap_or(false)) + .bind(user_uid) + .bind(source_repo_id) + .bind(¶ms.source_branch) + .bind(&source_sha) + .bind(&target_branch) + .bind(&target_sha) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.pr_build_response(pr).await + } + + pub async fn pr_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + filter: PullRequestFilter, + pagination: Pagination, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(wk_name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let repo = self.repo_resolve(wk.id, repo_name).await?; + + let mut conditions = vec![ + "pr.repo = $1".to_string(), + "pr.deleted_at IS NULL".to_string(), + ]; + let mut param_idx = 2; + + if filter.state.is_some() { + conditions.push(format!("pr.state = ${param_idx}")); + param_idx += 1; + } + if filter.author.is_some() { + conditions.push(format!( + "EXISTS(SELECT 1 FROM \"user\" u WHERE u.id = pr.author AND u.username = ${param_idx})" + )); + param_idx += 1; + } + if filter.assignee.is_some() { + conditions.push(format!( + "EXISTS(SELECT 1 FROM pull_request_assignee pa INNER JOIN \"user\" u ON u.id = pa.\"user\" \ + WHERE pa.pull_request = pr.id AND u.username = ${param_idx})" + )); + param_idx += 1; + } + if filter.label.is_some() { + conditions.push(format!( + "EXISTS(SELECT 1 FROM pull_request_label pl INNER JOIN label l ON l.id = pl.label \ + WHERE pl.pull_request = pr.id AND l.name = ${param_idx})" + )); + param_idx += 1; + } + + let where_clause = conditions.join(" AND "); + let limit_idx = param_idx; + let offset_idx = param_idx + 1; + + let query = format!( + "SELECT pr.id, pr.repo, pr.number, pr.title, pr.body, pr.state, pr.draft, pr.author, \ + pr.source_repo, pr.source_branch, pr.source_sha, pr.target_branch, pr.target_sha, \ + pr.merged_by, pr.merged_at, pr.closed_by, pr.closed_at, pr.created_at, pr.updated_at, pr.deleted_at \ + FROM pull_request pr WHERE {where_clause} \ + ORDER BY pr.created_at DESC LIMIT ${limit_idx} OFFSET ${offset_idx}" + ); + + let mut q = sqlx::query_as::<_, PullRequestModel>(AssertSqlSafe(query)) + .bind(repo.id); + if let Some(state) = &filter.state { + q = q.bind(state); + } + if let Some(author) = &filter.author { + q = q.bind(author); + } + if let Some(assignee) = &filter.assignee { + q = q.bind(assignee); + } + if let Some(label) = &filter.label { + q = q.bind(label); + } + q = q + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64); + + let prs = q + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for pr in prs { + results.push(self.pr_build_response(pr).await?); + } + Ok(results) + } + + pub async fn pr_get( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result { + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + self.pr_build_response(pr).await + } + + pub async fn pr_update( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: UpdatePullRequest, + ) -> Result { + if let Some(ref state) = params.state { + return match state.as_str() { + "closed" => self.pr_close(ctx, wk_name, repo_name, number).await, + "open" => self.pr_reopen(ctx, wk_name, repo_name, number).await, + other => Err(AppError::BadRequest(format!( + "invalid state '{}': must be 'open' or 'closed'", other + ))), + }; + } + + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let mut pr = self.pr_resolve(repo_id, number).await?; + + if pr.author != user_uid { + return Err(AppError::Forbidden( + "only the author can update this pull request".to_string(), + )); + } + if pr.state != "open" { + return Err(AppError::BadRequest( + "cannot update a closed or merged pull request".to_string(), + )); + } + + let next_title = params + .title + .map(|t| t.trim().to_string()) + .unwrap_or(pr.title.clone()); + if next_title.is_empty() { + return Err(AppError::BadRequest( + "pull request title is required".to_string(), + )); + } + let next_body = params.body.map(Some).unwrap_or(Some(pr.body.clone())); + let next_draft = params.draft.unwrap_or(pr.draft); + let now = chrono::Utc::now(); + + pr = sqlx::query_as::<_, PullRequestModel>( + "UPDATE pull_request SET title = $1, body = $2, draft = $3, updated_at = $4 WHERE id = $5 \ + RETURNING id, repo, number, title, body, state, draft, author, \ + source_repo, source_branch, source_sha, target_branch, target_sha, \ + merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at", + ) + .bind(&next_title) + .bind(&next_body) + .bind(next_draft) + .bind(now) + .bind(pr.id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.pr_build_response(pr).await + } + + pub async fn pr_close( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + if pr.state != "open" { + return Err(AppError::BadRequest( + "pull request is already closed or merged".to_string(), + )); + } + + let now = chrono::Utc::now(); + sqlx::query( + "UPDATE pull_request SET state = 'closed', closed_by = $1, closed_at = $2, updated_at = $2 WHERE id = $3", + ) + .bind(user_uid) + .bind(now) + .bind(pr.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let pr = self.pr_resolve(repo_id, number).await?; + self.pr_build_response(pr).await + } + + pub async fn pr_reopen( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result { + let _user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + if pr.state != "closed" { + return Err(AppError::BadRequest( + "pull request is not closed".to_string(), + )); + } + if pr.merged_by.is_some() { + return Err(AppError::BadRequest( + "merged pull request cannot be reopened".to_string(), + )); + } + + let now = chrono::Utc::now(); + sqlx::query( + "UPDATE pull_request SET state = 'open', closed_by = NULL, closed_at = NULL, updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(pr.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let pr = self.pr_resolve(repo_id, number).await?; + self.pr_build_response(pr).await + } + + pub async fn pr_delete( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result<(), AppError> { + let (repo_id, _) = + self.pr_resolve_repo_admin(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + sqlx::query("UPDATE pull_request SET deleted_at = $1 WHERE id = $2") + .bind(chrono::Utc::now()) + .bind(pr.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub async fn pr_build_response( + &self, + pr: PullRequestModel, + ) -> Result { + let author = self.users_find_by_id(pr.author).await?; + + let merged_by = if let Some(uid) = pr.merged_by { + Some(issue_author(self.users_find_by_id(uid).await?)) + } else { + None + }; + + let closed_by = if let Some(uid) = pr.closed_by { + Some(issue_author(self.users_find_by_id(uid).await?)) + } else { + None + }; + + let labels = self.pr_labels(pr.id).await?; + let assignees = self.pr_assignees_list(pr.id).await?; + let reviews = self.pr_reviews_list(pr.id).await?; + + Ok(PullRequestResponse { + number: pr.number, + title: pr.title, + body: pr.body, + state: pr.state, + draft: pr.draft, + author: issue_author(author), + source_repo: pr.source_repo, + source_branch: pr.source_branch, + source_sha: pr.source_sha, + target_branch: pr.target_branch, + target_sha: pr.target_sha, + merged_by, + merged_at: pr.merged_at, + closed_by, + closed_at: pr.closed_at, + created_at: pr.created_at, + updated_at: pr.updated_at, + labels, + assignees, + reviews, + }) + } +} diff --git a/lib/service/pull_request/reaction.rs b/lib/service/pull_request/reaction.rs new file mode 100644 index 0000000..7ab775d --- /dev/null +++ b/lib/service/pull_request/reaction.rs @@ -0,0 +1,140 @@ +use db::sqlx; +use model::pull_request::PullRequestReactionModel; +use serde::Deserialize; +use session::Session; +use uuid::Uuid; + +use crate::{ + AppService, error::AppError, + pull_request::types::PullRequestReactionResponse, session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AddPrReaction { + pub reaction: String, +} + +impl AppService { + pub async fn pr_add_reaction( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: AddPrReaction, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let now = chrono::Utc::now(); + let reaction = sqlx::query_as::<_, PullRequestReactionModel>( + "INSERT INTO pull_request_reaction (id, pull_request, comment, \"user\", reaction, created_at) \ + VALUES ($1, $2, NULL, $3, $4, $5) \ + RETURNING id, pull_request, comment, \"user\", reaction, created_at", + ) + .bind(uuid::Uuid::now_v7()) + .bind(pr.id) + .bind(user_uid) + .bind(¶ms.reaction) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let user = self.users_find_by_id(reaction.user).await?; + Ok(PullRequestReactionResponse { + id: reaction.id, + user: crate::issues::types::issue_author(user), + reaction: reaction.reaction, + created_at: reaction.created_at, + }) + } + + pub async fn pr_remove_reaction( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + _number: i64, + reaction_id: Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let (_repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + sqlx::query( + "DELETE FROM pull_request_reaction WHERE id = $1 AND \"user\" = $2", + ) + .bind(reaction_id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub async fn pr_comment_add_reaction( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + _number: i64, + comment_id: Uuid, + params: AddPrReaction, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + let now = chrono::Utc::now(); + let reaction = sqlx::query_as::<_, PullRequestReactionModel>( + "INSERT INTO pull_request_reaction (id, pull_request, comment, \"user\", reaction, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6) \ + RETURNING id, pull_request, comment, \"user\", reaction, created_at", + ) + .bind(uuid::Uuid::now_v7()) + .bind(repo_id) + .bind(comment_id) + .bind(user_uid) + .bind(¶ms.reaction) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let user = self.users_find_by_id(reaction.user).await?; + Ok(PullRequestReactionResponse { + id: reaction.id, + user: crate::issues::types::issue_author(user), + reaction: reaction.reaction, + created_at: reaction.created_at, + }) + } + + pub async fn pr_comment_remove_reaction( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + _number: i64, + reaction_id: Uuid, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let (_repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + sqlx::query( + "DELETE FROM pull_request_reaction WHERE id = $1 AND \"user\" = $2", + ) + .bind(reaction_id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } +} diff --git a/lib/service/pull_request/review.rs b/lib/service/pull_request/review.rs new file mode 100644 index 0000000..5a94d2b --- /dev/null +++ b/lib/service/pull_request/review.rs @@ -0,0 +1,240 @@ +use db::sqlx; +use model::pull_request::{ + PullRequestReviewCommentModel, PullRequestReviewModel, +}; +use serde::Deserialize; +use session::Session; +use uuid::Uuid; + +use crate::{ + AppService, + error::AppError, + pull_request::types::{ + PullRequestReviewCommentResponse, PullRequestReviewResponse, + }, + session_user, +}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreatePrReview { + pub state: String, + pub body: Option, + pub commit_sha: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreatePrReviewComment { + pub body: String, + pub path: String, + pub line: Option, + pub side: Option, + pub commit_sha: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct DismissPrReview { + pub dismiss_reason: Option, +} + +impl AppService { + pub async fn pr_review_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + params: CreatePrReview, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let valid_states = + ["pending", "commented", "approved", "changes_requested"]; + if !valid_states.contains(¶ms.state.as_str()) { + return Err(AppError::BadRequest(format!( + "invalid review state: {}", + params.state + ))); + } + + let now = chrono::Utc::now(); + let review = sqlx::query_as::<_, PullRequestReviewModel>( + "INSERT INTO pull_request_review (id, pull_request, reviewer, state, body, commit_sha, \ + submitted_at, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $7, $7) \ + RETURNING id, pull_request, reviewer, state, body, commit_sha, submitted_at, \ + created_at, updated_at, dismissed_by, dismissed_at, dismiss_reason", + ) + .bind(uuid::Uuid::now_v7()) + .bind(pr.id) + .bind(user_uid) + .bind(¶ms.state) + .bind(¶ms.body) + .bind(¶ms.commit_sha) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let reviewer = self.users_find_by_id(review.reviewer).await?; + Ok(PullRequestReviewResponse { + id: review.id, + reviewer: crate::issues::types::issue_author(reviewer), + state: review.state, + body: review.body, + commit_sha: review.commit_sha, + submitted_at: review.submitted_at, + created_at: review.created_at, + }) + } + + pub async fn pr_review_list( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + ) -> Result, AppError> { + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let reviews = sqlx::query_as::<_, PullRequestReviewModel>( + "SELECT id, pull_request, reviewer, state, body, commit_sha, submitted_at, \ + created_at, updated_at, dismissed_by, dismissed_at, dismiss_reason \ + FROM pull_request_review WHERE pull_request = $1 \ + ORDER BY created_at DESC", + ) + .bind(pr.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for review in reviews { + let reviewer = self.users_find_by_id(review.reviewer).await?; + results.push(PullRequestReviewResponse { + id: review.id, + reviewer: crate::issues::types::issue_author(reviewer), + state: review.state, + body: review.body, + commit_sha: review.commit_sha, + submitted_at: review.submitted_at, + created_at: review.created_at, + }); + } + Ok(results) + } + + pub async fn pr_review_dismiss( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + _number: i64, + review_id: Uuid, + params: DismissPrReview, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let _repo_id = self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + + sqlx::query( + "UPDATE pull_request_review SET state = 'dismissed', dismissed_by = $1, dismissed_at = $2, \ + dismiss_reason = $3, updated_at = $2 WHERE id = $4", + ) + .bind(user_uid) + .bind(chrono::Utc::now()) + .bind(¶ms.dismiss_reason) + .bind(review_id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(()) + } + + pub async fn pr_review_comment_create( + &self, + ctx: &Session, + wk_name: &str, + repo_name: &str, + number: i64, + review_id: Option, + params: CreatePrReviewComment, + ) -> Result { + let user_uid = session_user(ctx)?; + let (repo_id, _) = + self.pr_resolve_repo(ctx, wk_name, repo_name).await?; + let pr = self.pr_resolve(repo_id, number).await?; + + let now = chrono::Utc::now(); + let comment = sqlx::query_as::<_, PullRequestReviewCommentModel>( + "INSERT INTO pull_request_review_comment (id, pull_request, review, author, body, path, \ + commit_sha, original_commit_sha, line, original_line, side, resolved, \ + created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, NULL, $8, $8, $9, false, $10, $10) \ + RETURNING id, pull_request, review, author, body, path, commit_sha, \ + original_commit_sha, line, original_line, side, resolved, resolved_by, resolved_at, \ + created_at, updated_at, deleted_at", + ) + .bind(uuid::Uuid::now_v7()) + .bind(pr.id) + .bind(review_id) + .bind(user_uid) + .bind(¶ms.body) + .bind(¶ms.path) + .bind(params.commit_sha.unwrap_or(pr.source_sha.clone())) + .bind(params.line) + .bind(¶ms.side) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let author = self.users_find_by_id(comment.author).await?; + Ok(PullRequestReviewCommentResponse { + id: comment.id, + author: crate::issues::types::issue_author(author), + path: comment.path, + line: comment.line, + body: comment.body, + commit_sha: comment.commit_sha, + resolved: comment.resolved, + created_at: comment.created_at, + updated_at: comment.updated_at, + }) + } + + pub(crate) async fn pr_reviews_list( + &self, + pr_id: uuid::Uuid, + ) -> Result, AppError> { + let reviews = sqlx::query_as::<_, PullRequestReviewModel>( + "SELECT id, pull_request, reviewer, state, body, commit_sha, submitted_at, \ + created_at, updated_at, dismissed_by, dismissed_at, dismiss_reason \ + FROM pull_request_review WHERE pull_request = $1 \ + ORDER BY created_at DESC", + ) + .bind(pr_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut results = Vec::new(); + for review in reviews { + let reviewer = self.users_find_by_id(review.reviewer).await?; + results.push(PullRequestReviewResponse { + id: review.id, + reviewer: crate::issues::types::issue_author(reviewer), + state: review.state, + body: review.body, + commit_sha: review.commit_sha, + submitted_at: review.submitted_at, + created_at: review.created_at, + }); + } + Ok(results) + } +} diff --git a/lib/service/pull_request/types.rs b/lib/service/pull_request/types.rs new file mode 100644 index 0000000..8f2371c --- /dev/null +++ b/lib/service/pull_request/types.rs @@ -0,0 +1,93 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::issues::types::{IssueAuthor, LabelResponse}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct PullRequestResponse { + pub number: i64, + pub title: String, + pub body: Option, + pub state: String, + pub draft: bool, + pub author: IssueAuthor, + #[schema(value_type = String)] + pub source_repo: Uuid, + pub source_branch: String, + pub source_sha: String, + pub target_branch: String, + pub target_sha: String, + pub merged_by: Option, + #[schema(value_type = Option)] + pub merged_at: Option>, + pub closed_by: Option, + #[schema(value_type = Option)] + pub closed_at: Option>, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, + pub labels: Vec, + pub assignees: Vec, + pub reviews: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct PullRequestCommentResponse { + #[schema(value_type = String)] + pub id: Uuid, + pub author: IssueAuthor, + pub body: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct PullRequestReviewResponse { + #[schema(value_type = String)] + pub id: Uuid, + pub reviewer: IssueAuthor, + pub state: String, + pub body: Option, + pub commit_sha: Option, + #[schema(value_type = Option)] + pub submitted_at: Option>, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct PullRequestReviewCommentResponse { + #[schema(value_type = String)] + pub id: Uuid, + pub author: IssueAuthor, + pub path: String, + pub line: Option, + pub body: String, + pub commit_sha: String, + pub resolved: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct PullRequestReactionResponse { + #[schema(value_type = String)] + pub id: Uuid, + pub user: IssueAuthor, + pub reaction: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::IntoParams)] +pub struct PullRequestFilter { + pub state: Option, + pub author: Option, + pub assignee: Option, + pub label: Option, +} diff --git a/lib/service/user/access_token.rs b/lib/service/user/access_token.rs new file mode 100644 index 0000000..ff46e83 --- /dev/null +++ b/lib/service/user/access_token.rs @@ -0,0 +1,230 @@ +use argon2::{Argon2, password_hash::PasswordHasher}; +use db::sqlx; +use model::users::UserTokenModel; +use rand::{RngExt, distr::Alphanumeric}; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserAccessToken { + pub id: i64, + pub name: String, + pub scopes: Vec, + #[schema(value_type = Option)] + pub expires_at: Option>, + pub is_revoked: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct CreatedUserAccessToken { + pub token: String, + pub access_token: UserAccessToken, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateUserAccessToken { + pub name: String, + pub scopes: Vec, + #[schema(value_type = Option)] + pub expires_at: Option>, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserAccessToken { + pub name: Option, + pub scopes: Option>, + #[schema(value_type = Option)] + pub expires_at: Option>>, +} + +impl AppService { + pub async fn user_access_tokens( + &self, + ctx: &Session, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let rows = sqlx::query_as::<_, UserTokenModel>( + "SELECT id, \"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at \ + FROM user_token WHERE \"user\" = $1 ORDER BY created_at DESC", + ) + .bind(user_uid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(Into::into).collect()) + } + + pub async fn user_create_access_token( + &self, + ctx: &Session, + params: CreateUserAccessToken, + ) -> Result { + let user_uid = session_user(ctx)?; + let name = params.name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest( + "token name is required".to_string(), + )); + } + + let token = generate_access_token(); + let token_hash = Argon2::default() + .hash_password(token.as_bytes()) + .map_err(|e| AppError::PasswordHashError(e.to_string()))? + .to_string(); + let scopes = normalize_scopes(params.scopes); + let scopes_str = scopes.join("."); + let now = chrono::Utc::now(); + + let row = sqlx::query_as::<_, UserTokenModel>( + "INSERT INTO user_token (\"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, false, $6, $6) \ + RETURNING id, \"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at", + ) + .bind(user_uid) + .bind(name) + .bind(&token_hash) + .bind(&scopes_str) + .bind(params.expires_at) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(CreatedUserAccessToken { + token, + access_token: row.into(), + }) + } + + pub async fn user_revoke_access_token( + &self, + ctx: &Session, + token_id: i64, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let affected = sqlx::query( + "UPDATE user_token SET is_revoked = true, updated_at = $1 \ + WHERE id = $2 AND \"user\" = $3 AND is_revoked = false", + ) + .bind(chrono::Utc::now()) + .bind(token_id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .rows_affected(); + + if affected == 0 { + return Err(AppError::NotFound( + "access token not found".to_string(), + )); + } + Ok(()) + } + + pub async fn user_update_access_token( + &self, + ctx: &Session, + token_id: i64, + params: UpdateUserAccessToken, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut token = + self.user_access_token_by_id(user_uid, token_id).await?; + + if let Some(name) = params.name { + let name = name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest( + "token name is required".to_string(), + )); + } + token.name = name.to_string(); + } + if let Some(scopes) = params.scopes { + token.scopes = normalize_scopes(scopes); + } + if let Some(expires_at) = params.expires_at { + token.expires_at = expires_at; + } + + let row = sqlx::query_as::<_, UserTokenModel>( + "UPDATE user_token SET name = $1, scopes = $2, expires_at = $3, updated_at = $4 \ + WHERE id = $5 AND \"user\" = $6 AND is_revoked = false \ + RETURNING id, \"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at", + ) + .bind(&token.name) + .bind(token.scopes.join(".")) + .bind(token.expires_at) + .bind(chrono::Utc::now()) + .bind(token_id) + .bind(user_uid) + .fetch_optional(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("access token not found".to_string()))?; + + Ok(row.into()) + } + + async fn user_access_token_by_id( + &self, + user_uid: uuid::Uuid, + token_id: i64, + ) -> Result { + sqlx::query_as::<_, UserTokenModel>( + "SELECT id, \"user\", name, token_hash, scopes, expires_at, is_revoked, created_at, updated_at \ + FROM user_token WHERE id = $1 AND \"user\" = $2 AND is_revoked = false", + ) + .bind(token_id) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .map(Into::into) + .ok_or_else(|| AppError::NotFound("access token not found".to_string())) + } +} + +fn normalize_scopes(scopes: Vec) -> Vec { + scopes + .into_iter() + .map(|scope| scope.trim().to_string()) + .filter(|scope| !scope.is_empty() && !scope.contains('.')) + .collect() +} + +fn generate_access_token() -> String { + #[allow(deprecated)] + let mut rng = rand::rng(); + let token: String = + (0..64).map(|_| rng.sample(Alphanumeric) as char).collect(); + format!("gda_{}", token) +} + +impl From for UserAccessToken { + fn from(value: UserTokenModel) -> Self { + Self { + id: value.id, + name: value.name, + scopes: value + .scopes + .split('.') + .filter(|scope| !scope.is_empty()) + .map(ToString::to_string) + .collect(), + expires_at: value.expires_at, + is_revoked: value.is_revoked, + created_at: value.created_at, + updated_at: value.updated_at, + } + } +} diff --git a/lib/service/user/accessibility.rs b/lib/service/user/accessibility.rs new file mode 100644 index 0000000..5d92601 --- /dev/null +++ b/lib/service/user/accessibility.rs @@ -0,0 +1,110 @@ +use db::sqlx; +use model::users::UserAccessibilityModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserAccessibilityConfig { + pub reduce_motion: bool, + pub high_contrast: bool, + pub screen_reader_optimized: bool, + pub font_scale_percent: i32, + pub color_blind_mode: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserAccessibilityConfig { + pub reduce_motion: Option, + pub high_contrast: Option, + pub screen_reader_optimized: Option, + pub font_scale_percent: Option, + pub color_blind_mode: Option>, +} + +impl AppService { + pub async fn user_update_accessibility_config( + &self, + ctx: &Session, + params: UpdateUserAccessibilityConfig, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut config = self.user_accessibility_config(user_uid).await?; + if let Some(reduce_motion) = params.reduce_motion { + config.reduce_motion = reduce_motion; + } + if let Some(high_contrast) = params.high_contrast { + config.high_contrast = high_contrast; + } + if let Some(screen_reader_optimized) = params.screen_reader_optimized { + config.screen_reader_optimized = screen_reader_optimized; + } + if let Some(font_scale_percent) = params.font_scale_percent { + config.font_scale_percent = font_scale_percent.clamp(50, 200); + } + if let Some(color_blind_mode) = params.color_blind_mode { + config.color_blind_mode = color_blind_mode; + } + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO user_accessibility \ + (\"user\", reduce_motion, high_contrast, screen_reader_optimized, font_scale_percent, color_blind_mode, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $7) \ + ON CONFLICT (\"user\") DO UPDATE SET \ + reduce_motion = EXCLUDED.reduce_motion, high_contrast = EXCLUDED.high_contrast, \ + screen_reader_optimized = EXCLUDED.screen_reader_optimized, font_scale_percent = EXCLUDED.font_scale_percent, \ + color_blind_mode = EXCLUDED.color_blind_mode, updated_at = EXCLUDED.updated_at", + ) + .bind(user_uid) + .bind(config.reduce_motion) + .bind(config.high_contrast) + .bind(config.screen_reader_optimized) + .bind(config.font_scale_percent) + .bind(&config.color_blind_mode) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(config) + } + + pub(crate) async fn user_accessibility_config( + &self, + user_uid: uuid::Uuid, + ) -> Result { + let row = sqlx::query_as::<_, UserAccessibilityModel>( + "SELECT \"user\", reduce_motion, high_contrast, screen_reader_optimized, font_scale_percent, color_blind_mode, created_at, updated_at \ + FROM user_accessibility WHERE \"user\" = $1", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(row.map(Into::into).unwrap_or_default()) + } +} + +impl Default for UserAccessibilityConfig { + fn default() -> Self { + Self { + reduce_motion: false, + high_contrast: false, + screen_reader_optimized: false, + font_scale_percent: 100, + color_blind_mode: None, + } + } +} + +impl From for UserAccessibilityConfig { + fn from(value: UserAccessibilityModel) -> Self { + Self { + reduce_motion: value.reduce_motion, + high_contrast: value.high_contrast, + screen_reader_optimized: value.screen_reader_optimized, + font_scale_percent: value.font_scale_percent, + color_blind_mode: value.color_blind_mode, + } + } +} diff --git a/lib/service/user/appearance.rs b/lib/service/user/appearance.rs new file mode 100644 index 0000000..d33a5b2 --- /dev/null +++ b/lib/service/user/appearance.rs @@ -0,0 +1,110 @@ +use db::sqlx; +use model::users::UserAppearanceModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserAppearanceConfig { + pub theme: String, + pub code_theme: String, + pub layout_density: String, + pub sidebar_collapsed: bool, + pub show_line_numbers: bool, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserAppearanceConfig { + pub theme: Option, + pub code_theme: Option, + pub layout_density: Option, + pub sidebar_collapsed: Option, + pub show_line_numbers: Option, +} + +impl AppService { + pub async fn user_update_appearance_config( + &self, + ctx: &Session, + params: UpdateUserAppearanceConfig, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut config = self.user_appearance_config(user_uid).await?; + if let Some(theme) = params.theme { + config.theme = theme; + } + if let Some(code_theme) = params.code_theme { + config.code_theme = code_theme; + } + if let Some(layout_density) = params.layout_density { + config.layout_density = layout_density; + } + if let Some(sidebar_collapsed) = params.sidebar_collapsed { + config.sidebar_collapsed = sidebar_collapsed; + } + if let Some(show_line_numbers) = params.show_line_numbers { + config.show_line_numbers = show_line_numbers; + } + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO user_appearance \ + (\"user\", theme, code_theme, layout_density, sidebar_collapsed, show_line_numbers, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $7) \ + ON CONFLICT (\"user\") DO UPDATE SET \ + theme = EXCLUDED.theme, code_theme = EXCLUDED.code_theme, layout_density = EXCLUDED.layout_density, \ + sidebar_collapsed = EXCLUDED.sidebar_collapsed, show_line_numbers = EXCLUDED.show_line_numbers, \ + updated_at = EXCLUDED.updated_at", + ) + .bind(user_uid) + .bind(&config.theme) + .bind(&config.code_theme) + .bind(&config.layout_density) + .bind(config.sidebar_collapsed) + .bind(config.show_line_numbers) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(config) + } + + pub(crate) async fn user_appearance_config( + &self, + user_uid: uuid::Uuid, + ) -> Result { + let row = sqlx::query_as::<_, UserAppearanceModel>( + "SELECT \"user\", theme, code_theme, layout_density, sidebar_collapsed, show_line_numbers, created_at, updated_at \ + FROM user_appearance WHERE \"user\" = $1", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(row.map(Into::into).unwrap_or_default()) + } +} + +impl Default for UserAppearanceConfig { + fn default() -> Self { + Self { + theme: "system".to_string(), + code_theme: "github-dark".to_string(), + layout_density: "comfortable".to_string(), + sidebar_collapsed: false, + show_line_numbers: true, + } + } +} + +impl From for UserAppearanceConfig { + fn from(value: UserAppearanceModel) -> Self { + Self { + theme: value.theme, + code_theme: value.code_theme, + layout_density: value.layout_density, + sidebar_collapsed: value.sidebar_collapsed, + show_line_numbers: value.show_line_numbers, + } + } +} diff --git a/lib/service/user/chpc.rs b/lib/service/user/chpc.rs new file mode 100644 index 0000000..1771fdf --- /dev/null +++ b/lib/service/user/chpc.rs @@ -0,0 +1,187 @@ +use std::collections::HashMap; + +use chrono::{Duration, NaiveDate, Utc}; +use db::sqlx; +use serde::{Deserialize, Serialize}; +use session::Session; +use utoipa::{IntoParams, ToSchema}; + +use crate::{AppService, error::AppError}; + +const HEATMAP_CACHE_PREFIX: &str = "user:heatmap"; + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct ContributionHeatmapItem { + pub date: String, + pub count: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct ContributionHeatmapResponse { + pub username: String, + pub total_contributions: i64, + pub heatmap: Vec, + pub start_date: String, + pub end_date: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, IntoParams)] +pub struct ContributionHeatmapQuery { + pub start_date: Option, + pub end_date: Option, +} + +impl AppService { + pub async fn user_chpc( + &self, + ctx: &Session, + query: ContributionHeatmapQuery, + ) -> Result { + let user_uid = ctx.user().ok_or(AppError::Unauthorized)?; + let user = self.auth_find_user_by_uid(user_uid).await?; + self.user_contribution_heatmap_for_user(user.id, user.username, query) + .await + } + + pub(crate) async fn user_contribution_heatmap_for_user( + &self, + user_uid: uuid::Uuid, + username: String, + query: ContributionHeatmapQuery, + ) -> Result { + let (start_date, end_date) = + parse_date_range(query.start_date, query.end_date)?; + let cache_key = build_heatmap_cache_key(user_uid, start_date, end_date); + + if let Ok(Some(cached)) = self + .cache + .get::(&cache_key) + .await + { + return Ok(cached); + } + + let start_dt = start_date + .and_hms_opt(0, 0, 0) + .ok_or(AppError::InternalError)? + .and_utc(); + let end_dt = end_date + .and_hms_opt(23, 59, 59) + .ok_or(AppError::InternalError)? + .and_utc(); + + let rows = sqlx::query_as::<_, (chrono::DateTime,)>( + "SELECT created_at FROM repo_commit \ + WHERE (author = $1 OR committer = $1) AND created_at >= $2 AND created_at <= $3", + ) + .bind(user_uid) + .bind(start_dt) + .bind(end_dt) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let mut sparse = HashMap::::new(); + for (created_at,) in &rows { + let date = created_at.format("%Y-%m-%d").to_string(); + *sparse.entry(date).or_insert(0) += 1; + } + + let mut heatmap = Vec::new(); + let mut current = start_date; + while current <= end_date { + let date = current.format("%Y-%m-%d").to_string(); + let count = sparse.get(&date).copied().unwrap_or(0) as i32; + heatmap.push(ContributionHeatmapItem { date, count }); + current += Duration::days(1); + } + + let response = ContributionHeatmapResponse { + username, + total_contributions: rows.len() as i64, + heatmap, + start_date: start_date.format("%Y-%m-%d").to_string(), + end_date: end_date.format("%Y-%m-%d").to_string(), + }; + + self.cache + .set(&cache_key, &response) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + Ok(response) + } + + pub async fn user_invalidate_chpc_cache( + &self, + ctx: &Session, + ) -> Result<(), AppError> { + let user_uid = ctx.user().ok_or(AppError::Unauthorized)?; + self.invalidate_user_heatmap_cache(user_uid).await + } + + pub(crate) async fn invalidate_user_heatmap_cache( + &self, + user_uid: uuid::Uuid, + ) -> Result<(), AppError> { + let pattern = format!("{}:{}:*", HEATMAP_CACHE_PREFIX, user_uid); + self.cache + .delete_pattern(&pattern) + .await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + Ok(()) + } +} + +fn build_heatmap_cache_key( + user_uid: uuid::Uuid, + start_date: NaiveDate, + end_date: NaiveDate, +) -> String { + format!( + "{}:{}:{}:{}", + HEATMAP_CACHE_PREFIX, + user_uid, + start_date.format("%Y-%m-%d"), + end_date.format("%Y-%m-%d"), + ) +} + +fn parse_date_range( + start_date_str: Option, + end_date_str: Option, +) -> Result<(NaiveDate, NaiveDate), AppError> { + let today = Utc::now().date_naive(); + let one_year_ago = today - Duration::days(365); + + let start_date = match start_date_str { + Some(date) => parse_date(&date, "start_date")?, + None => one_year_ago, + }; + let end_date = match end_date_str { + Some(date) => parse_date(&date, "end_date")?, + None => today, + }; + + if start_date > end_date { + return Err(AppError::BadRequest( + "start_date cannot be later than end_date".to_string(), + )); + } + + if (end_date - start_date).num_days() > 730 { + return Err(AppError::BadRequest( + "date range cannot exceed 2 years".to_string(), + )); + } + + Ok((start_date, end_date)) +} + +fn parse_date(value: &str, field: &str) -> Result { + NaiveDate::parse_from_str(value, "%Y-%m-%d").map_err(|_| { + AppError::BadRequest(format!( + "invalid {field} format, expected YYYY-MM-DD" + )) + }) +} diff --git a/lib/service/user/config.rs b/lib/service/user/config.rs new file mode 100644 index 0000000..67bc809 --- /dev/null +++ b/lib/service/user/config.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; +use session::Session; + +use super::{ + accessibility::UserAccessibilityConfig, appearance::UserAppearanceConfig, + notification::UserNotificationConfig, privacy::UserPrivacyConfig, + profile::UserProfileConfig, +}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserConfigResponse { + pub profile: UserProfileConfig, + pub appearance: UserAppearanceConfig, + pub accessibility: UserAccessibilityConfig, + pub privacy: UserPrivacyConfig, + pub notifications: UserNotificationConfig, +} + +impl AppService { + pub async fn user_config( + &self, + ctx: &Session, + ) -> Result { + let user_uid = session_user(ctx)?; + Ok(UserConfigResponse { + profile: self.user_profile_config(user_uid).await?, + appearance: self.user_appearance_config(user_uid).await?, + accessibility: self.user_accessibility_config(user_uid).await?, + privacy: self.user_privacy_config(user_uid).await?, + notifications: self.user_notification_config(user_uid).await?, + }) + } +} diff --git a/lib/service/user/mod.rs b/lib/service/user/mod.rs new file mode 100644 index 0000000..e9d915a --- /dev/null +++ b/lib/service/user/mod.rs @@ -0,0 +1,9 @@ +pub mod access_token; +pub mod accessibility; +pub mod appearance; +pub mod chpc; +pub mod config; +pub mod notification; +pub mod privacy; +pub mod profile; +pub mod sshkey; diff --git a/lib/service/user/notification.rs b/lib/service/user/notification.rs new file mode 100644 index 0000000..5cdf8e6 --- /dev/null +++ b/lib/service/user/notification.rs @@ -0,0 +1,252 @@ +use db::sqlx; +use model::notify::UserAppNotifyModel; +use model::users::UserNotificationModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserNotificationConfig { + pub email_enabled: bool, + pub in_app_enabled: bool, + pub push_enabled: bool, + pub digest_mode: String, + pub dnd_enabled: bool, + pub dnd_start_minute: Option, + pub dnd_end_minute: Option, + pub marketing_enabled: bool, + pub security_enabled: bool, + pub product_enabled: bool, + pub push_subscription_endpoint: Option, + pub push_subscription_keys_p256dh: Option, + pub push_subscription_keys_auth: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct AppNotificationItem { + pub id: uuid::Uuid, + pub title: String, + pub body: String, + pub notify_type: String, + #[schema(value_type = Option)] + pub read_at: Option>, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +impl From for AppNotificationItem { + fn from(n: UserAppNotifyModel) -> Self { + Self { + id: n.id, + title: n.title, + body: n.body, + notify_type: n.notify_type, + read_at: n.read_at, + created_at: n.created_at, + } + } +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserNotificationConfig { + pub email_enabled: Option, + pub in_app_enabled: Option, + pub push_enabled: Option, + pub digest_mode: Option, + pub dnd_enabled: Option, + pub dnd_start_minute: Option>, + pub dnd_end_minute: Option>, + pub marketing_enabled: Option, + pub security_enabled: Option, + pub product_enabled: Option, + pub push_subscription_endpoint: Option>, + pub push_subscription_keys_p256dh: Option>, + pub push_subscription_keys_auth: Option>, +} + +impl AppService { + pub async fn user_update_notification_config( + &self, + ctx: &Session, + params: UpdateUserNotificationConfig, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut config = self.user_notification_config(user_uid).await?; + if let Some(email_enabled) = params.email_enabled { + config.email_enabled = email_enabled; + } + if let Some(in_app_enabled) = params.in_app_enabled { + config.in_app_enabled = in_app_enabled; + } + if let Some(push_enabled) = params.push_enabled { + config.push_enabled = push_enabled; + } + if let Some(digest_mode) = params.digest_mode { + config.digest_mode = digest_mode; + } + if let Some(dnd_enabled) = params.dnd_enabled { + config.dnd_enabled = dnd_enabled; + } + if let Some(dnd_start_minute) = params.dnd_start_minute { + config.dnd_start_minute = dnd_start_minute; + } + if let Some(dnd_end_minute) = params.dnd_end_minute { + config.dnd_end_minute = dnd_end_minute; + } + if let Some(marketing_enabled) = params.marketing_enabled { + config.marketing_enabled = marketing_enabled; + } + if let Some(security_enabled) = params.security_enabled { + config.security_enabled = security_enabled; + } + if let Some(product_enabled) = params.product_enabled { + config.product_enabled = product_enabled; + } + if let Some(push_subscription_endpoint) = + params.push_subscription_endpoint + { + config.push_subscription_endpoint = push_subscription_endpoint; + } + if let Some(push_subscription_keys_p256dh) = + params.push_subscription_keys_p256dh + { + config.push_subscription_keys_p256dh = + push_subscription_keys_p256dh; + } + if let Some(push_subscription_keys_auth) = + params.push_subscription_keys_auth + { + config.push_subscription_keys_auth = push_subscription_keys_auth; + } + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO user_notification \ + (\"user\", email_enabled, in_app_enabled, push_enabled, digest_mode, dnd_enabled, dnd_start_minute, dnd_end_minute, \ + marketing_enabled, security_enabled, product_enabled, push_subscription_endpoint, push_subscription_keys_p256dh, \ + push_subscription_keys_auth, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $15) \ + ON CONFLICT (\"user\") DO UPDATE SET \ + email_enabled = EXCLUDED.email_enabled, in_app_enabled = EXCLUDED.in_app_enabled, push_enabled = EXCLUDED.push_enabled, \ + digest_mode = EXCLUDED.digest_mode, dnd_enabled = EXCLUDED.dnd_enabled, dnd_start_minute = EXCLUDED.dnd_start_minute, \ + dnd_end_minute = EXCLUDED.dnd_end_minute, marketing_enabled = EXCLUDED.marketing_enabled, security_enabled = EXCLUDED.security_enabled, \ + product_enabled = EXCLUDED.product_enabled, push_subscription_endpoint = EXCLUDED.push_subscription_endpoint, \ + push_subscription_keys_p256dh = EXCLUDED.push_subscription_keys_p256dh, push_subscription_keys_auth = EXCLUDED.push_subscription_keys_auth, \ + updated_at = EXCLUDED.updated_at", + ) + .bind(user_uid) + .bind(config.email_enabled) + .bind(config.in_app_enabled) + .bind(config.push_enabled) + .bind(&config.digest_mode) + .bind(config.dnd_enabled) + .bind(config.dnd_start_minute) + .bind(config.dnd_end_minute) + .bind(config.marketing_enabled) + .bind(config.security_enabled) + .bind(config.product_enabled) + .bind(&config.push_subscription_endpoint) + .bind(&config.push_subscription_keys_p256dh) + .bind(&config.push_subscription_keys_auth) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(config) + } + + pub async fn list_notifications( + &self, + ctx: &Session, + pagination: Pagination, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let rows = sqlx::query_as::<_, UserAppNotifyModel>( + "SELECT id, \"user\", title, body, notify_type, target_type, target_id, metadata, read_at, archived_at, created_at, updated_at \ + FROM user_app_notify WHERE \"user\" = $1 AND archived_at IS NULL \ + ORDER BY created_at DESC \ + OFFSET $2 LIMIT $3", + ) + .bind(user_uid) + .bind(pagination.offset() as i64) + .bind(pagination.limit() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(Into::into).collect()) + } + + pub(crate) async fn unread_notifications_count( + &self, + user_uid: uuid::Uuid, + ) -> Result { + let row: (Option,) = sqlx::query_as( + "SELECT COUNT(*) FROM user_app_notify \ + WHERE \"user\" = $1 AND read_at IS NULL AND archived_at IS NULL", + ) + .bind(user_uid) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.0.unwrap_or(0)) + } + + pub(crate) async fn user_notification_config( + &self, + user_uid: uuid::Uuid, + ) -> Result { + let row = sqlx::query_as::<_, UserNotificationModel>( + "SELECT \"user\", email_enabled, in_app_enabled, push_enabled, digest_mode, dnd_enabled, dnd_start_minute, dnd_end_minute, \ + marketing_enabled, security_enabled, product_enabled, push_subscription_endpoint, push_subscription_keys_p256dh, \ + push_subscription_keys_auth, created_at, updated_at \ + FROM user_notification WHERE \"user\" = $1", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(row.map(Into::into).unwrap_or_default()) + } +} + +impl Default for UserNotificationConfig { + fn default() -> Self { + Self { + email_enabled: true, + in_app_enabled: true, + push_enabled: false, + digest_mode: "daily".to_string(), + dnd_enabled: false, + dnd_start_minute: None, + dnd_end_minute: None, + marketing_enabled: false, + security_enabled: true, + product_enabled: true, + push_subscription_endpoint: None, + push_subscription_keys_p256dh: None, + push_subscription_keys_auth: None, + } + } +} + +impl From for UserNotificationConfig { + fn from(value: UserNotificationModel) -> Self { + Self { + email_enabled: value.email_enabled, + in_app_enabled: value.in_app_enabled, + push_enabled: value.push_enabled, + digest_mode: value.digest_mode, + dnd_enabled: value.dnd_enabled, + dnd_start_minute: value.dnd_start_minute, + dnd_end_minute: value.dnd_end_minute, + marketing_enabled: value.marketing_enabled, + security_enabled: value.security_enabled, + product_enabled: value.product_enabled, + push_subscription_endpoint: value.push_subscription_endpoint, + push_subscription_keys_p256dh: value.push_subscription_keys_p256dh, + push_subscription_keys_auth: value.push_subscription_keys_auth, + } + } +} diff --git a/lib/service/user/privacy.rs b/lib/service/user/privacy.rs new file mode 100644 index 0000000..86a9564 --- /dev/null +++ b/lib/service/user/privacy.rs @@ -0,0 +1,119 @@ +use db::sqlx; +use model::users::UserPrivacyModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserPrivacyConfig { + pub profile_visibility: String, + pub email_visibility: String, + pub activity_visibility: String, + pub allow_search_indexing: bool, + pub allow_direct_messages: bool, + pub show_online_status: bool, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserPrivacyConfig { + pub profile_visibility: Option, + pub email_visibility: Option, + pub activity_visibility: Option, + pub allow_search_indexing: Option, + pub allow_direct_messages: Option, + pub show_online_status: Option, +} + +impl AppService { + pub async fn user_update_privacy_config( + &self, + ctx: &Session, + params: UpdateUserPrivacyConfig, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut config = self.user_privacy_config(user_uid).await?; + if let Some(profile_visibility) = params.profile_visibility { + config.profile_visibility = profile_visibility; + } + if let Some(email_visibility) = params.email_visibility { + config.email_visibility = email_visibility; + } + if let Some(activity_visibility) = params.activity_visibility { + config.activity_visibility = activity_visibility; + } + if let Some(allow_search_indexing) = params.allow_search_indexing { + config.allow_search_indexing = allow_search_indexing; + } + if let Some(allow_direct_messages) = params.allow_direct_messages { + config.allow_direct_messages = allow_direct_messages; + } + if let Some(show_online_status) = params.show_online_status { + config.show_online_status = show_online_status; + } + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO user_privacy \ + (\"user\", profile_visibility, email_visibility, activity_visibility, allow_search_indexing, allow_direct_messages, show_online_status, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8) \ + ON CONFLICT (\"user\") DO UPDATE SET \ + profile_visibility = EXCLUDED.profile_visibility, email_visibility = EXCLUDED.email_visibility, \ + activity_visibility = EXCLUDED.activity_visibility, allow_search_indexing = EXCLUDED.allow_search_indexing, \ + allow_direct_messages = EXCLUDED.allow_direct_messages, show_online_status = EXCLUDED.show_online_status, \ + updated_at = EXCLUDED.updated_at", + ) + .bind(user_uid) + .bind(&config.profile_visibility) + .bind(&config.email_visibility) + .bind(&config.activity_visibility) + .bind(config.allow_search_indexing) + .bind(config.allow_direct_messages) + .bind(config.show_online_status) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(config) + } + + pub(crate) async fn user_privacy_config( + &self, + user_uid: uuid::Uuid, + ) -> Result { + let row = sqlx::query_as::<_, UserPrivacyModel>( + "SELECT \"user\", profile_visibility, email_visibility, activity_visibility, allow_search_indexing, allow_direct_messages, show_online_status, created_at, updated_at \ + FROM user_privacy WHERE \"user\" = $1", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(row.map(Into::into).unwrap_or_default()) + } +} + +impl Default for UserPrivacyConfig { + fn default() -> Self { + Self { + profile_visibility: "public".to_string(), + email_visibility: "private".to_string(), + activity_visibility: "public".to_string(), + allow_search_indexing: true, + allow_direct_messages: true, + show_online_status: true, + } + } +} + +impl From for UserPrivacyConfig { + fn from(value: UserPrivacyModel) -> Self { + Self { + profile_visibility: value.profile_visibility, + email_visibility: value.email_visibility, + activity_visibility: value.activity_visibility, + allow_search_indexing: value.allow_search_indexing, + allow_direct_messages: value.allow_direct_messages, + show_online_status: value.show_online_status, + } + } +} diff --git a/lib/service/user/profile.rs b/lib/service/user/profile.rs new file mode 100644 index 0000000..bff7f3c --- /dev/null +++ b/lib/service/user/profile.rs @@ -0,0 +1,182 @@ +use db::sqlx; +use model::users::UserProfileModel; +use serde::{Deserialize, Serialize}; +use session::Session; +use storage::{ObjectStorage, PutObjectOptions}; +use uuid::Uuid; + +use crate::{AppService, error::AppError, session_user}; + +/// Allowed image MIME types for avatars. +const ALLOWED_AVATAR_TYPES: &[&str] = &[ + "image/png", + "image/jpeg", + "image/webp", + "image/gif", +]; + +/// Maximum avatar file size: 5 MB. +const MAX_AVATAR_SIZE: usize = 5 * 1024 * 1024; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserProfileConfig { + pub language: String, + pub theme: String, + pub timezone: String, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserProfileConfig { + pub language: Option, + pub theme: Option, + pub timezone: Option, + pub avatar_url: Option, +} + +#[derive(Debug, Clone, Serialize, utoipa::ToSchema)] +pub struct AvatarUploadResponse { + pub avatar_url: String, +} + +/// Derive file extension from MIME content type. +fn extension_from_content_type(content_type: &str) -> &str { + match content_type { + "image/png" => "png", + "image/jpeg" => "jpg", + "image/webp" => "webp", + "image/gif" => "gif", + _ => "bin", + } +} + +impl AppService { + pub async fn user_update_profile_config( + &self, + ctx: &Session, + params: UpdateUserProfileConfig, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut config = self.user_profile_config(user_uid).await?; + if let Some(language) = params.language { + config.language = language; + } + if let Some(theme) = params.theme { + config.theme = theme; + } + if let Some(timezone) = params.timezone { + config.timezone = timezone; + } + if let Some(avatar_url) = params.avatar_url { + sqlx::query("UPDATE users SET avatar_url = $1 WHERE id = $2") + .bind(&avatar_url) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO user_profile (\"user\", language, theme, timezone, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $5) \ + ON CONFLICT (\"user\") DO UPDATE SET \ + language = EXCLUDED.language, theme = EXCLUDED.theme, timezone = EXCLUDED.timezone, updated_at = EXCLUDED.updated_at", + ) + .bind(user_uid) + .bind(&config.language) + .bind(&config.theme) + .bind(&config.timezone) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(config) + } + + /// Upload a user avatar image, store it, and update the user's avatar_url. + pub async fn user_upload_avatar( + &self, + ctx: &Session, + bytes: Vec, + content_type: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + + if bytes.len() > MAX_AVATAR_SIZE { + return Err(AppError::AvatarUploadError( + "file size exceeds 5 MB limit".to_string(), + )); + } + if !ALLOWED_AVATAR_TYPES.contains(&content_type) { + return Err(AppError::AvatarUploadError(format!( + "unsupported image type: {content_type}. Allowed: png, jpeg, webp, gif" + ))); + } + + let ext = extension_from_content_type(content_type); + let key = format!( + "avatars/users/{user_uid}-{}.{ext}", + uuid::Uuid::now_v7() + ); + + let stored = self + .storage + .put_bytes( + &key, + bytes, + PutObjectOptions { + content_type: Some(content_type.to_string()), + ..PutObjectOptions::default() + }, + ) + .await + .map_err(|e| { + AppError::AvatarUploadError(format!("storage error: {e}")) + })?; + + sqlx::query("UPDATE users SET avatar_url = $1 WHERE id = $2") + .bind(&stored.url) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(AvatarUploadResponse { + avatar_url: stored.url, + }) + } + + pub(crate) async fn user_profile_config( + &self, + user_uid: Uuid, + ) -> Result { + let row = sqlx::query_as::<_, UserProfileModel>( + "SELECT \"user\", language, theme, timezone, created_at, updated_at \ + FROM user_profile WHERE \"user\" = $1", + ) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(row.map(Into::into).unwrap_or_default()) + } +} + +impl Default for UserProfileConfig { + fn default() -> Self { + Self { + language: "en".to_string(), + theme: "system".to_string(), + timezone: "UTC".to_string(), + } + } +} + +impl From for UserProfileConfig { + fn from(value: UserProfileModel) -> Self { + Self { + language: value.language, + theme: value.theme, + timezone: value.timezone, + } + } +} diff --git a/lib/service/user/sshkey.rs b/lib/service/user/sshkey.rs new file mode 100644 index 0000000..9163400 --- /dev/null +++ b/lib/service/user/sshkey.rs @@ -0,0 +1,248 @@ +use base64::{Engine as _, engine::general_purpose}; +use db::sqlx; +use model::users::UserSshKeyModel; +use serde::{Deserialize, Serialize}; +use session::Session; +use sha2::{Digest, Sha256}; + +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserSshKey { + pub id: i64, + pub title: String, + pub public_key: String, + pub fingerprint: String, + pub key_type: String, + pub key_bits: Option, + pub is_verified: bool, + #[schema(value_type = Option)] + pub last_used_at: Option>, + #[schema(value_type = Option)] + pub expires_at: Option>, + pub is_revoked: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateUserSshKey { + pub title: String, + pub public_key: String, + #[schema(value_type = Option)] + pub expires_at: Option>, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateUserSshKey { + pub title: Option, + #[schema(value_type = Option)] + pub expires_at: Option>>, +} + +impl AppService { + pub async fn user_ssh_keys( + &self, + ctx: &Session, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let rows = sqlx::query_as::<_, UserSshKeyModel>( + "SELECT id, \"user\", title, public_key, fingerprint, key_type, key_bits, is_verified, last_used_at, expires_at, is_revoked, created_at, updated_at \ + FROM user_ssh_key WHERE \"user\" = $1 ORDER BY created_at DESC", + ) + .bind(user_uid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(Into::into).collect()) + } + + pub async fn user_add_ssh_key( + &self, + ctx: &Session, + params: CreateUserSshKey, + ) -> Result { + let user_uid = session_user(ctx)?; + let title = params.title.trim(); + if title.is_empty() { + return Err(AppError::BadRequest( + "ssh key title is required".to_string(), + )); + } + + let parsed = parse_public_key(¶ms.public_key)?; + let now = chrono::Utc::now(); + let row = sqlx::query_as::<_, UserSshKeyModel>( + "INSERT INTO user_ssh_key \ + (\"user\", title, public_key, fingerprint, key_type, key_bits, is_verified, last_used_at, expires_at, is_revoked, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, true, NULL, $7, false, $8, $8) \ + RETURNING id, \"user\", title, public_key, fingerprint, key_type, key_bits, is_verified, last_used_at, expires_at, is_revoked, created_at, updated_at", + ) + .bind(user_uid) + .bind(title) + .bind(parsed.public_key) + .bind(parsed.fingerprint) + .bind(parsed.key_type) + .bind(parsed.key_bits) + .bind(params.expires_at) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(row.into()) + } + + pub async fn user_revoke_ssh_key( + &self, + ctx: &Session, + key_id: i64, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let affected = sqlx::query( + "UPDATE user_ssh_key SET is_revoked = true, updated_at = $1 \ + WHERE id = $2 AND \"user\" = $3 AND is_revoked = false", + ) + .bind(chrono::Utc::now()) + .bind(key_id) + .bind(user_uid) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .rows_affected(); + + if affected == 0 { + return Err(AppError::NotFound("ssh key not found".to_string())); + } + Ok(()) + } + + pub async fn user_update_ssh_key( + &self, + ctx: &Session, + key_id: i64, + params: UpdateUserSshKey, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut key = self.user_ssh_key_by_id(user_uid, key_id).await?; + + if let Some(title) = params.title { + let title = title.trim(); + if title.is_empty() { + return Err(AppError::BadRequest( + "ssh key title is required".to_string(), + )); + } + key.title = title.to_string(); + } + if let Some(expires_at) = params.expires_at { + key.expires_at = expires_at; + } + + let row = sqlx::query_as::<_, UserSshKeyModel>( + "UPDATE user_ssh_key SET title = $1, expires_at = $2, updated_at = $3 \ + WHERE id = $4 AND \"user\" = $5 AND is_revoked = false \ + RETURNING id, \"user\", title, public_key, fingerprint, key_type, key_bits, is_verified, last_used_at, expires_at, is_revoked, created_at, updated_at", + ) + .bind(&key.title) + .bind(key.expires_at) + .bind(chrono::Utc::now()) + .bind(key_id) + .bind(user_uid) + .fetch_optional(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or_else(|| AppError::NotFound("ssh key not found".to_string()))?; + + Ok(row.into()) + } + + async fn user_ssh_key_by_id( + &self, + user_uid: uuid::Uuid, + key_id: i64, + ) -> Result { + sqlx::query_as::<_, UserSshKeyModel>( + "SELECT id, \"user\", title, public_key, fingerprint, key_type, key_bits, is_verified, last_used_at, expires_at, is_revoked, created_at, updated_at \ + FROM user_ssh_key WHERE id = $1 AND \"user\" = $2 AND is_revoked = false", + ) + .bind(key_id) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .map(Into::into) + .ok_or_else(|| AppError::NotFound("ssh key not found".to_string())) + } +} + +struct ParsedPublicKey { + public_key: String, + fingerprint: String, + key_type: String, + key_bits: Option, +} + +fn parse_public_key(public_key: &str) -> Result { + let public_key = public_key.trim(); + let mut parts = public_key.split_whitespace(); + let key_type = parts.next().ok_or_else(|| { + AppError::BadRequest("invalid ssh public key".to_string()) + })?; + let key_data_base64 = parts.next().ok_or_else(|| { + AppError::BadRequest("invalid ssh public key".to_string()) + })?; + + let key_data = + general_purpose::STANDARD + .decode(key_data_base64) + .map_err(|_| { + AppError::BadRequest("invalid ssh public key".to_string()) + })?; + + let mut hasher = Sha256::new(); + hasher.update(&key_data); + let fingerprint = format!( + "SHA256:{}", + general_purpose::STANDARD_NO_PAD.encode(hasher.finalize()) + ); + + Ok(ParsedPublicKey { + public_key: public_key.to_string(), + fingerprint, + key_type: key_type.to_string(), + key_bits: key_bits(key_type), + }) +} + +fn key_bits(key_type: &str) -> Option { + match key_type { + "ssh-ed25519" => Some(256), + "ecdsa-sha2-nistp256" => Some(256), + "ecdsa-sha2-nistp384" => Some(384), + "ecdsa-sha2-nistp521" => Some(521), + _ => None, + } +} + +impl From for UserSshKey { + fn from(value: UserSshKeyModel) -> Self { + Self { + id: value.id, + title: value.title, + public_key: value.public_key, + fingerprint: value.fingerprint, + key_type: value.key_type, + key_bits: value.key_bits, + is_verified: value.is_verified, + last_used_at: value.last_used_at, + expires_at: value.expires_at, + is_revoked: value.is_revoked, + created_at: value.created_at, + updated_at: value.updated_at, + } + } +} diff --git a/lib/service/users/chpc.rs b/lib/service/users/chpc.rs new file mode 100644 index 0000000..4e0b162 --- /dev/null +++ b/lib/service/users/chpc.rs @@ -0,0 +1,17 @@ +use crate::{ + AppService, + error::AppError, + user::chpc::{ContributionHeatmapQuery, ContributionHeatmapResponse}, +}; + +impl AppService { + pub async fn users_chpc_by_username( + &self, + username: &str, + query: ContributionHeatmapQuery, + ) -> Result { + let user = self.users_find_active_user_by_username(username).await?; + self.user_contribution_heatmap_for_user(user.id, user.username, query) + .await + } +} diff --git a/lib/service/users/mod.rs b/lib/service/users/mod.rs new file mode 100644 index 0000000..4aa6448 --- /dev/null +++ b/lib/service/users/mod.rs @@ -0,0 +1,4 @@ +pub mod chpc; +pub mod public; +pub mod relation; +pub mod summary; diff --git a/lib/service/users/public.rs b/lib/service/users/public.rs new file mode 100644 index 0000000..bf50298 --- /dev/null +++ b/lib/service/users/public.rs @@ -0,0 +1,72 @@ +use db::sqlx; +use model::users::{UserPrivacyModel, UserProfileModel}; +use serde::{Deserialize, Serialize}; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct PublicUserResponse { + pub username: String, + pub display_name: String, + pub avatar_url: String, + pub website_url: String, + pub language: String, + pub timezone: String, + pub allow_direct_messages: bool, + pub show_online_status: bool, +} + +impl AppService { + pub async fn users_public_by_username( + &self, + username: &str, + ) -> Result { + let user = self.users_find_active_user_by_username(username).await?; + + let privacy = sqlx::query_as::<_, UserPrivacyModel>( + "SELECT \"user\", profile_visibility, email_visibility, activity_visibility, allow_search_indexing, allow_direct_messages, show_online_status, created_at, updated_at \ + FROM user_privacy WHERE \"user\" = $1", + ) + .bind(user.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if privacy + .as_ref() + .is_some_and(|privacy| privacy.profile_visibility == "private") + { + return Err(AppError::Forbidden("profile is private".to_string())); + } + + let profile = sqlx::query_as::<_, UserProfileModel>( + "SELECT \"user\", language, theme, timezone, created_at, updated_at \ + FROM user_profile WHERE \"user\" = $1", + ) + .bind(user.id) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(PublicUserResponse { + username: user.username, + display_name: user.display_name, + avatar_url: user.avatar_url, + website_url: user.website_url, + language: profile + .as_ref() + .map(|profile| profile.language.clone()) + .unwrap_or_else(|| "en".to_string()), + timezone: profile + .map(|profile| profile.timezone) + .unwrap_or_else(|| "UTC".to_string()), + allow_direct_messages: privacy + .as_ref() + .map(|privacy| privacy.allow_direct_messages) + .unwrap_or(true), + show_online_status: privacy + .map(|privacy| privacy.show_online_status) + .unwrap_or(true), + }) + } +} diff --git a/lib/service/users/relation.rs b/lib/service/users/relation.rs new file mode 100644 index 0000000..8a5629c --- /dev/null +++ b/lib/service/users/relation.rs @@ -0,0 +1,401 @@ +use std::collections::HashMap; + +use db::sqlx; +use model::users::UserModel; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, Pagination, error::AppError, non_empty, session_user}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserRelationStatus { + pub username: String, + pub avatar_url: Option, + pub is_following: bool, + pub is_followed_by: bool, + pub is_blocked: bool, + pub has_blocked_me: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserRelationCard { + pub username: String, + pub display_name: Option, + pub avatar_url: Option, + pub is_following: bool, + pub is_blocked: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserRelationCounts { + pub followers: i64, + pub following: i64, + pub blocked: i64, +} + +impl AppService { + pub async fn users_follow_by_username( + &self, + ctx: &Session, + username: &str, + ) -> Result { + let current_uid = session_user(ctx)?; + let target = self.users_relation_target(current_uid, username).await?; + + if self.users_is_blocked(current_uid, target.id).await? + || self.users_is_blocked(target.id, current_uid).await? + { + return Err(AppError::Forbidden( + "cannot follow a blocked user".to_string(), + )); + } + + if !self.users_is_following(current_uid, target.id).await? { + sqlx::query( + "INSERT INTO user_favorite (\"user\", target, created_at) VALUES ($1, $2, $3)", + ) + .bind(current_uid) + .bind(target.id) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + self.users_relation_status_for(current_uid, target).await + } + + pub async fn users_unfollow_by_username( + &self, + ctx: &Session, + username: &str, + ) -> Result { + let current_uid = session_user(ctx)?; + let target = self.users_relation_target(current_uid, username).await?; + + sqlx::query( + "DELETE FROM user_favorite WHERE \"user\" = $1 AND target = $2", + ) + .bind(current_uid) + .bind(target.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.users_relation_status_for(current_uid, target).await + } + + pub async fn users_block_by_username( + &self, + ctx: &Session, + username: &str, + ) -> Result { + let current_uid = session_user(ctx)?; + let target = self.users_relation_target(current_uid, username).await?; + + if !self.users_is_blocked(current_uid, target.id).await? { + sqlx::query( + "INSERT INTO user_blacklist (\"user\", black, created_at) VALUES ($1, $2, $3)", + ) + .bind(current_uid) + .bind(target.id) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + sqlx::query( + "DELETE FROM user_favorite \ + WHERE (\"user\" = $1 AND target = $2) OR (\"user\" = $2 AND target = $1)", + ) + .bind(current_uid) + .bind(target.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.users_relation_status_for(current_uid, target).await + } + + pub async fn users_unblock_by_username( + &self, + ctx: &Session, + username: &str, + ) -> Result { + let current_uid = session_user(ctx)?; + let target = self.users_relation_target(current_uid, username).await?; + + sqlx::query( + "DELETE FROM user_blacklist WHERE \"user\" = $1 AND black = $2", + ) + .bind(current_uid) + .bind(target.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.users_relation_status_for(current_uid, target).await + } + + pub async fn users_relation_status_by_username( + &self, + ctx: &Session, + username: &str, + ) -> Result { + let current_uid = session_user(ctx)?; + let target = self.users_relation_target(current_uid, username).await?; + self.users_relation_status_for(current_uid, target).await + } + + pub async fn users_followers_by_username( + &self, + ctx: Option<&Session>, + username: &str, + pagination: Pagination, + ) -> Result, AppError> { + let target = self.users_find_active_user_by_username(username).await?; + let current_uid = ctx.and_then(Session::user); + let users = sqlx::query_as::<_, UserModel>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at \ + FROM user_favorite f \ + INNER JOIN \"user\" u ON u.id = f.\"user\" \ + WHERE f.target = $1 AND u.allow_use = true \ + ORDER BY f.created_at DESC \ + LIMIT $2 OFFSET $3", + ) + .bind(target.id) + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.users_relation_cards(current_uid, users).await + } + + pub async fn users_following_by_username( + &self, + ctx: Option<&Session>, + username: &str, + pagination: Pagination, + ) -> Result, AppError> { + let target = self.users_find_active_user_by_username(username).await?; + let current_uid = ctx.and_then(Session::user); + let users = sqlx::query_as::<_, UserModel>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at \ + FROM user_favorite f \ + INNER JOIN \"user\" u ON u.id = f.target \ + WHERE f.\"user\" = $1 AND u.allow_use = true \ + ORDER BY f.created_at DESC \ + LIMIT $2 OFFSET $3", + ) + .bind(target.id) + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.users_relation_cards(current_uid, users).await + } + + pub async fn users_blocked( + &self, + ctx: &Session, + ) -> Result, AppError> { + let current_uid = session_user(ctx)?; + let users = sqlx::query_as::<_, UserModel>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at \ + FROM user_blacklist b \ + INNER JOIN \"user\" u ON u.id = b.black \ + WHERE b.\"user\" = $1 AND u.allow_use = true \ + ORDER BY b.created_at DESC", + ) + .bind(current_uid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + self.users_relation_cards(Some(current_uid), users).await + } + + pub async fn users_relation_counts_by_username( + &self, + username: &str, + ) -> Result { + let target = self.users_find_active_user_by_username(username).await?; + let followers = sqlx::query_scalar::<_, i64>( + "SELECT COUNT(*) FROM user_favorite f \ + INNER JOIN \"user\" u ON u.id = f.\"user\" \ + WHERE f.target = $1 AND u.allow_use = true", + ) + .bind(target.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let following = sqlx::query_scalar::<_, i64>( + "SELECT COUNT(*) FROM user_favorite f \ + INNER JOIN \"user\" u ON u.id = f.target \ + WHERE f.\"user\" = $1 AND u.allow_use = true", + ) + .bind(target.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let blocked = sqlx::query_scalar::<_, i64>( + "SELECT COUNT(*) FROM user_blacklist b \ + INNER JOIN \"user\" u ON u.id = b.black \ + WHERE b.\"user\" = $1 AND u.allow_use = true", + ) + .bind(target.id) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(UserRelationCounts { + followers, + following, + blocked, + }) + } + + async fn users_relation_target( + &self, + current_uid: uuid::Uuid, + username: &str, + ) -> Result { + let target = self.users_find_active_user_by_username(username).await?; + if target.id == current_uid { + return Err(AppError::BadRequest( + "cannot operate on yourself".to_string(), + )); + } + Ok(target) + } + + async fn users_relation_status_for( + &self, + current_uid: uuid::Uuid, + target: UserModel, + ) -> Result { + Ok(UserRelationStatus { + username: target.username, + avatar_url: non_empty(target.avatar_url), + is_following: self + .users_is_following(current_uid, target.id) + .await?, + is_followed_by: self + .users_is_following(target.id, current_uid) + .await?, + is_blocked: self.users_is_blocked(current_uid, target.id).await?, + has_blocked_me: self + .users_is_blocked(target.id, current_uid) + .await?, + }) + } + + async fn users_relation_cards( + &self, + current_uid: Option, + users: Vec, + ) -> Result, AppError> { + let user_ids = users.iter().map(|user| user.id).collect::>(); + let following = match current_uid { + Some(current_uid) => { + self.users_following_set(current_uid, &user_ids).await? + } + None => HashMap::new(), + }; + let blocked = match current_uid { + Some(current_uid) => { + self.users_blocked_set(current_uid, &user_ids).await? + } + None => HashMap::new(), + }; + + Ok(users + .into_iter() + .map(|user| UserRelationCard { + username: user.username, + display_name: non_empty(user.display_name), + avatar_url: non_empty(user.avatar_url), + is_following: following.contains_key(&user.id), + is_blocked: blocked.contains_key(&user.id), + }) + .collect()) + } + + async fn users_is_following( + &self, + user_uid: uuid::Uuid, + target_uid: uuid::Uuid, + ) -> Result { + let exists = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM user_favorite WHERE \"user\" = $1 AND target = $2)", + ) + .bind(user_uid) + .bind(target_uid) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(exists) + } + + async fn users_is_blocked( + &self, + user_uid: uuid::Uuid, + target_uid: uuid::Uuid, + ) -> Result { + let exists = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM user_blacklist WHERE \"user\" = $1 AND black = $2)", + ) + .bind(user_uid) + .bind(target_uid) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(exists) + } + + async fn users_following_set( + &self, + user_uid: uuid::Uuid, + target_uids: &[uuid::Uuid], + ) -> Result, AppError> { + if target_uids.is_empty() { + return Ok(HashMap::new()); + } + let rows = sqlx::query_as::<_, (uuid::Uuid,)>( + "SELECT target FROM user_favorite WHERE \"user\" = $1 AND target = ANY($2)", + ) + .bind(user_uid) + .bind(target_uids) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(rows.into_iter().map(|(uid,)| (uid, ())).collect()) + } + + async fn users_blocked_set( + &self, + user_uid: uuid::Uuid, + target_uids: &[uuid::Uuid], + ) -> Result, AppError> { + if target_uids.is_empty() { + return Ok(HashMap::new()); + } + let rows = sqlx::query_as::<_, (uuid::Uuid,)>( + "SELECT black FROM user_blacklist WHERE \"user\" = $1 AND black = ANY($2)", + ) + .bind(user_uid) + .bind(target_uids) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(rows.into_iter().map(|(uid,)| (uid, ())).collect()) + } +} diff --git a/lib/service/users/summary.rs b/lib/service/users/summary.rs new file mode 100644 index 0000000..98a4895 --- /dev/null +++ b/lib/service/users/summary.rs @@ -0,0 +1,70 @@ +use db::sqlx; +use model::users::UserModel; +use serde::{Deserialize, Serialize}; + +use crate::{AppService, error::AppError}; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct UserSummaryResponse { + pub username: String, + pub display_name: String, + pub avatar_url: String, + pub website_url: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +impl AppService { + pub async fn users_summary_by_username( + &self, + username: &str, + ) -> Result { + let user = self.users_find_active_user_by_username(username).await?; + Ok(user.into()) + } + + /// Get a user's avatar URL by username. + pub async fn users_get_avatar_url( + &self, + username: &str, + ) -> Result { + let user = self.users_find_active_user_by_username(username).await?; + if user.avatar_url.is_empty() { + return Err(AppError::NotFound("avatar not found".to_string())); + } + Ok(user.avatar_url) + } + + pub(crate) async fn users_find_active_user_by_username( + &self, + username: &str, + ) -> Result { + let user = sqlx::query_as::<_, UserModel>( + "SELECT id, username, display_name, avatar_url, website_url, allow_use, can_search, \ + last_sign_in_at, created_at, updated_at \ + FROM \"user\" WHERE username = $1", + ) + .bind(username) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::UserNotFound)?; + + if !user.allow_use { + return Err(AppError::UserNotFound); + } + Ok(user) + } +} + +impl From for UserSummaryResponse { + fn from(value: UserModel) -> Self { + Self { + username: value.username, + display_name: value.display_name, + avatar_url: value.avatar_url, + website_url: value.website_url, + created_at: value.created_at, + } + } +} diff --git a/lib/service/workspace/group.rs b/lib/service/workspace/group.rs new file mode 100644 index 0000000..a3f35f4 --- /dev/null +++ b/lib/service/workspace/group.rs @@ -0,0 +1,227 @@ +use db::sqlx; +use model::workspace::WkGroupModel; +use serde::Deserialize; +use session::Session; + +use super::types::{ + WorkspaceGroupMemberRow, WorkspaceGroupResponse, WorkspaceMemberResponse, + group_response, +}; +use crate::{AppService, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateWorkspaceGroup { + pub name: String, + pub avatar_url: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateWorkspaceGroup { + pub name: Option, + pub avatar_url: Option>, +} + +impl AppService { + pub async fn workspace_groups( + &self, + ctx: &Session, + name: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let rows = sqlx::query_as::<_, WkGroupModel>( + "SELECT id, name, wk, created_at, avatar_url, is_deleted FROM wk_group \ + WHERE wk = $1 ORDER BY created_at DESC", + ) + .bind(wk.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows.into_iter().map(group_response).collect()) + } + + pub async fn workspace_create_group( + &self, + ctx: &Session, + name: &str, + params: CreateWorkspaceGroup, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + use super::types::normalize_name; + let group_name = normalize_name(¶ms.name)?; + + let row = sqlx::query_as::<_, WkGroupModel>( + "INSERT INTO wk_group (id, name, wk, created_at, avatar_url, is_deleted) \ + VALUES ($1, $2, $3, $4, $5, false) \ + RETURNING id, name, wk, created_at, avatar_url, is_deleted", + ) + .bind(uuid::Uuid::now_v7()) + .bind(group_name) + .bind(wk.id) + .bind(chrono::Utc::now()) + .bind(params.avatar_url) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(group_response(row)) + } + + pub async fn workspace_update_group( + &self, + ctx: &Session, + name: &str, + group_name: &str, + params: UpdateWorkspaceGroup, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let group = self.workspace_group_by_name(wk.id, group_name).await?; + + use super::types::normalize_name; + let group_name = match params.name { + Some(name) => normalize_name(&name)?, + None => group.name, + }; + let avatar_url = params.avatar_url.unwrap_or(group.avatar_url); + + let row = sqlx::query_as::<_, WkGroupModel>( + "UPDATE wk_group SET name = $1, avatar_url = $2 WHERE id = $3 AND wk = $4 \ + RETURNING id, name, wk, created_at, avatar_url, is_deleted", + ) + .bind(group_name) + .bind(avatar_url) + .bind(group.id) + .bind(wk.id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(group_response(row)) + } + + pub async fn workspace_delete_group( + &self, + ctx: &Session, + name: &str, + group_name: &str, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let group = self.workspace_group_by_name(wk.id, group_name).await?; + sqlx::query( + "UPDATE wk_group SET is_deleted = true WHERE id = $1 AND wk = $2", + ) + .bind(group.id) + .bind(wk.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } + + pub async fn workspace_add_group_member( + &self, + ctx: &Session, + name: &str, + group_name: &str, + username: &str, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let group = self.workspace_group_by_name(wk.id, group_name).await?; + let target = self.users_find_active_user_by_username(username).await?; + self.workspace_require_member(wk.id, target.id).await?; + + sqlx::query( + "INSERT INTO wk_gp_member (\"user\", gp, join_at, leave_at) VALUES ($1, $2, $3, NULL) \ + ON CONFLICT (\"user\", gp) DO UPDATE SET leave_at = NULL", + ) + .bind(target.id) + .bind(group.id) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } + + pub async fn workspace_remove_group_member( + &self, + ctx: &Session, + name: &str, + group_name: &str, + username: &str, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let group = self.workspace_group_by_name(wk.id, group_name).await?; + let target = self.users_find_active_user_by_username(username).await?; + + sqlx::query("UPDATE wk_gp_member SET leave_at = $1 WHERE \"user\" = $2 AND gp = $3") + .bind(chrono::Utc::now()) + .bind(target.id) + .bind(group.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } + + pub async fn workspace_group_members( + &self, + ctx: &Session, + name: &str, + group_name: &str, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + let group = self.workspace_group_by_name(wk.id, group_name).await?; + + let rows = sqlx::query_as::<_, WorkspaceGroupMemberRow>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at, gm.join_at \ + FROM wk_gp_member gm \ + INNER JOIN \"user\" u ON u.id = gm.\"user\" \ + WHERE gm.gp = $1 AND gm.leave_at IS NULL \ + ORDER BY gm.join_at ASC", + ) + .bind(group.id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(WorkspaceMemberResponse::from) + .collect()) + } + + pub(crate) async fn workspace_group_by_name( + &self, + wk_id: uuid::Uuid, + group_name: &str, + ) -> Result { + sqlx::query_as::<_, WkGroupModel>( + "SELECT id, name, wk, created_at, avatar_url, is_deleted \ + FROM wk_group WHERE wk = $1 AND name = $2 AND is_deleted = false", + ) + .bind(wk_id) + .bind(group_name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("workspace group not found".to_string())) + } +} diff --git a/lib/service/workspace/join.rs b/lib/service/workspace/join.rs new file mode 100644 index 0000000..46898b2 --- /dev/null +++ b/lib/service/workspace/join.rs @@ -0,0 +1,609 @@ +use db::sqlx; +use model::workspace::{ + WkApplyJoinModel, WkJoinApprovalModel, WkJoinStrategyModel, WorkspaceModel, +}; +use serde::{Deserialize, Serialize}; +use session::Session; + +use crate::{AppService, error::AppError, session_user}; + +const JOIN_STATUS_PENDING: &str = "pending"; +const JOIN_STATUS_APPROVED: &str = "approved"; +const JOIN_STATUS_REJECTED: &str = "rejected"; +const JOIN_STATUS_CANCELLED: &str = "cancelled"; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WorkspaceJoinStrategyResponse { + pub workspace_name: String, + pub workspace_avatar_url: String, + pub require_approval: bool, + pub require_question: bool, + pub question: Option, + pub has_answer: bool, + pub enabled: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateWorkspaceJoinStrategy { + pub require_approval: Option, + pub require_question: Option, + pub question: Option>, + pub answer: Option>, + pub enabled: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WorkspaceJoinApplyResponse { + pub workspace_name: String, + pub workspace_avatar_url: String, + pub username: String, + pub avatar_url: Option, + pub status: String, + pub question: Option, + pub answer: Option, + pub message: Option, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + #[schema(value_type = String)] + pub updated_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateWorkspaceJoinApply { + pub answer: Option, + pub message: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema, utoipa::IntoParams)] +pub struct ListWorkspaceJoinApply { + pub status: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WorkspaceJoinApprovalResponse { + pub workspace_name: String, + pub workspace_avatar_url: String, + pub username: String, + pub avatar_url: Option, + pub approver_username: String, + pub approver_avatar_url: Option, + pub approved: bool, + pub reason: Option, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct ApproveWorkspaceJoinApply { + pub approved: bool, + pub reason: Option, +} + +impl AppService { + pub async fn workspace_join_strategy( + &self, + name: &str, + ) -> Result { + let wk = self.workspace_resolve(name).await?; + let strategy = self.workspace_join_strategy_by_wk(wk.id).await?; + Ok(strategy_response(strategy, &wk)) + } + + pub async fn workspace_update_join_strategy( + &self, + ctx: &Session, + name: &str, + params: UpdateWorkspaceJoinStrategy, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let existing = self.workspace_join_strategy_by_wk(wk.id).await?; + let require_approval = + params.require_approval.unwrap_or(existing.require_approval); + let mut require_question = + params.require_question.unwrap_or(existing.require_question); + let question = params.question.unwrap_or(existing.question); + let answer = params.answer.unwrap_or(existing.answer); + let enabled = params.enabled.unwrap_or(existing.enabled); + + if require_question + && question + .as_ref() + .is_none_or(|question| question.trim().is_empty()) + { + return Err(AppError::BadRequest( + "join question is required when require_question is true" + .to_string(), + )); + } + if question.is_none() { + require_question = false; + } + + let now = chrono::Utc::now(); + let saved = if self.workspace_join_strategy_exists(wk.id).await? { + sqlx::query_as::<_, WkJoinStrategyModel>( + "UPDATE wk_join_strategy SET require_approval = $1, require_question = $2, \ + question = $3, answer = $4, enabled = $5, updated_at = $6 \ + WHERE wk = $7 \ + RETURNING wk, require_approval, require_question, question, answer, enabled, created_at, updated_at", + ) + .bind(require_approval) + .bind(require_question) + .bind(clean_optional(question)) + .bind(clean_optional(answer)) + .bind(enabled) + .bind(now) + .bind(wk.id) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + } else { + sqlx::query_as::<_, WkJoinStrategyModel>( + "INSERT INTO wk_join_strategy \ + (wk, require_approval, require_question, question, answer, enabled, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $7) \ + RETURNING wk, require_approval, require_question, question, answer, enabled, created_at, updated_at", + ) + .bind(wk.id) + .bind(require_approval) + .bind(require_question) + .bind(clean_optional(question)) + .bind(clean_optional(answer)) + .bind(enabled) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + }; + + Ok(strategy_response(saved, &wk)) + } + + pub async fn workspace_apply_join( + &self, + ctx: &Session, + name: &str, + params: CreateWorkspaceJoinApply, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + if self.workspace_member(wk.id, user_uid).await.is_ok() { + return Err(AppError::Conflict( + "user is already a workspace member".to_string(), + )); + } + if self + .workspace_has_pending_join_apply(wk.id, user_uid) + .await? + { + return Err(AppError::Conflict( + "join application is already pending".to_string(), + )); + } + + let strategy = self.workspace_join_strategy_by_wk(wk.id).await?; + let answer = clean_optional(params.answer); + if strategy.enabled && strategy.require_question { + let expected = strategy + .answer + .as_ref() + .map(|answer| answer.trim()) + .unwrap_or_default(); + let actual = answer + .as_ref() + .map(|answer| answer.trim()) + .unwrap_or_default(); + if actual.is_empty() { + return Err(AppError::BadRequest( + "join answer is required".to_string(), + )); + } + if !expected.is_empty() && actual != expected { + return Err(AppError::PermissionDenied); + } + } + + let status = if strategy.enabled && strategy.require_approval { + JOIN_STATUS_PENDING + } else { + JOIN_STATUS_APPROVED + }; + + let now = chrono::Utc::now(); + let apply = sqlx::query_as::<_, WkApplyJoinModel>( + "INSERT INTO wk_apply_join \ + (id, wk, \"user\", status, question, answer, message, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8) \ + RETURNING id, wk, \"user\", status, question, answer, message, created_at, updated_at", + ) + .bind(uuid::Uuid::now_v7()) + .bind(wk.id) + .bind(user_uid) + .bind(status) + .bind(strategy.question.clone()) + .bind(answer) + .bind(clean_optional(params.message)) + .bind(now) + .fetch_one(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if status == JOIN_STATUS_APPROVED { + self.workspace_join_add_member(wk.id, user_uid).await?; + } + + let current_user = self.auth_find_user_by_uid(user_uid).await?; + Ok(apply_response( + apply, + &wk, + current_user.username, + clean_optional(Some(current_user.avatar_url)), + )) + } + + pub async fn workspace_my_join_applies( + &self, + ctx: &Session, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let rows = sqlx::query_as::<_, WorkspaceJoinApplyRow>( + "SELECT a.id, a.wk, w.name AS workspace_name, w.avatar_url AS workspace_avatar_url, \ + a.\"user\", u.username, u.avatar_url, a.status, a.question, a.answer, a.message, a.created_at, a.updated_at \ + FROM wk_apply_join a \ + INNER JOIN workspace w ON w.id = a.wk \ + INNER JOIN \"user\" u ON u.id = a.\"user\" \ + WHERE a.\"user\" = $1 ORDER BY a.created_at DESC", + ) + .bind(user_uid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(WorkspaceJoinApplyResponse::from) + .collect()) + } + + pub async fn workspace_cancel_join_apply( + &self, + ctx: &Session, + name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + let user = self.auth_find_user_by_uid(user_uid).await?; + let apply = sqlx::query_as::<_, WkApplyJoinModel>( + "UPDATE wk_apply_join SET status = $1, updated_at = $2 \ + WHERE wk = $3 AND \"user\" = $4 AND status = $5 \ + RETURNING id, wk, \"user\", status, question, answer, message, created_at, updated_at", + ) + .bind(JOIN_STATUS_CANCELLED) + .bind(chrono::Utc::now()) + .bind(wk.id) + .bind(user_uid) + .bind(JOIN_STATUS_PENDING) + .fetch_optional(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("join application not found".to_string()))?; + + Ok(apply_response( + apply, + &wk, + user.username, + clean_optional(Some(user.avatar_url)), + )) + } + + pub async fn workspace_join_applies( + &self, + ctx: &Session, + name: &str, + query: ListWorkspaceJoinApply, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let status = query + .status + .unwrap_or_else(|| JOIN_STATUS_PENDING.to_string()); + + let rows = sqlx::query_as::<_, WorkspaceJoinApplyRow>( + "SELECT a.id, a.wk, w.name AS workspace_name, w.avatar_url AS workspace_avatar_url, \ + a.\"user\", u.username, u.avatar_url, a.status, a.question, a.answer, a.message, a.created_at, a.updated_at \ + FROM wk_apply_join a \ + INNER JOIN workspace w ON w.id = a.wk \ + INNER JOIN \"user\" u ON u.id = a.\"user\" \ + WHERE a.wk = $1 AND a.status = $2 \ + ORDER BY a.created_at ASC", + ) + .bind(wk.id) + .bind(status) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(WorkspaceJoinApplyResponse::from) + .collect()) + } + + pub async fn workspace_approve_join_apply( + &self, + ctx: &Session, + name: &str, + username: &str, + params: ApproveWorkspaceJoinApply, + ) -> Result { + let approver = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, approver).await?; + let applicant = + self.users_find_active_user_by_username(username).await?; + let approver_user = self.auth_find_user_by_uid(approver).await?; + let apply = self + .workspace_pending_join_apply_by_user(wk.id, applicant.id) + .await?; + if apply.status != JOIN_STATUS_PENDING { + return Err(AppError::Conflict( + "join application has already been processed".to_string(), + )); + } + + let now = chrono::Utc::now(); + let approval_id = uuid::Uuid::now_v7(); + let next_status = if params.approved { + JOIN_STATUS_APPROVED + } else { + JOIN_STATUS_REJECTED + }; + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + let approval = sqlx::query_as::<_, WkJoinApprovalModel>( + "INSERT INTO wk_join_approval \ + (id, apply, wk, \"user\", approver, approved, reason, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) \ + RETURNING id, apply, wk, \"user\", approver, approved, reason, created_at", + ) + .bind(approval_id) + .bind(apply.id) + .bind(wk.id) + .bind(apply.user) + .bind(approver) + .bind(params.approved) + .bind(clean_optional(params.reason)) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query("UPDATE wk_apply_join SET status = $1, updated_at = $2 WHERE id = $3") + .bind(next_status) + .bind(now) + .bind(apply.id) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if params.approved { + sqlx::query( + "INSERT INTO wk_member (wk, \"user\", owner, admin, join_at, leave_at) \ + VALUES ($1, $2, false, false, $3, NULL) \ + ON CONFLICT (wk, \"user\") DO UPDATE SET leave_at = NULL", + ) + .bind(wk.id) + .bind(apply.user) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + } + + txn.commit().await.map_err(|_| AppError::TxnError)?; + Ok(approval_response( + approval, + &wk, + applicant.username, + clean_optional(Some(applicant.avatar_url)), + approver_user.username, + clean_optional(Some(approver_user.avatar_url)), + )) + } + + async fn workspace_join_strategy_by_wk( + &self, + wk: uuid::Uuid, + ) -> Result { + let strategy = sqlx::query_as::<_, WkJoinStrategyModel>( + "SELECT wk, require_approval, require_question, question, answer, enabled, created_at, updated_at \ + FROM wk_join_strategy WHERE wk = $1", + ) + .bind(wk) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let now = chrono::Utc::now(); + Ok(strategy.unwrap_or(WkJoinStrategyModel { + wk, + require_approval: false, + require_question: false, + question: None, + answer: None, + enabled: false, + created_at: now, + updated_at: now, + })) + } + + async fn workspace_join_strategy_exists( + &self, + wk: uuid::Uuid, + ) -> Result { + sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM wk_join_strategy WHERE wk = $1)", + ) + .bind(wk) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string())) + } + + async fn workspace_has_pending_join_apply( + &self, + wk: uuid::Uuid, + user: uuid::Uuid, + ) -> Result { + sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM wk_apply_join WHERE wk = $1 AND \"user\" = $2 AND status = $3)", + ) + .bind(wk) + .bind(user) + .bind(JOIN_STATUS_PENDING) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string())) + } + + async fn workspace_pending_join_apply_by_user( + &self, + wk: uuid::Uuid, + user: uuid::Uuid, + ) -> Result { + sqlx::query_as::<_, WkApplyJoinModel>( + "SELECT id, wk, \"user\", status, question, answer, message, created_at, updated_at \ + FROM wk_apply_join WHERE wk = $1 AND \"user\" = $2 AND status = $3", + ) + .bind(wk) + .bind(user) + .bind(JOIN_STATUS_PENDING) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("join application not found".to_string())) + } + + async fn workspace_join_add_member( + &self, + wk: uuid::Uuid, + user: uuid::Uuid, + ) -> Result<(), AppError> { + sqlx::query( + "INSERT INTO wk_member (wk, \"user\", owner, admin, join_at, leave_at) \ + VALUES ($1, $2, false, false, $3, NULL) \ + ON CONFLICT (wk, \"user\") DO UPDATE SET leave_at = NULL", + ) + .bind(wk) + .bind(user) + .bind(chrono::Utc::now()) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } +} + +#[derive(db::sqlx::FromRow)] +struct WorkspaceJoinApplyRow { + workspace_name: String, + workspace_avatar_url: String, + username: String, + avatar_url: String, + status: String, + question: Option, + answer: Option, + message: Option, + created_at: chrono::DateTime, + updated_at: chrono::DateTime, +} + +fn strategy_response( + value: WkJoinStrategyModel, + wk: &WorkspaceModel, +) -> WorkspaceJoinStrategyResponse { + WorkspaceJoinStrategyResponse { + workspace_name: wk.name.clone(), + workspace_avatar_url: wk.avatar_url.clone(), + require_approval: value.require_approval, + require_question: value.require_question, + question: value.question, + has_answer: value.answer.is_some(), + enabled: value.enabled, + created_at: value.created_at, + updated_at: value.updated_at, + } +} + +fn apply_response( + value: WkApplyJoinModel, + wk: &WorkspaceModel, + username: String, + avatar_url: Option, +) -> WorkspaceJoinApplyResponse { + WorkspaceJoinApplyResponse { + workspace_name: wk.name.clone(), + workspace_avatar_url: wk.avatar_url.clone(), + username, + avatar_url, + status: value.status, + question: value.question, + answer: value.answer, + message: value.message, + created_at: value.created_at, + updated_at: value.updated_at, + } +} + +fn approval_response( + value: WkJoinApprovalModel, + wk: &WorkspaceModel, + username: String, + avatar_url: Option, + approver_username: String, + approver_avatar_url: Option, +) -> WorkspaceJoinApprovalResponse { + WorkspaceJoinApprovalResponse { + workspace_name: wk.name.clone(), + workspace_avatar_url: wk.avatar_url.clone(), + username, + avatar_url, + approver_username, + approver_avatar_url, + approved: value.approved, + reason: value.reason, + created_at: value.created_at, + } +} + +fn clean_optional(value: Option) -> Option { + value.and_then(|value| { + let value = value.trim().to_string(); + if value.is_empty() { None } else { Some(value) } + }) +} + +impl From for WorkspaceJoinApplyResponse { + fn from(value: WorkspaceJoinApplyRow) -> Self { + Self { + workspace_name: value.workspace_name, + workspace_avatar_url: value.workspace_avatar_url, + username: value.username, + avatar_url: clean_optional(Some(value.avatar_url)), + status: value.status, + question: value.question, + answer: value.answer, + message: value.message, + created_at: value.created_at, + updated_at: value.updated_at, + } + } +} diff --git a/lib/service/workspace/member.rs b/lib/service/workspace/member.rs new file mode 100644 index 0000000..2bc8a8f --- /dev/null +++ b/lib/service/workspace/member.rs @@ -0,0 +1,155 @@ +use db::sqlx; +use serde::Deserialize; +use session::Session; + +use super::types::{ + WorkspaceMemberResponse, WorkspaceMemberRow, member_response, +}; +use crate::{AppService, Pagination, error::AppError, session_user}; + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct AddWorkspaceMember { + pub username: String, + pub admin: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateWorkspaceMember { + pub admin: bool, +} + +impl AppService { + pub async fn workspace_members( + &self, + ctx: &Session, + name: &str, + pagination: Pagination, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_member(wk.id, user_uid).await?; + + let rows = sqlx::query_as::<_, WorkspaceMemberRow>( + "SELECT u.id, u.username, u.display_name, u.avatar_url, u.website_url, u.allow_use, u.can_search, \ + u.last_sign_in_at, u.created_at, u.updated_at, m.owner, m.admin, m.join_at \ + FROM wk_member m \ + INNER JOIN \"user\" u ON u.id = m.\"user\" \ + WHERE m.wk = $1 AND m.leave_at IS NULL \ + ORDER BY m.owner DESC, m.admin DESC, m.join_at ASC \ + LIMIT $2 OFFSET $3", + ) + .bind(wk.id) + .bind(pagination.limit() as i64) + .bind(pagination.offset() as i64) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(WorkspaceMemberResponse::from) + .collect()) + } + + pub async fn workspace_add_member( + &self, + ctx: &Session, + name: &str, + params: AddWorkspaceMember, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + let target = self + .users_find_active_user_by_username(¶ms.username) + .await?; + + let now = chrono::Utc::now(); + sqlx::query( + "INSERT INTO wk_member (wk, \"user\", owner, admin, join_at, leave_at) \ + VALUES ($1, $2, false, $3, $4, NULL) \ + ON CONFLICT (wk, \"user\") DO UPDATE SET admin = EXCLUDED.admin, leave_at = NULL", + ) + .bind(wk.id) + .bind(target.id) + .bind(params.admin.unwrap_or(false)) + .bind(now) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let member = self.workspace_member(wk.id, target.id).await?; + Ok(member_response( + target, + member.owner, + member.admin, + member.join_at, + )) + } + + pub async fn workspace_update_member( + &self, + ctx: &Session, + name: &str, + username: &str, + params: UpdateWorkspaceMember, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_owner(wk.id, user_uid).await?; + let target = self.users_find_active_user_by_username(username).await?; + let member = self.workspace_member(wk.id, target.id).await?; + if member.owner { + return Err(AppError::BadRequest( + "cannot update workspace owner role".to_string(), + )); + } + + sqlx::query( + "UPDATE wk_member SET admin = $1 WHERE wk = $2 AND \"user\" = $3 AND leave_at IS NULL", + ) + .bind(params.admin) + .bind(wk.id) + .bind(target.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + let member = self.workspace_member(wk.id, target.id).await?; + Ok(member_response( + target, + member.owner, + member.admin, + member.join_at, + )) + } + + pub async fn workspace_remove_member( + &self, + ctx: &Session, + name: &str, + username: &str, + ) -> Result<(), AppError> { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + let target = self.users_find_active_user_by_username(username).await?; + let target_member = self.workspace_member(wk.id, target.id).await?; + if target_member.owner { + return Err(AppError::BadRequest( + "cannot remove workspace owner".to_string(), + )); + } + if user_uid != target.id { + self.workspace_require_admin(wk.id, user_uid).await?; + } + + sqlx::query("UPDATE wk_member SET leave_at = $1 WHERE wk = $2 AND \"user\" = $3") + .bind(chrono::Utc::now()) + .bind(wk.id) + .bind(target.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(()) + } +} diff --git a/lib/service/workspace/mod.rs b/lib/service/workspace/mod.rs new file mode 100644 index 0000000..bf0bfee --- /dev/null +++ b/lib/service/workspace/mod.rs @@ -0,0 +1,5 @@ +pub mod group; +pub mod join; +pub mod member; +pub mod types; +pub mod workspace; diff --git a/lib/service/workspace/types.rs b/lib/service/workspace/types.rs new file mode 100644 index 0000000..22492e7 --- /dev/null +++ b/lib/service/workspace/types.rs @@ -0,0 +1,167 @@ +use model::{ + users::UserModel, + workspace::{WkGroupModel, WorkspaceModel}, +}; +use serde::{Deserialize, Serialize}; + +use crate::{error::AppError, non_empty}; +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WorkspaceResponse { + pub name: String, + pub description: String, + pub avatar_url: String, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, + pub owner: bool, + pub admin: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WorkspaceMemberResponse { + pub username: String, + pub display_name: Option, + pub avatar_url: Option, + pub owner: bool, + pub admin: bool, + #[schema(value_type = String)] + pub join_at: chrono::DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct WorkspaceGroupResponse { + pub name: String, + pub avatar_url: Option, + pub is_deleted: bool, + #[schema(value_type = String)] + pub created_at: chrono::DateTime, +} + +pub(crate) fn normalize_name(name: &str) -> Result { + let name = name.trim(); + if name.is_empty() { + return Err(AppError::BadRequest( + "workspace name is required".to_string(), + )); + } + if name.len() > 64 { + return Err(AppError::BadRequest( + "workspace name is too long".to_string(), + )); + } + if !name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') + { + return Err(AppError::BadRequest( + "workspace name can only contain letters, numbers, '-' and '_'" + .to_string(), + )); + } + Ok(name.to_string()) +} + +pub(crate) fn workspace_response( + wk: WorkspaceModel, + owner: bool, + admin: bool, +) -> WorkspaceResponse { + WorkspaceResponse { + name: wk.name, + description: wk.description, + avatar_url: wk.avatar_url, + created_at: wk.created_at, + owner, + admin, + } +} + +pub(crate) fn member_response( + user: UserModel, + owner: bool, + admin: bool, + join_at: chrono::DateTime, +) -> WorkspaceMemberResponse { + WorkspaceMemberResponse { + username: user.username, + display_name: non_empty(user.display_name), + avatar_url: non_empty(user.avatar_url), + owner, + admin, + join_at, + } +} + +pub(crate) fn group_response(group: WkGroupModel) -> WorkspaceGroupResponse { + WorkspaceGroupResponse { + name: group.name, + avatar_url: group.avatar_url, + is_deleted: group.is_deleted, + created_at: group.created_at, + } +} +impl From for WorkspaceModel { + fn from(value: WorkspaceListRow) -> Self { + Self { + id: value.id, + name: value.name, + description: value.description, + avatar_url: value.avatar_url, + created_at: value.created_at, + } + } +} + +#[derive(db::sqlx::FromRow)] +pub(crate) struct WorkspaceListRow { + id: uuid::Uuid, + pub name: String, + pub description: String, + pub avatar_url: String, + pub created_at: chrono::DateTime, + pub owner: bool, + pub admin: bool, +} + +impl From for WorkspaceMemberResponse { + fn from(value: WorkspaceMemberRow) -> Self { + Self { + username: value.username, + display_name: non_empty(value.display_name), + avatar_url: non_empty(value.avatar_url), + owner: value.owner, + admin: value.admin, + join_at: value.join_at, + } + } +} + +#[derive(db::sqlx::FromRow)] +pub(crate) struct WorkspaceMemberRow { + pub username: String, + pub display_name: String, + pub avatar_url: String, + pub owner: bool, + pub admin: bool, + pub join_at: chrono::DateTime, +} + +impl From for WorkspaceMemberResponse { + fn from(value: WorkspaceGroupMemberRow) -> Self { + Self { + username: value.username, + display_name: non_empty(value.display_name), + avatar_url: non_empty(value.avatar_url), + owner: false, + admin: false, + join_at: value.join_at, + } + } +} + +#[derive(db::sqlx::FromRow)] +pub(crate) struct WorkspaceGroupMemberRow { + pub username: String, + pub display_name: String, + pub avatar_url: String, + pub join_at: chrono::DateTime, +} diff --git a/lib/service/workspace/workspace.rs b/lib/service/workspace/workspace.rs new file mode 100644 index 0000000..683ee57 --- /dev/null +++ b/lib/service/workspace/workspace.rs @@ -0,0 +1,397 @@ +use db::sqlx; +use model::workspace::{WkMemberModel, WorkspaceModel}; +use serde::{Deserialize, Serialize}; +use session::Session; +use storage::{ObjectStorage, PutObjectOptions}; + +use super::types::{ + WorkspaceListRow, WorkspaceResponse, normalize_name, workspace_response, +}; +use crate::{AppService, error::AppError, session_user}; + +const ALLOWED_AVATAR_TYPES: &[&str] = &[ + "image/png", + "image/jpeg", + "image/webp", + "image/gif", +]; +const MAX_AVATAR_SIZE: usize = 5 * 1024 * 1024; + +#[derive(Debug, Clone, Serialize, utoipa::ToSchema)] +pub struct AvatarUploadResponse { + pub avatar_url: String, +} + +fn extension_from_content_type(content_type: &str) -> &str { + match content_type { + "image/png" => "png", + "image/jpeg" => "jpg", + "image/webp" => "webp", + "image/gif" => "gif", + _ => "bin", + } +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct CreateWorkspace { + pub name: String, + pub description: Option, + pub avatar_url: Option, +} + +#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] +pub struct UpdateWorkspace { + pub name: Option, + pub description: Option, + pub avatar_url: Option, +} + +impl AppService { + pub async fn workspace_create( + &self, + ctx: &Session, + params: CreateWorkspace, + ) -> Result { + let user_uid = session_user(ctx)?; + let name = normalize_name(¶ms.name)?; + self.workspace_ensure_name_available(&name).await?; + + let wk_id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + let description = params.description.unwrap_or_default(); + let avatar_url = params.avatar_url.unwrap_or_default(); + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + + let workspace = sqlx::query_as::<_, WorkspaceModel>( + "INSERT INTO workspace (id, name, description, avatar_url, created_at) \ + VALUES ($1, $2, $3, $4, $5) \ + RETURNING id, name, description, avatar_url, created_at", + ) + .bind(wk_id) + .bind(&name) + .bind(&description) + .bind(&avatar_url) + .bind(now) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + sqlx::query( + "INSERT INTO wk_member (wk, \"user\", owner, admin, join_at, leave_at) \ + VALUES ($1, $2, true, true, $3, NULL)", + ) + .bind(wk_id) + .bind(user_uid) + .bind(now) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + txn.commit().await.map_err(|_| AppError::TxnError)?; + Ok(workspace_response(workspace, true, true)) + } + + pub async fn workspace_my( + &self, + ctx: &Session, + ) -> Result, AppError> { + let user_uid = session_user(ctx)?; + + let rows = sqlx::query_as::<_, WorkspaceListRow>( + "SELECT w.id, w.name, w.description, w.avatar_url, w.created_at, m.owner, m.admin \ + FROM wk_member m \ + INNER JOIN workspace w ON w.id = m.wk \ + WHERE m.\"user\" = $1 AND m.leave_at IS NULL \ + ORDER BY w.created_at DESC", + ) + .bind(user_uid) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(rows + .into_iter() + .map(|row| { + let owner = row.owner; + let admin = row.admin; + workspace_response(row.into(), owner, admin) + }) + .collect()) + } + + pub async fn workspace_get( + &self, + ctx: &Session, + name: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + let member = self.workspace_member(wk.id, user_uid).await?; + Ok(workspace_response(wk, member.owner, member.admin)) + } + + pub async fn workspace_update( + &self, + ctx: &Session, + name: &str, + params: UpdateWorkspace, + ) -> Result { + let user_uid = session_user(ctx)?; + let mut wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + let next_name = match params.name { + Some(name) => { + let name = normalize_name(&name)?; + if name != wk.name { + self.workspace_ensure_name_available(&name).await?; + Some(name) + } else { + None + } + } + None => None, + }; + let next_avatar_url = + params.avatar_url.unwrap_or_else(|| wk.avatar_url.clone()); + let next_description = + params.description.unwrap_or_else(|| wk.description.clone()); + + let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?; + if let Some(next_name) = &next_name { + sqlx::query( + "INSERT INTO wk_history_name (id, wk, name, changed_by, created_at) \ + VALUES ($1, $2, $3, $4, $5)", + ) + .bind(uuid::Uuid::now_v7()) + .bind(wk.id) + .bind(&wk.name) + .bind(user_uid) + .bind(chrono::Utc::now()) + .execute(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + wk.name = next_name.clone(); + } + + wk = sqlx::query_as::<_, WorkspaceModel>( + "UPDATE workspace SET name = $1, description = $2, avatar_url = $3 WHERE id = $4 \ + RETURNING id, name, description, avatar_url, created_at", + ) + .bind(&wk.name) + .bind(&next_description) + .bind(&next_avatar_url) + .bind(wk.id) + .fetch_one(&mut **txn.inner_mut()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + txn.commit().await.map_err(|_| AppError::TxnError)?; + + let member = self.workspace_member(wk.id, user_uid).await?; + Ok(workspace_response(wk, member.owner, member.admin)) + } + + /// Get a workspace's avatar URL by workspace name. + pub async fn workspace_get_avatar_url( + &self, + name: &str, + ) -> Result { + let wk = self.workspace_resolve(name).await?; + if wk.avatar_url.is_empty() { + return Err(AppError::NotFound("avatar not found".to_string())); + } + Ok(wk.avatar_url) + } + + /// Upload a workspace avatar image, store it, and update the workspace's avatar_url. + pub async fn workspace_upload_avatar( + &self, + ctx: &Session, + name: &str, + bytes: Vec, + content_type: &str, + ) -> Result { + let user_uid = session_user(ctx)?; + let wk = self.workspace_resolve(name).await?; + self.workspace_require_admin(wk.id, user_uid).await?; + + if bytes.len() > MAX_AVATAR_SIZE { + return Err(AppError::AvatarUploadError( + "file size exceeds 5 MB limit".to_string(), + )); + } + if !ALLOWED_AVATAR_TYPES.contains(&content_type) { + return Err(AppError::AvatarUploadError(format!( + "unsupported image type: {content_type}. Allowed: png, jpeg, webp, gif" + ))); + } + + let ext = extension_from_content_type(content_type); + let key = format!( + "avatars/workspaces/{wk_id}-{ts}.{ext}", + wk_id = wk.id, + ts = uuid::Uuid::now_v7() + ); + + let stored = self + .storage + .put_bytes( + &key, + bytes, + PutObjectOptions { + content_type: Some(content_type.to_string()), + ..PutObjectOptions::default() + }, + ) + .await + .map_err(|e| { + AppError::AvatarUploadError(format!("storage error: {e}")) + })?; + + sqlx::query( + "UPDATE workspace SET avatar_url = $1 WHERE id = $2", + ) + .bind(&stored.url) + .bind(wk.id) + .execute(self.db.writer()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + Ok(AvatarUploadResponse { + avatar_url: stored.url, + }) + } + + pub(crate) async fn workspace_resolve( + &self, + name: &str, + ) -> Result { + if let Some(wk) = sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at FROM workspace WHERE name = $1", + ) + .bind(name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + return Ok(wk); + } + + if let Some(history) = sqlx::query_as::<_, (uuid::Uuid,)>( + "SELECT wk FROM wk_history_name WHERE name = $1 ORDER BY created_at DESC LIMIT 1", + ) + .bind(name) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + { + return sqlx::query_as::<_, WorkspaceModel>( + "SELECT id, name, description, avatar_url, created_at FROM workspace WHERE id = $1", + ) + .bind(history.0) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::NotFound("workspace not found".to_string())); + } + + Err(AppError::NotFound("workspace not found".to_string())) + } + + pub(crate) async fn workspace_member( + &self, + wk_id: uuid::Uuid, + user_uid: uuid::Uuid, + ) -> Result { + sqlx::query_as::<_, WkMemberModel>( + "SELECT wk, \"user\", owner, admin, join_at, leave_at \ + FROM wk_member WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL", + ) + .bind(wk_id) + .bind(user_uid) + .fetch_optional(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))? + .ok_or(AppError::PermissionDenied) + } + + pub(crate) async fn workspace_require_member( + &self, + wk_id: uuid::Uuid, + user_uid: uuid::Uuid, + ) -> Result { + self.workspace_member(wk_id, user_uid).await + } + + pub(crate) async fn workspace_require_admin( + &self, + wk_id: uuid::Uuid, + user_uid: uuid::Uuid, + ) -> Result { + let member = self.workspace_member(wk_id, user_uid).await?; + if member.owner || member.admin { + Ok(member) + } else { + Err(AppError::PermissionDenied) + } + } + + pub(crate) async fn workspace_require_owner( + &self, + wk_id: uuid::Uuid, + user_uid: uuid::Uuid, + ) -> Result { + let member = self.workspace_member(wk_id, user_uid).await?; + if member.owner { + Ok(member) + } else { + Err(AppError::PermissionDenied) + } + } + + async fn workspace_ensure_name_available( + &self, + name: &str, + ) -> Result<(), AppError> { + let current = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM workspace WHERE name = $1)", + ) + .bind(name) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + let history = sqlx::query_scalar::<_, bool>( + "SELECT EXISTS(SELECT 1 FROM wk_history_name WHERE name = $1)", + ) + .bind(name) + .fetch_one(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + + if current || history { + Err(AppError::Conflict( + "workspace name already exists".to_string(), + )) + } else { + Ok(()) + } + } + + pub async fn workspace_my_inner( + &self, + user_id: uuid::Uuid, + ) -> Result)>, AppError> { + let rows = sqlx::query_as::<_, (String, Option)>( + "SELECT w.name, w.description \ + FROM wk_member m \ + INNER JOIN workspace w ON w.id = m.wk \ + WHERE m.\"user\" = $1 AND m.leave_at IS NULL \ + ORDER BY w.created_at DESC", + ) + .bind(user_id) + .fetch_all(self.db.reader()) + .await + .map_err(|e| AppError::DatabaseError(e.to_string()))?; + Ok(rows) + } +} diff --git a/lib/session/Cargo.toml b/lib/session/Cargo.toml new file mode 100644 index 0000000..cde9dde --- /dev/null +++ b/lib/session/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "session" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true +[lib] +path = "lib.rs" +name = "session" +[dependencies] +actix-service = { workspace = true } +actix-utils = { workspace = true } +actix-web = { workspace = true, features = ["cookies", "secure-cookies"] } + +anyhow = { workspace = true } +derive_more = { workspace = true, features = ["display", "error", "from"] } +rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +uuid = { workspace = true,features = ["serde","v7","v4"] } +redis = { workspace = true, features = ["tokio-comp", "connection-manager", "cluster", "cluster-async"] } +deadpool-redis = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "sync"] } +tracing = { workspace = true } +[lints] +workspace = true diff --git a/lib/session/config.rs b/lib/session/config.rs new file mode 100644 index 0000000..a184d1f --- /dev/null +++ b/lib/session/config.rs @@ -0,0 +1,223 @@ +use actix_web::cookie::{Key, SameSite, time::Duration}; +use derive_more::derive::From; + +use crate::{SessionMiddleware, storage::SessionStore}; + +#[derive(Debug, Clone, From)] +#[non_exhaustive] +pub enum SessionLifecycle { + BrowserSession(BrowserSession), + PersistentSession(PersistentSession), +} + +#[derive(Debug, Clone)] +pub struct BrowserSession { + state_ttl: Duration, + state_ttl_extension_policy: TtlExtensionPolicy, +} + +impl BrowserSession { + pub fn state_ttl(mut self, ttl: Duration) -> Self { + self.state_ttl = ttl; + self + } + + pub fn state_ttl_extension_policy( + mut self, + ttl_extension_policy: TtlExtensionPolicy, + ) -> Self { + self.state_ttl_extension_policy = ttl_extension_policy; + self + } +} + +impl Default for BrowserSession { + fn default() -> Self { + Self { + state_ttl: default_ttl(), + state_ttl_extension_policy: default_ttl_extension_policy(), + } + } +} + +#[derive(Debug, Clone)] +pub struct PersistentSession { + session_ttl: Duration, + ttl_extension_policy: TtlExtensionPolicy, +} + +impl PersistentSession { + #[doc(alias = "max_age", alias = "max age", alias = "expires")] + pub fn session_ttl(mut self, session_ttl: Duration) -> Self { + self.session_ttl = session_ttl; + self + } + + pub fn session_ttl_extension_policy( + mut self, + ttl_extension_policy: TtlExtensionPolicy, + ) -> Self { + self.ttl_extension_policy = ttl_extension_policy; + self + } +} + +impl Default for PersistentSession { + fn default() -> Self { + Self { + session_ttl: default_ttl(), + ttl_extension_policy: default_ttl_extension_policy(), + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TtlExtensionPolicy { + OnEveryRequest, + OnStateChanges, +} + +#[derive(Debug, Clone, Copy)] +pub enum CookieContentSecurity { + Private, + Signed, +} + +pub(crate) const fn default_ttl() -> Duration { + Duration::days(1) +} + +pub(crate) const fn default_ttl_extension_policy() -> TtlExtensionPolicy { + TtlExtensionPolicy::OnStateChanges +} + +#[must_use] +pub struct SessionMiddlewareBuilder { + storage_backend: Store, + configuration: Configuration, +} + +impl SessionMiddlewareBuilder { + pub(crate) fn new(store: Store, configuration: Configuration) -> Self { + Self { + storage_backend: store, + configuration, + } + } + + pub fn cookie_name(mut self, name: String) -> Self { + self.configuration.cookie.name = name; + self + } + + pub fn cookie_secure(mut self, secure: bool) -> Self { + self.configuration.cookie.secure = secure; + self + } + + pub fn session_lifecycle>( + mut self, + session_lifecycle: S, + ) -> Self { + match session_lifecycle.into() { + SessionLifecycle::BrowserSession(BrowserSession { + state_ttl, + state_ttl_extension_policy, + }) => { + self.configuration.cookie.max_age = None; + self.configuration.session.state_ttl = state_ttl; + self.configuration.ttl_extension_policy = + state_ttl_extension_policy; + } + SessionLifecycle::PersistentSession(PersistentSession { + session_ttl, + ttl_extension_policy, + }) => { + self.configuration.cookie.max_age = Some(session_ttl); + self.configuration.session.state_ttl = session_ttl; + self.configuration.ttl_extension_policy = ttl_extension_policy; + } + } + + self + } + + pub fn cookie_same_site(mut self, same_site: SameSite) -> Self { + self.configuration.cookie.same_site = same_site; + self + } + + pub fn cookie_path(mut self, path: String) -> Self { + self.configuration.cookie.path = path; + self + } + + pub fn cookie_domain(mut self, domain: Option) -> Self { + self.configuration.cookie.domain = domain; + self + } + + pub fn cookie_content_security( + mut self, + content_security: CookieContentSecurity, + ) -> Self { + self.configuration.cookie.content_security = content_security; + self + } + + pub fn cookie_http_only(mut self, http_only: bool) -> Self { + self.configuration.cookie.http_only = http_only; + self + } + + #[must_use] + pub fn build(self) -> SessionMiddleware { + SessionMiddleware::from_parts(self.storage_backend, self.configuration) + } +} + +#[derive(Clone)] +pub(crate) struct Configuration { + pub(crate) cookie: CookieConfiguration, + pub(crate) session: SessionConfiguration, + pub(crate) ttl_extension_policy: TtlExtensionPolicy, +} + +#[derive(Clone)] +pub(crate) struct SessionConfiguration { + pub(crate) state_ttl: Duration, +} + +#[derive(Clone)] +pub(crate) struct CookieConfiguration { + pub(crate) secure: bool, + pub(crate) http_only: bool, + pub(crate) name: String, + pub(crate) same_site: SameSite, + pub(crate) path: String, + pub(crate) domain: Option, + pub(crate) max_age: Option, + pub(crate) content_security: CookieContentSecurity, + pub(crate) key: Key, +} + +pub(crate) fn default_configuration(key: Key) -> Configuration { + Configuration { + cookie: CookieConfiguration { + secure: true, + http_only: true, + name: "id".into(), + same_site: SameSite::Lax, + path: "/".into(), + domain: None, + max_age: None, + content_security: CookieContentSecurity::Private, + key, + }, + session: SessionConfiguration { + state_ttl: default_ttl(), + }, + ttl_extension_policy: default_ttl_extension_policy(), + } +} diff --git a/lib/session/lib.rs b/lib/session/lib.rs new file mode 100644 index 0000000..c306c16 --- /dev/null +++ b/lib/session/lib.rs @@ -0,0 +1,16 @@ +#![forbid(unsafe_code)] + +pub mod config; +mod middleware; +mod session; +mod session_ext; +pub mod storage; + +pub use self::{ + middleware::SessionMiddleware, + session::{ + Session, SessionGetError, SessionInsertError, SessionStatus, + SessionUser, + }, + session_ext::SessionExt, +}; diff --git a/lib/session/middleware.rs b/lib/session/middleware.rs new file mode 100644 index 0000000..4473700 --- /dev/null +++ b/lib/session/middleware.rs @@ -0,0 +1,360 @@ +use std::{fmt, future::Future, pin::Pin, rc::Rc}; + +use actix_utils::future::{Ready, ready}; +use actix_web::{ + HttpResponse, + body::MessageBody, + cookie::{Cookie, CookieJar, Key}, + dev::{ + ResponseHead, Service, ServiceRequest, ServiceResponse, Transform, + forward_ready, + }, + http::header::{HeaderValue, SET_COOKIE}, +}; +use anyhow::Context; +use serde_json::{Map, Value}; + +use crate::{ + Session, SessionStatus, + config::{ + self, Configuration, CookieConfiguration, CookieContentSecurity, + SessionMiddlewareBuilder, TtlExtensionPolicy, + }, + storage::{LoadError, SessionKey, SessionStore}, +}; + +#[derive(Clone)] +pub struct SessionMiddleware { + storage_backend: Rc, + configuration: Rc, +} + +impl SessionMiddleware { + pub fn new(store: Store, key: Key) -> Self { + Self::builder(store, key).build() + } + + pub fn builder(store: Store, key: Key) -> SessionMiddlewareBuilder { + SessionMiddlewareBuilder::new(store, config::default_configuration(key)) + } + + pub(crate) fn from_parts( + store: Store, + configuration: Configuration, + ) -> Self { + Self { + storage_backend: Rc::new(store), + configuration: Rc::new(configuration), + } + } +} + +impl Transform for SessionMiddleware +where + S: Service< + ServiceRequest, + Response = ServiceResponse, + Error = actix_web::Error, + > + 'static, + S::Future: 'static, + B: MessageBody + 'static, + Store: SessionStore + 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + type Transform = InnerSessionMiddleware; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(InnerSessionMiddleware { + service: Rc::new(service), + configuration: Rc::clone(&self.configuration), + storage_backend: Rc::clone(&self.storage_backend), + })) + } +} + +fn e500(err: E) -> actix_web::Error { + actix_web::error::InternalError::from_response( + err, + HttpResponse::InternalServerError().finish(), + ) + .into() +} + +#[doc(hidden)] +#[non_exhaustive] +pub struct InnerSessionMiddleware { + service: Rc, + configuration: Rc, + storage_backend: Rc, +} + +impl Service for InnerSessionMiddleware +where + S: Service< + ServiceRequest, + Response = ServiceResponse, + Error = actix_web::Error, + > + 'static, + S::Future: 'static, + Store: SessionStore + 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + #[allow(clippy::type_complexity)] + type Future = + Pin>>>; + + forward_ready!(service); + + fn call(&self, mut req: ServiceRequest) -> Self::Future { + let service = Rc::clone(&self.service); + let storage_backend = Rc::clone(&self.storage_backend); + let configuration = Rc::clone(&self.configuration); + + Box::pin(async move { + let session_key = extract_session_key(&req, &configuration.cookie); + let (session_key, session_state) = + load_session_state(session_key, storage_backend.as_ref()) + .await?; + + Session::set_session(&mut req, session_state); + + let mut res = service.call(req).await?; + let (status, session_state) = Session::get_changes(&mut res); + + match session_key { + None => { + if !session_state.is_empty() { + let session_key = storage_backend + .save( + session_state, + &configuration.session.state_ttl, + ) + .await + .map_err(e500)?; + + set_session_cookie( + res.response_mut().head_mut(), + session_key, + &configuration.cookie, + ) + .map_err(e500)?; + } + } + + Some(session_key) => { + match status { + SessionStatus::Changed => { + let session_key = storage_backend + .update( + session_key, + session_state, + &configuration.session.state_ttl, + ) + .await + .map_err(e500)?; + + set_session_cookie( + res.response_mut().head_mut(), + session_key, + &configuration.cookie, + ) + .map_err(e500)?; + } + + SessionStatus::Purged => { + storage_backend + .delete(&session_key) + .await + .map_err(e500)?; + + delete_session_cookie( + res.response_mut().head_mut(), + &configuration.cookie, + ) + .map_err(e500)?; + } + + SessionStatus::Renewed => { + storage_backend + .delete(&session_key) + .await + .map_err(e500)?; + + let session_key = storage_backend + .save( + session_state, + &configuration.session.state_ttl, + ) + .await + .map_err(e500)?; + + set_session_cookie( + res.response_mut().head_mut(), + session_key, + &configuration.cookie, + ) + .map_err(e500)?; + } + + SessionStatus::Unchanged => { + if matches!( + configuration.ttl_extension_policy, + TtlExtensionPolicy::OnEveryRequest + ) { + storage_backend + .update_ttl( + &session_key, + &configuration.session.state_ttl, + ) + .await + .map_err(e500)?; + + if configuration.cookie.max_age.is_some() { + set_session_cookie( + res.response_mut().head_mut(), + session_key, + &configuration.cookie, + ) + .map_err(e500)?; + } + } + } + }; + } + } + + Ok(res) + }) + } +} + +fn extract_session_key( + req: &ServiceRequest, + config: &CookieConfiguration, +) -> Option { + let cookies = match req.cookies() { + Ok(cookies) => cookies, + Err(_e) => { + return None; + } + }; + let session_cookie = cookies + .iter() + .find(|&cookie| cookie.name() == config.name)?; + + let mut jar = CookieJar::new(); + jar.add_original(session_cookie.clone()); + + let verification_result = match config.content_security { + CookieContentSecurity::Signed => { + jar.signed(&config.key).get(&config.name) + } + CookieContentSecurity::Private => { + jar.private(&config.key).get(&config.name) + } + }; + + verification_result?.value().to_owned().try_into().ok() +} + +async fn load_session_state( + session_key: Option, + storage_backend: &Store, +) -> Result<(Option, Map), actix_web::Error> { + if let Some(session_key) = session_key { + match storage_backend.load(&session_key).await { + Ok(state) => { + if let Some(state) = state { + Ok((Some(session_key), state)) + } else { + Ok((None, Map::new())) + } + } + + Err(_err) => match _err { + LoadError::Deserialization(_err) => { + Ok((Some(session_key), Map::new())) + } + + LoadError::Other(err) => Err(e500(err)), + }, + } + } else { + Ok((None, Map::new())) + } +} + +fn set_session_cookie( + response: &mut ResponseHead, + session_key: SessionKey, + config: &CookieConfiguration, +) -> Result<(), anyhow::Error> { + let value: String = session_key.into(); + let mut cookie = Cookie::new(config.name.clone(), value); + + cookie.set_secure(config.secure); + cookie.set_http_only(config.http_only); + cookie.set_same_site(config.same_site); + cookie.set_path(config.path.clone()); + + if let Some(max_age) = config.max_age { + cookie.set_max_age(max_age); + } + + if let Some(ref domain) = config.domain { + cookie.set_domain(domain.clone()); + } + + let mut jar = CookieJar::new(); + match config.content_security { + CookieContentSecurity::Signed => { + jar.signed_mut(&config.key).add(cookie) + } + CookieContentSecurity::Private => { + jar.private_mut(&config.key).add(cookie) + } + } + + let cookie = jar + .delta() + .next() + .context("Failed to build session cookie")?; + let val = HeaderValue::from_str(&cookie.encoded().to_string()).context( + "Failed to attach a session cookie to the outgoing response", + )?; + + response.headers_mut().append(SET_COOKIE, val); + + Ok(()) +} + +fn delete_session_cookie( + response: &mut ResponseHead, + config: &CookieConfiguration, +) -> Result<(), anyhow::Error> { + let removal_cookie = Cookie::build(config.name.clone(), "") + .path(config.path.clone()) + .secure(config.secure) + .http_only(config.http_only) + .same_site(config.same_site); + + let mut removal_cookie = if let Some(ref domain) = config.domain { + removal_cookie.domain(domain) + } else { + removal_cookie + } + .finish(); + + removal_cookie.make_removal(); + + let val = HeaderValue::from_str(&removal_cookie.to_string()).context( + "Failed to attach a session removal cookie to the outgoing response", + )?; + response.headers_mut().append(SET_COOKIE, val); + + Ok(()) +} diff --git a/lib/session/session.rs b/lib/session/session.rs new file mode 100644 index 0000000..0ec888f --- /dev/null +++ b/lib/session/session.rs @@ -0,0 +1,438 @@ +use std::{ + cell::{Ref, RefCell}, + convert::Infallible, + error::Error as StdError, + future::Future, + mem, + pin::Pin, + rc::Rc, +}; + +use actix_utils::future::{Ready, ready}; +use actix_web::{ + FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, + body::BoxBody, + dev::{Extensions, Payload, ServiceRequest, ServiceResponse}, +}; +use anyhow::Context; +use derive_more::derive::{Display, From}; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::{Map, Value}; +use uuid::Uuid; + +const SESSION_USER_KEY: &str = "session:user_uid"; + +#[derive(Clone)] +pub struct Session(Rc>); + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum SessionStatus { + Changed, + Purged, + Renewed, + #[default] + Unchanged, +} + +#[derive(Default)] +struct SessionInner { + state: Map, + status: SessionStatus, +} + +impl Session { + pub fn get( + &self, + key: &str, + ) -> Result, SessionGetError> { + if let Some(value) = self.0.borrow().state.get(key) { + Ok(Some( + serde_json::from_value::(value.clone()) + .with_context(|| { + format!( + "Failed to deserialize the JSON-encoded session data attached to key \ + `{}` as a `{}` type", + key, + std::any::type_name::() + ) + }) + .map_err(SessionGetError)?, + )) + } else { + Ok(None) + } + } + + pub fn contains_key(&self, key: &str) -> bool { + self.0.borrow().state.contains_key(key) + } + + pub fn entries(&self) -> Ref<'_, Map> { + Ref::map(self.0.borrow(), |inner| &inner.state) + } + + pub fn status(&self) -> SessionStatus { + Ref::map(self.0.borrow(), |inner| &inner.status).clone() + } + + pub fn insert( + &self, + key: impl Into, + value: T, + ) -> Result<(), SessionInsertError> { + let mut inner = self.0.borrow_mut(); + + if inner.status != SessionStatus::Purged { + if inner.status != SessionStatus::Renewed { + inner.status = SessionStatus::Changed; + } + + let key = key.into(); + let val = serde_json::to_value(&value) + .with_context(|| { + format!( + "Failed to serialize the provided `{}` type instance as JSON in order to \ + attach as session data to the `{key}` key", + std::any::type_name::(), + ) + }) + .map_err(SessionInsertError)?; + + inner.state.insert(key, val); + } + + Ok(()) + } + + pub fn update( + &self, + key: impl Into, + updater: F, + ) -> Result<(), SessionUpdateError> + where + F: FnOnce(T) -> T, + { + let mut inner = self.0.borrow_mut(); + let key_str = key.into(); + + if let Some(val) = inner.state.get(&key_str) { + if inner.status == SessionStatus::Purged { + return Ok(()); + } + + let value = serde_json::from_value(val.clone()) + .with_context(|| { + format!( + "Failed to deserialize the JSON-encoded session data attached to key \ + `{key_str}` as a `{}` type", + std::any::type_name::() + ) + }) + .map_err(SessionUpdateError)?; + + let val = serde_json::to_value(updater(value)) + .with_context(|| { + format!( + "Failed to serialize the provided `{}` type instance as JSON in order to \ + attach as session data to the `{key_str}` key", + std::any::type_name::(), + ) + }) + .map_err(SessionUpdateError)?; + + if inner.status != SessionStatus::Renewed { + inner.status = SessionStatus::Changed; + } + inner.state.insert(key_str, val); + } + + Ok(()) + } + + pub fn update_or( + &self, + key: &str, + default_value: T, + updater: F, + ) -> Result<(), SessionUpdateError> + where + F: FnOnce(T) -> T, + { + if self.contains_key(key) { + self.update(key, updater) + } else { + self.insert(key, default_value) + .map_err(|err| SessionUpdateError(err.into())) + } + } + + pub fn remove(&self, key: &str) -> Option { + let mut inner = self.0.borrow_mut(); + + if inner.status != SessionStatus::Purged { + if inner.status != SessionStatus::Renewed { + inner.status = SessionStatus::Changed; + } + return inner.state.remove(key); + } + + None + } + + pub fn remove_as( + &self, + key: &str, + ) -> Option> { + self.remove(key).map(|value| { + match serde_json::from_value::(value.clone()) { + Ok(val) => Ok(val), + Err(_err) => Err(value), + } + }) + } + + pub fn clear(&self) { + let mut inner = self.0.borrow_mut(); + + if inner.status != SessionStatus::Purged { + if inner.status != SessionStatus::Renewed { + inner.status = SessionStatus::Changed; + } + inner.state.clear() + } + } + + pub fn purge(&self) { + let mut inner = self.0.borrow_mut(); + inner.status = SessionStatus::Purged; + inner.state.clear(); + } + + pub fn renew(&self) { + let mut inner = self.0.borrow_mut(); + + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Renewed; + } + } + + pub fn user(&self) -> Option { + self.get::(SESSION_USER_KEY).ok().flatten() + } + + pub fn set_user(&self, uid: Uuid) { + let _ = self.insert(SESSION_USER_KEY, uid); + } + + pub fn clear_user(&self) { + let _ = self.remove(SESSION_USER_KEY); + } + + pub fn ip_address(&self) -> Option { + self.get::("session:ip_address").ok().flatten() + } + + pub fn user_agent(&self) -> Option { + self.get::("session:user_agent").ok().flatten() + } + + pub fn set_request_info(req: &HttpRequest) { + let extensions = req.extensions_mut(); + if let Some(inner) = extensions.get::>>() { + let mut inner = inner.borrow_mut(); + let mut changed = false; + if let Some(ua) = req.headers().get("user-agent") + && let Ok(ua) = ua.to_str() + { + let _ = inner.state.insert( + "session:user_agent".to_string(), + serde_json::json!(ua), + ); + changed = true; + } + let addr = req + .connection_info() + .realip_remote_addr() + .map(|s| s.to_string()); + if let Some(ip) = addr { + let _ = inner.state.insert( + "session:ip_address".to_string(), + serde_json::json!(ip), + ); + changed = true; + } + if changed && inner.status != SessionStatus::Renewed { + inner.status = SessionStatus::Changed; + } + } + } + + #[allow(clippy::needless_pass_by_ref_mut)] + pub(crate) fn set_session( + req: &mut ServiceRequest, + data: impl IntoIterator, + ) { + let session = Session::get_session(&mut req.extensions_mut()); + let mut inner = session.0.borrow_mut(); + inner.state.extend(data); + } + + #[allow(clippy::needless_pass_by_ref_mut)] + pub(crate) fn get_changes( + res: &mut ServiceResponse, + ) -> (SessionStatus, Map) { + if let Some(s_impl) = res + .request() + .extensions() + .get::>>() + { + let state = mem::take(&mut s_impl.borrow_mut().state); + (s_impl.borrow().status.clone(), state) + } else { + (SessionStatus::Unchanged, Map::new()) + } + } + pub fn no_op() -> Self { + Self(Rc::new(RefCell::new(SessionInner::default()))) + } + pub fn get_session(extensions: &mut Extensions) -> Session { + if let Some(s_impl) = extensions.get::>>() { + return Session(Rc::clone(s_impl)); + } + + let inner = Rc::new(RefCell::new(SessionInner::default())); + extensions.insert(inner.clone()); + + Session(inner) + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{Session, SessionStatus, SessionUpdateError}; + + #[test] + fn update_marks_session_as_changed() -> Result<(), SessionUpdateError> { + let session = Session::no_op(); + { + let mut inner = session.0.borrow_mut(); + inner.state.insert("counter".to_string(), json!(1_u64)); + inner.status = SessionStatus::Unchanged; + } + + session.update("counter", |counter: u64| counter + 1)?; + + assert_eq!(session.status(), SessionStatus::Changed); + assert_eq!( + session.0.borrow().state.get("counter"), + Some(&json!(2_u64)) + ); + Ok(()) + } + + #[test] + fn update_preserves_renewed_status() -> Result<(), SessionUpdateError> { + let session = Session::no_op(); + { + let mut inner = session.0.borrow_mut(); + inner.state.insert("counter".to_string(), json!(1_u64)); + inner.status = SessionStatus::Renewed; + } + + session.update("counter", |counter: u64| counter + 1)?; + + assert_eq!(session.status(), SessionStatus::Renewed); + assert_eq!( + session.0.borrow().state.get("counter"), + Some(&json!(2_u64)) + ); + Ok(()) + } +} + +impl FromRequest for Session { + type Error = Infallible; + type Future = Ready>; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ready(Ok(Session::get_session(&mut req.extensions_mut()))) + } +} +#[derive(Clone, Copy)] +pub struct SessionUser(pub Uuid); + +impl FromRequest for SessionUser { + type Error = SessionGetError; + type Future = Pin>>>; + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let req = req.clone(); + Box::pin(async move { + let uid = { + let mut extensions = req.extensions_mut(); + let session = Session::get_session(&mut extensions); + session.user().ok_or_else(|| { + SessionGetError(anyhow::anyhow!("not authenticated")) + })? + }; + Ok(SessionUser(uid)) + }) + } +} + +#[derive(Debug, Display, From)] +#[display("{_0}")] +pub struct SessionGetError(anyhow::Error); + +impl StdError for SessionGetError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.0.as_ref()) + } +} + +impl ResponseError for SessionGetError { + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()) + .content_type("text/plain") + .body(self.to_string()) + } +} + +#[derive(Debug, Display, From)] +#[display("{_0}")] +pub struct SessionInsertError(anyhow::Error); + +impl StdError for SessionInsertError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.0.as_ref()) + } +} + +impl ResponseError for SessionInsertError { + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()) + .content_type("text/plain") + .body(self.to_string()) + } +} + +#[derive(Debug, Display, From)] +#[display("{_0}")] +pub struct SessionUpdateError(anyhow::Error); + +impl StdError for SessionUpdateError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.0.as_ref()) + } +} + +impl ResponseError for SessionUpdateError { + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()) + .content_type("text/plain") + .body(self.to_string()) + } +} diff --git a/lib/session/session_ext.rs b/lib/session/session_ext.rs new file mode 100644 index 0000000..318fb2f --- /dev/null +++ b/lib/session/session_ext.rs @@ -0,0 +1,35 @@ +use actix_web::{ + HttpMessage, HttpRequest, + dev::{ServiceRequest, ServiceResponse}, + guard::GuardContext, +}; + +use crate::Session; + +pub trait SessionExt { + fn get_session(&self) -> Session; +} + +impl SessionExt for HttpRequest { + fn get_session(&self) -> Session { + Session::get_session(&mut self.extensions_mut()) + } +} + +impl SessionExt for ServiceRequest { + fn get_session(&self) -> Session { + Session::get_session(&mut self.extensions_mut()) + } +} + +impl SessionExt for ServiceResponse { + fn get_session(&self) -> Session { + self.request().get_session() + } +} + +impl SessionExt for GuardContext<'_> { + fn get_session(&self) -> Session { + Session::get_session(&mut self.req_data_mut()) + } +} diff --git a/lib/session/storage/format.rs b/lib/session/storage/format.rs new file mode 100644 index 0000000..b6e8df1 --- /dev/null +++ b/lib/session/storage/format.rs @@ -0,0 +1,79 @@ +use std::collections::HashMap; + +use serde::ser::{Serialize, SerializeMap, Serializer}; +use serde_json::{Map, Value}; + +use super::interface::SessionState; + +const SESSION_STATE_FORMAT_VERSION: u8 = 1; + +#[derive(Debug)] +struct StoredSessionStateRef<'a> { + state: &'a SessionState, +} + +impl Serialize for StoredSessionStateRef<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("v", &SESSION_STATE_FORMAT_VERSION)?; + map.serialize_entry("state", self.state)?; + map.end() + } +} + +pub(crate) fn serialize_session_state( + session_state: &SessionState, +) -> Result { + let stored = StoredSessionStateRef { + state: session_state, + }; + + serde_json::to_string(&stored).map_err(anyhow::Error::new) +} + +pub(crate) fn deserialize_session_state( + value: &str, +) -> Result { + let value = serde_json::from_str::(value)?; + + let Value::Object(mut obj) = value else { + anyhow::bail!("Session state is not a JSON object"); + }; + + if matches!(obj.get("state"), Some(Value::Object(_))) + && let Some(Value::Number(v)) = obj.get("v") + { + let v = v.as_u64().ok_or_else(|| { + anyhow::anyhow!("Invalid session state format version") + })?; + let v = u8::try_from(v).map_err(|_| { + anyhow::anyhow!("Invalid session state format version") + })?; + anyhow::ensure!( + v == SESSION_STATE_FORMAT_VERSION, + "Unsupported session state format version: {}", + v + ); + + let Some(Value::Object(state)) = obj.remove("state") else { + unreachable!("`state` was checked to be an object above"); + }; + return Ok(state); + } + + if obj.values().all(Value::is_string) { + let legacy: HashMap = + serde_json::from_value(Value::Object(obj))?; + let mut migrated: Map = Map::new(); + for (key, json_encoded) in legacy { + migrated.insert(key, serde_json::from_str::(&json_encoded)?); + } + + return Ok(migrated); + } + + Ok(obj) +} diff --git a/lib/session/storage/interface.rs b/lib/session/storage/interface.rs new file mode 100644 index 0000000..d3f6b6f --- /dev/null +++ b/lib/session/storage/interface.rs @@ -0,0 +1,94 @@ +use std::future::Future; + +use actix_web::cookie::time::Duration; +use derive_more::derive::Display; +use serde_json::{Map, Value}; + +use super::SessionKey; + +pub(crate) type SessionState = Map; + +pub trait SessionStore { + fn load( + &self, + session_key: &SessionKey, + ) -> impl Future, LoadError>>; + + fn save( + &self, + session_state: SessionState, + ttl: &Duration, + ) -> impl Future>; + + fn update( + &self, + session_key: SessionKey, + session_state: SessionState, + ttl: &Duration, + ) -> impl Future>; + + fn update_ttl( + &self, + session_key: &SessionKey, + ttl: &Duration, + ) -> impl Future>; + + fn delete( + &self, + session_key: &SessionKey, + ) -> impl Future>; +} + +#[derive(Debug, Display)] +pub enum LoadError { + #[display("Failed to deserialize session state")] + Deserialization(anyhow::Error), + + #[display("Something went wrong when retrieving the session state")] + Other(anyhow::Error), +} + +impl std::error::Error for LoadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Deserialization(err) => Some(err.as_ref()), + Self::Other(err) => Some(err.as_ref()), + } + } +} + +#[derive(Debug, Display)] +pub enum SaveError { + #[display("Failed to serialize session state")] + Serialization(anyhow::Error), + + #[display("Something went wrong when persisting the session state")] + Other(anyhow::Error), +} + +impl std::error::Error for SaveError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Serialization(err) => Some(err.as_ref()), + Self::Other(err) => Some(err.as_ref()), + } + } +} + +#[derive(Debug, Display)] +pub enum UpdateError { + #[display("Failed to serialize session state")] + Serialization(anyhow::Error), + + #[display("Something went wrong when updating the session state.")] + Other(anyhow::Error), +} + +impl std::error::Error for UpdateError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Serialization(err) => Some(err.as_ref()), + Self::Other(err) => Some(err.as_ref()), + } + } +} diff --git a/lib/session/storage/mod.rs b/lib/session/storage/mod.rs new file mode 100644 index 0000000..3f169fc --- /dev/null +++ b/lib/session/storage/mod.rs @@ -0,0 +1,12 @@ +mod format; +mod interface; +mod redis_cluster; +mod session_key; +mod utils; + +pub use self::{ + interface::{LoadError, SaveError, SessionStore, UpdateError}, + redis_cluster::RedisClusterSessionStore, + session_key::SessionKey, + utils::generate_session_key, +}; diff --git a/lib/session/storage/redis_cluster.rs b/lib/session/storage/redis_cluster.rs new file mode 100644 index 0000000..71454ca --- /dev/null +++ b/lib/session/storage/redis_cluster.rs @@ -0,0 +1,297 @@ +use std::{sync::Arc, time::Duration as StdDuration}; + +use actix_web::cookie::time::Duration; +use redis::{ + cluster::{ClusterClient, ClusterClientBuilder}, + cluster_async::ClusterConnection, +}; +use tokio::sync::Mutex; + +use super::SessionKey; +use crate::storage::{ + SessionStore, + format::{deserialize_session_state, serialize_session_state}, + interface::{LoadError, SaveError, SessionState, UpdateError}, + utils::generate_session_key, +}; + +#[derive(Clone)] +pub struct RedisClusterSessionStore { + client: ClusterClient, + connection: Arc>, + configuration: CacheConfiguration, +} + +#[derive(Clone)] +struct CacheConfiguration { + cache_keygen: Arc String + Send + Sync>, +} + +impl Default for CacheConfiguration { + fn default() -> Self { + Self { + cache_keygen: Arc::new(str::to_owned), + } + } +} + +impl RedisClusterSessionStore { + const DEFAULT_CONNECTION_TIMEOUT: StdDuration = StdDuration::from_secs(2); + const DEFAULT_RESPONSE_TIMEOUT: StdDuration = StdDuration::from_secs(2); + const DEFAULT_COMMAND_TIMEOUT: StdDuration = StdDuration::from_secs(3); + const DEFAULT_RETRIES: u32 = 1; + const DEFAULT_RETRY_MIN_WAIT_MS: u64 = 25; + const DEFAULT_RETRY_MAX_WAIT_MS: u64 = 100; + const DEFAULT_RETRY_FACTOR: u64 = 10; + const DEFAULT_RETRY_EXPONENT_BASE: u64 = 2; + + pub fn builder( + connection_strings: Vec, + ) -> RedisClusterSessionStoreBuilder { + RedisClusterSessionStoreBuilder { + configuration: CacheConfiguration::default(), + connection_strings, + } + } + + pub async fn new( + connection_strings: Vec, + ) -> anyhow::Result { + Self::builder(connection_strings).build().await + } + + fn client_builder(connection_strings: Vec) -> ClusterClientBuilder { + ClusterClient::builder(connection_strings) + .connection_timeout(Self::DEFAULT_CONNECTION_TIMEOUT) + .response_timeout(Self::DEFAULT_RESPONSE_TIMEOUT) + .retries(Self::DEFAULT_RETRIES) + .min_retry_wait(Self::DEFAULT_RETRY_MIN_WAIT_MS) + .max_retry_wait(Self::DEFAULT_RETRY_MAX_WAIT_MS) + .retry_wait_formula( + Self::DEFAULT_RETRY_FACTOR, + Self::DEFAULT_RETRY_EXPONENT_BASE, + ) + } + + async fn connect( + client: &ClusterClient, + ) -> anyhow::Result { + let started = std::time::Instant::now(); + let connection = tokio::time::timeout( + Self::DEFAULT_COMMAND_TIMEOUT, + client.get_async_connection(), + ) + .await + .map_err(|_| anyhow::anyhow!("session redis async connect timed out"))? + .map_err(|e| anyhow::anyhow!(e))?; + + tracing::debug!( + elapsed_ms = started.elapsed().as_millis() as u64, + "session redis async connect finished" + ); + Ok(connection) + } + + async fn execute_cmd( + &self, + op_name: &'static str, + make_cmd: F, + ) -> anyhow::Result + where + T: redis::FromRedisValue, + F: Fn() -> redis::Cmd, + { + let first_try: anyhow::Result = { + let mut connection = self.connection.lock().await; + let started = std::time::Instant::now(); + tracing::debug!(op = op_name, "session redis command start"); + match tokio::time::timeout( + Self::DEFAULT_COMMAND_TIMEOUT, + make_cmd().query_async(&mut *connection), + ) + .await + .map_err(|_| anyhow::anyhow!("session redis command timed out"))? + { + Ok(value) => { + tracing::debug!( + op = op_name, + elapsed_ms = started.elapsed().as_millis() as u64, + "session redis command finished" + ); + return Ok(value); + } + Err(err) => Err(anyhow::anyhow!(err)), + } + }; + + if let Err(error) = &first_try { + tracing::warn!(op = op_name, error = %error, "session redis command failed, reconnecting"); + } + + let new_connection = Self::connect(&self.client).await?; + { + let mut connection = self.connection.lock().await; + *connection = new_connection; + } + + let mut connection = self.connection.lock().await; + let started = std::time::Instant::now(); + tracing::debug!(op = op_name, "session redis command retry start"); + let result = tokio::time::timeout( + Self::DEFAULT_COMMAND_TIMEOUT, + make_cmd().query_async(&mut *connection), + ) + .await + .map_err(|_| anyhow::anyhow!("session redis command retry timed out"))? + .map_err(|e| anyhow::anyhow!(e))?; + tracing::debug!( + op = op_name, + elapsed_ms = started.elapsed().as_millis() as u64, + "session redis command retry finished" + ); + Ok(result) + } + + fn ttl_seconds(ttl: &Duration) -> anyhow::Result { + let ttl_secs = ttl.whole_seconds(); + if ttl_secs <= 0 { + anyhow::bail!("session TTL must be positive"); + } + u64::try_from(ttl_secs).map_err(anyhow::Error::new) + } +} + +#[must_use] +pub struct RedisClusterSessionStoreBuilder { + configuration: CacheConfiguration, + connection_strings: Vec, +} + +impl RedisClusterSessionStoreBuilder { + pub fn cache_keygen(mut self, keygen: F) -> Self + where + F: Fn(&str) -> String + 'static + Send + Sync, + { + self.configuration.cache_keygen = Arc::new(keygen); + self + } + + pub async fn build(self) -> anyhow::Result { + let client = + RedisClusterSessionStore::client_builder(self.connection_strings) + .build()?; + let connection = RedisClusterSessionStore::connect(&client).await?; + + Ok(RedisClusterSessionStore { + client, + connection: Arc::new(Mutex::new(connection)), + configuration: self.configuration, + }) + } +} + +impl SessionStore for RedisClusterSessionStore { + async fn load( + &self, + session_key: &SessionKey, + ) -> Result, LoadError> { + let cache_key = + self.configuration.cache_keygen.as_ref()(session_key.as_ref()); + let value: Option = self + .execute_cmd("get", move || { + let mut cmd = redis::cmd("GET"); + cmd.arg(&cache_key); + cmd + }) + .await + .map_err(LoadError::Other)?; + + match value { + None => Ok(None), + Some(value) => Ok(Some( + deserialize_session_state(&value) + .map_err(LoadError::Deserialization)?, + )), + } + } + + async fn save( + &self, + session_state: SessionState, + ttl: &Duration, + ) -> Result { + let body = serialize_session_state(&session_state) + .map_err(SaveError::Serialization)?; + let session_key = generate_session_key(); + let cache_key = + self.configuration.cache_keygen.as_ref()(session_key.as_ref()); + let ttl_secs = Self::ttl_seconds(ttl).map_err(SaveError::Other)?; + + self.execute_cmd::<(), _>("set_ex", move || { + let mut cmd = redis::cmd("SETEX"); + cmd.arg(&cache_key).arg(ttl_secs).arg(&body); + cmd + }) + .await + .map_err(SaveError::Other)?; + + Ok(session_key) + } + + async fn update( + &self, + session_key: SessionKey, + session_state: SessionState, + ttl: &Duration, + ) -> Result { + let body = serialize_session_state(&session_state) + .map_err(UpdateError::Serialization)?; + let cache_key = + self.configuration.cache_keygen.as_ref()(session_key.as_ref()); + let ttl_secs = Self::ttl_seconds(ttl).map_err(UpdateError::Other)?; + + self.execute_cmd::<(), _>("set_ex", move || { + let mut cmd = redis::cmd("SETEX"); + cmd.arg(&cache_key).arg(ttl_secs).arg(&body); + cmd + }) + .await + .map_err(UpdateError::Other)?; + + Ok(session_key) + } + + async fn update_ttl( + &self, + session_key: &SessionKey, + ttl: &Duration, + ) -> anyhow::Result<()> { + let cache_key = + self.configuration.cache_keygen.as_ref()(session_key.as_ref()); + let ttl_secs = Self::ttl_seconds(ttl)?; + + self.execute_cmd("expire", move || { + let mut cmd = redis::cmd("EXPIRE"); + cmd.arg(&cache_key).arg(ttl_secs); + cmd + }) + .await + .map(|_: bool| ()) + } + + async fn delete( + &self, + session_key: &SessionKey, + ) -> Result<(), anyhow::Error> { + let cache_key = + self.configuration.cache_keygen.as_ref()(session_key.as_ref()); + + self.execute_cmd("del", move || { + let mut cmd = redis::cmd("DEL"); + cmd.arg(&cache_key); + cmd + }) + .await + .map(|_: i64| ()) + } +} diff --git a/lib/session/storage/session_key.rs b/lib/session/storage/session_key.rs new file mode 100644 index 0000000..0704df4 --- /dev/null +++ b/lib/session/storage/session_key.rs @@ -0,0 +1,48 @@ +use derive_more::derive::{Display, From}; + +#[derive(Debug, PartialEq, Eq)] +pub struct SessionKey(String); + +impl TryFrom for SessionKey { + type Error = InvalidSessionKeyError; + + fn try_from(val: String) -> Result { + if val.len() > 4064 { + return Err(anyhow::anyhow!( + "The session key is bigger than 4064 bytes, the upper limit on cookie content." + ) + .into()); + } + + if val.contains('\0') { + return Err(anyhow::anyhow!( + "The session key contains null bytes which are not allowed." + ) + .into()); + } + + Ok(SessionKey(val)) + } +} + +impl AsRef for SessionKey { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl From for String { + fn from(key: SessionKey) -> Self { + key.0 + } +} + +#[derive(Debug, Display, From)] +#[display("The provided string is not a valid session key")] +pub struct InvalidSessionKeyError(anyhow::Error); + +impl std::error::Error for InvalidSessionKeyError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.0.as_ref()) + } +} diff --git a/lib/session/storage/utils.rs b/lib/session/storage/utils.rs new file mode 100644 index 0000000..0129b71 --- /dev/null +++ b/lib/session/storage/utils.rs @@ -0,0 +1,12 @@ +use rand::distr::{Alphanumeric, SampleString as _}; + +use crate::storage::SessionKey; + +pub fn generate_session_key() -> SessionKey { + match Alphanumeric.sample_string(&mut rand::rng(), 64).try_into() { + Ok(session_key) => session_key, + Err(_error) => unreachable!( + "64 alphanumeric characters are always a valid session key" + ), + } +} diff --git a/lib/socketio/Cargo.toml b/lib/socketio/Cargo.toml new file mode 100644 index 0000000..58f3373 --- /dev/null +++ b/lib/socketio/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "socketio" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + +[lib] +path = "lib.rs" +name = "socketio" +[dependencies] +actix-web = { workspace = true } +actix-ws = "0.4.0" +async-nats = { workspace = true } +async-trait = { workspace = true } +base64 = { workspace = true } +deadpool-redis = { workspace = true } +futures-util = { workspace = true } +redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +session = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["macros", "sync", "time"] } +tracing = { workspace = true } +uuid = { workspace = true, features = ["v4"] } + +[lints] +workspace = true diff --git a/lib/socketio/actix.rs b/lib/socketio/actix.rs new file mode 100644 index 0000000..a3e6585 --- /dev/null +++ b/lib/socketio/actix.rs @@ -0,0 +1,394 @@ +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use actix_web::{ + Error, HttpRequest, HttpResponse, + error::{ErrorBadRequest, ErrorInternalServerError, ErrorNotFound}, + http::header, + web::{self, Bytes, Data, Payload, ServiceConfig}, +}; +use futures_util::StreamExt; +use serde_json::json; +use session::SessionExt; + +use crate::{ + engine_packet::{ + EnginePacket, SocketPayload, decode_engine_payload, + decode_engine_text_packet, encode_engine_packet, encode_engine_payload, + }, + error::SocketIoError, + server::SocketIo, + session::{Session, Transport}, + socket::DisconnectReason, +}; +struct ActiveGuard(Arc); + +impl Drop for ActiveGuard { + fn drop(&mut self) { + self.0.store(false, Ordering::Release); + } +} + +#[derive(Debug, serde::Deserialize)] +struct EngineQuery { + #[serde(rename = "EIO")] + eio: Option, + transport: Option, + sid: Option, +} + +pub fn configure(cfg: &mut ServiceConfig, io: SocketIo) { + let path = io.config().path.clone(); + configure_at(cfg, path, io); +} + +pub fn configure_at( + cfg: &mut ServiceConfig, + path: impl Into, + io: SocketIo, +) { + cfg.app_data(Data::new(io)).service( + web::resource(path.into()) + .route(web::get().to(engine_get)) + .route(web::post().to(engine_post)), + ); +} + +async fn engine_get( + io: Data, + req: HttpRequest, + stream: Payload, + query: web::Query, +) -> Result { + validate_eio(&query)?; + + match query.transport.as_deref() { + Some("polling") if query.sid.is_none() => polling_open(io, &req).await, + Some("polling") => polling_get(io, query.sid.as_deref()).await, + Some("websocket") => { + websocket_open(io, req, stream, query.sid.clone()).await + } + _ => Err(ErrorBadRequest("unsupported transport")), + } +} + +async fn engine_post( + io: Data, + query: web::Query, + body: Bytes, +) -> Result { + validate_eio(&query)?; + if query.transport.as_deref() != Some("polling") { + return Err(ErrorBadRequest("unsupported transport")); + } + if body.len() > io.config().max_payload { + return Err(ErrorBadRequest("payload too large")); + } + + let sid = query + .sid + .as_deref() + .ok_or_else(|| ErrorBadRequest("missing sid"))?; + validate_sid(sid)?; + let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?; + if session.post_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() { + io.remove_session(&session, DisconnectReason::TransportClosed) + .await; + return Err(ErrorBadRequest("overlapping post request")); + } + + let _guard = ActiveGuard(session.post_active.clone()); + handle_polling_body(&io, session, &body).await?; + + Ok(HttpResponse::Ok() + .insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8")) + .body("ok")) +} + +async fn handle_polling_body( + io: &SocketIo, + session: Arc, + body: &Bytes, +) -> Result<(), Error> { + touch_session(&session).await; + let payload = std::str::from_utf8(body).map_err(ErrorBadRequest)?; + let packets = decode_engine_payload(payload).map_err(map_socket_error)?; + for packet in packets { + handle_engine_packet(io, session.clone(), packet).await?; + } + Ok(()) +} + +async fn polling_open( + io: Data, + req: &HttpRequest, +) -> Result { + let session = Session::new(req.get_session().user()); + let sid = session.engine_id.clone(); + io.insert_session(session).await; + + let open = json!({ + "sid": sid, + "upgrades": ["websocket"], + "pingInterval": io.config().ping_interval.as_millis(), + "pingTimeout": io.config().ping_timeout.as_millis(), + "maxPayload": io.config().max_payload + }); + + Ok(text_response(encode_engine_packet( + &EnginePacket::Open(open), + true, + ))) +} + +async fn polling_get( + io: Data, + sid: Option<&str>, +) -> Result { + let sid = sid.ok_or_else(|| ErrorBadRequest("missing sid"))?; + validate_sid(sid)?; + let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?; + if session.get_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() { + io.remove_session(&session, DisconnectReason::TransportClosed) + .await; + return Err(ErrorBadRequest("overlapping get request")); + } + + let _guard = ActiveGuard(session.get_active.clone()); + loop { + let packets = session.drain().await; + if !packets.is_empty() { + return Ok(text_response(encode_engine_payload(&packets, true))); + } + tokio::select! { + () = session.notify.notified() => { + } + () = tokio::time::sleep(io.config().ping_interval) => { + return Ok(text_response(encode_engine_payload( + &[EnginePacket::Ping(None)], true, + ))); + } + } + } +} + +async fn websocket_open( + io: Data, + req: HttpRequest, + stream: Payload, + sid: Option, +) -> Result { + let direct_websocket = sid.is_none(); + let session = match sid { + Some(ref sid) => { + validate_sid(sid)?; + io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))? + } + None => { + let session = Session::new(req.get_session().user()); + io.insert_session(session.clone()).await; + session + } + }; + + let (response, mut ws_session, messages) = actix_ws::handle(&req, stream)?; + let io = io.get_ref().clone(); + + actix_web::rt::spawn(async move { + if direct_websocket { + *session.transport.lock().await = Transport::WebSocket; + let open = json!({ + "sid": session.engine_id, + "upgrades": [], + "pingInterval": io.config().ping_interval.as_millis(), + "pingTimeout": io.config().ping_timeout.as_millis(), + "maxPayload": io.config().max_payload + }); + if ws_session + .text(encode_engine_packet(&EnginePacket::Open(open), false)) + .await + .is_err() + { + io.remove_session(&session, DisconnectReason::TransportClosed) + .await; + return; + } + } + + websocket_loop(io, session, ws_session, messages, direct_websocket) + .await; + }); + + Ok(response) +} + +async fn websocket_loop( + io: SocketIo, + session: Arc, + mut ws_session: actix_ws::Session, + mut messages: actix_ws::MessageStream, + mut upgraded: bool, +) { + let mut heartbeat = tokio::time::interval_at( + tokio::time::Instant::now() + io.config().ping_interval, + io.config().ping_interval, + ); + + loop { + tokio::select! { + message = messages.next() => { + match message { + Some(Ok(actix_ws::Message::Text(text))) => { + touch_session(&session).await; + match decode_engine_text_packet(text.as_ref()) { + Ok(EnginePacket::Ping(Some(value))) if value == "probe" && !upgraded => { + if ws_session.text("3probe").await.is_err() { + break; + } + } + Ok(EnginePacket::Upgrade) if !upgraded => { + *session.transport.lock().await = Transport::WebSocket; + upgraded = true; + } + Ok(packet) => { + if handle_engine_packet(&io, session.clone(), packet).await.is_err() { + break; + } + } + Err(_) => break, + } + } + Some(Ok(actix_ws::Message::Binary(bytes))) => { + touch_session(&session).await; + if handle_engine_packet( + &io, + session.clone(), + EnginePacket::Message(SocketPayload::Binary(bytes.to_vec())), + ) + .await + .is_err() + { + break; + } + } + Some(Ok(actix_ws::Message::Ping(bytes))) => { + touch_session(&session).await; + if ws_session.pong(&bytes).await.is_err() { + break; + } + } + Some(Ok(actix_ws::Message::Pong(_))) => { + touch_session(&session).await; + } + Some(Ok(actix_ws::Message::Close(reason))) => { + let _ = ws_session.close(reason).await; + break; + } + Some(Ok(actix_ws::Message::Continuation(_))) => {} + Some(Ok(actix_ws::Message::Nop)) => {} + Some(Err(_)) | None => break, + } + } + () = session.notify.notified() => { + for packet in session.drain().await { + if send_ws_packet(&mut ws_session, packet).await.is_err() { + io.remove_session(&session, DisconnectReason::TransportClosed).await; + return; + } + } + } + _ = heartbeat.tick() => { + if session.last_pong.lock().await.elapsed() + > io.config().ping_interval + io.config().ping_timeout + { + io.remove_session(&session, DisconnectReason::PingTimeout).await; + return; + } + if ws_session.text("2").await.is_err() { + io.remove_session(&session, DisconnectReason::TransportClosed).await; + return; + } + } + } + } + + io.remove_session(&session, DisconnectReason::TransportClosed) + .await; +} + +async fn touch_session(session: &Session) { + *session.last_pong.lock().await = std::time::Instant::now(); +} + +async fn send_ws_packet( + ws_session: &mut actix_ws::Session, + packet: EnginePacket, +) -> std::result::Result<(), actix_ws::Closed> { + match packet { + EnginePacket::Message(SocketPayload::Binary(bytes)) => { + ws_session.binary(bytes).await + } + packet => ws_session.text(encode_engine_packet(&packet, false)).await, + } +} + +async fn handle_engine_packet( + io: &SocketIo, + session: Arc, + packet: EnginePacket, +) -> Result<(), Error> { + match packet { + EnginePacket::Ping(data) => { + session.enqueue(EnginePacket::Pong(data)).await; + } + EnginePacket::Pong(_) => { + *session.last_pong.lock().await = std::time::Instant::now(); + } + EnginePacket::Message(payload) => { + io.handle_socket_payload(session, payload) + .await + .map_err(map_socket_error)?; + } + EnginePacket::Close => { + io.remove_session(&session, DisconnectReason::Client).await; + } + EnginePacket::Open(_) => { + tracing::warn!("client sent unexpected Open packet"); + } + EnginePacket::Upgrade | EnginePacket::Noop => {} + } + Ok(()) +} + +fn validate_eio(query: &EngineQuery) -> Result<(), Error> { + match query.eio.as_deref() { + Some("4") => Ok(()), + _ => Err(ErrorBadRequest("unsupported EIO version")), + } +} + +fn validate_sid(sid: &str) -> Result<(), Error> { + if sid.is_empty() + || sid.len() > 128 + || !sid.bytes().all(|byte| byte.is_ascii_graphic()) + { + return Err(ErrorBadRequest("invalid sid")); + } + Ok(()) +} + +fn text_response(body: String) -> HttpResponse { + HttpResponse::Ok() + .insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8")) + .body(body) +} + +fn map_socket_error(err: SocketIoError) -> Error { + match err { + SocketIoError::UnknownSession | SocketIoError::UnknownNamespace(_) => { + ErrorNotFound(err) + } + SocketIoError::InvalidPacket(_) => ErrorBadRequest(err), + _ => ErrorInternalServerError(err), + } +} diff --git a/lib/socketio/adapter.rs b/lib/socketio/adapter.rs new file mode 100644 index 0000000..fb84c5a --- /dev/null +++ b/lib/socketio/adapter.rs @@ -0,0 +1,201 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use async_trait::async_trait; +use tokio::sync::RwLock; + +use crate::{error::Result, packet::Packet}; + +#[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)] +pub struct BroadcastOptions { + pub namespace: String, + pub rooms: HashSet, + pub except: HashSet, + pub skip_sid: Option, +} + +#[async_trait] +pub trait Adapter: Send + Sync { + async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()>; + async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()>; + async fn add_to_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()>; + async fn remove_from_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()>; + async fn sockets( + &self, + namespace: &str, + opts: &BroadcastOptions, + ) -> Result>; + async fn publish( + &self, + packet: &Packet, + opts: &BroadcastOptions, + ) -> Result<()>; +} + +#[derive(Default)] +pub struct MemoryAdapter { + state: RwLock>, +} + +#[derive(Default)] +struct NamespaceRooms { + sockets: HashSet, + rooms: HashMap>, +} + +impl MemoryAdapter { + pub fn new() -> Arc { + Arc::new(Self::default()) + } +} + +#[async_trait] +impl Adapter for MemoryAdapter { + async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()> { + let mut state = self.state.write().await; + state + .entry(namespace.to_owned()) + .or_default() + .sockets + .insert(sid.to_owned()); + Ok(()) + } + + async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()> { + let mut state = self.state.write().await; + if let Some(ns) = state.get_mut(namespace) { + ns.sockets.remove(sid); + ns.rooms.retain(|_, sockets| { + sockets.remove(sid); + !sockets.is_empty() + }); + } + Ok(()) + } + + async fn add_to_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()> { + let mut state = self.state.write().await; + let ns = state.entry(namespace.to_owned()).or_default(); + ns.sockets.insert(sid.to_owned()); + ns.rooms + .entry(room.to_owned()) + .or_default() + .insert(sid.to_owned()); + Ok(()) + } + + async fn remove_from_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()> { + let mut state = self.state.write().await; + if let Some(ns) = state.get_mut(namespace) + && let Some(sockets) = ns.rooms.get_mut(room) + { + sockets.remove(sid); + } + Ok(()) + } + + async fn sockets( + &self, + namespace: &str, + opts: &BroadcastOptions, + ) -> Result> { + let state = self.state.read().await; + let Some(ns) = state.get(namespace) else { + return Ok(Vec::new()); + }; + + let mut selected: HashSet = if opts.rooms.is_empty() { + ns.sockets.iter().cloned().collect() + } else { + opts.rooms + .iter() + .filter_map(|room| ns.rooms.get(room)) + .flat_map(|sockets| sockets.iter().cloned()) + .collect() + }; + + for room in &opts.except { + if let Some(excluded) = ns.rooms.get(room) { + for sid in excluded { + selected.remove(sid); + } + } + } + if let Some(skip_sid) = &opts.skip_sid { + selected.remove(skip_sid); + } + + Ok(selected.into_iter().collect()) + } + + async fn publish( + &self, + _packet: &Packet, + _opts: &BroadcastOptions, + ) -> Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn memory_adapter_filters_rooms_and_except() { + let adapter = MemoryAdapter::new(); + adapter.add_socket("/", "s1").await.unwrap(); + adapter.add_socket("/", "s2").await.unwrap(); + adapter.add_socket("/", "s3").await.unwrap(); + adapter.add_to_room("/", "s1", "room-a").await.unwrap(); + adapter.add_to_room("/", "s2", "room-a").await.unwrap(); + adapter.add_to_room("/", "s2", "muted").await.unwrap(); + + let opts = BroadcastOptions { + namespace: "/".to_owned(), + rooms: HashSet::from(["room-a".to_owned()]), + except: HashSet::from(["muted".to_owned()]), + skip_sid: None, + }; + + assert_eq!(adapter.sockets("/", &opts).await.unwrap(), vec!["s1"]); + } + + #[tokio::test] + async fn memory_adapter_can_skip_sender() { + let adapter = MemoryAdapter::new(); + adapter.add_socket("/", "s1").await.unwrap(); + adapter.add_socket("/", "s2").await.unwrap(); + + let opts = BroadcastOptions { + namespace: "/".to_owned(), + skip_sid: Some("s1".to_owned()), + ..BroadcastOptions::default() + }; + let sockets = adapter.sockets("/", &opts).await.unwrap(); + + assert_eq!(sockets, vec!["s2"]); + } +} diff --git a/lib/socketio/config.rs b/lib/socketio/config.rs new file mode 100644 index 0000000..4aa8562 --- /dev/null +++ b/lib/socketio/config.rs @@ -0,0 +1,24 @@ +use std::time::Duration; + +#[derive(Clone, Debug)] +pub struct SocketIoConfig { + pub path: String, + pub ping_interval: Duration, + pub ping_timeout: Duration, + pub connect_timeout: Duration, + pub ack_timeout: Duration, + pub max_payload: usize, +} + +impl Default for SocketIoConfig { + fn default() -> Self { + Self { + path: "/socket.io/".to_owned(), + ping_interval: Duration::from_millis(25_000), + ping_timeout: Duration::from_millis(20_000), + connect_timeout: Duration::from_millis(45_000), + ack_timeout: Duration::from_secs(10), + max_payload: 1_000_000, + } + } +} diff --git a/lib/socketio/engine_packet.rs b/lib/socketio/engine_packet.rs new file mode 100644 index 0000000..013d7d5 --- /dev/null +++ b/lib/socketio/engine_packet.rs @@ -0,0 +1,121 @@ +use base64::{Engine, engine::general_purpose::STANDARD}; +use serde_json::Value; + +use crate::error::{Result, SocketIoError}; + +const RECORD_SEPARATOR: char = '\x1e'; + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum EnginePacket { + Open(Value), + Close, + Ping(Option), + Pong(Option), + Message(SocketPayload), + Upgrade, + Noop, +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum SocketPayload { + Text(String), + Binary(Vec), +} + +pub(crate) fn encode_engine_payload( + packets: &[EnginePacket], + polling: bool, +) -> String { + packets + .iter() + .map(|packet| encode_engine_packet(packet, polling)) + .collect::>() + .join(&RECORD_SEPARATOR.to_string()) +} + +pub(crate) fn decode_engine_payload( + payload: &str, +) -> Result> { + payload + .split(RECORD_SEPARATOR) + .filter(|item| !item.is_empty()) + .map(decode_engine_text_packet) + .collect() +} + +pub(crate) fn encode_engine_packet( + packet: &EnginePacket, + _polling: bool, +) -> String { + match packet { + EnginePacket::Open(data) => format!("0{data}"), + EnginePacket::Close => "1".to_owned(), + EnginePacket::Ping(data) => { + format!("2{}", data.as_deref().unwrap_or_default()) + } + EnginePacket::Pong(data) => { + format!("3{}", data.as_deref().unwrap_or_default()) + } + EnginePacket::Message(SocketPayload::Text(text)) => format!("4{text}"), + EnginePacket::Message(SocketPayload::Binary(bytes)) => { + format!("b{}", STANDARD.encode(bytes)) + } + EnginePacket::Upgrade => "5".to_owned(), + EnginePacket::Noop => "6".to_owned(), + } +} + +pub(crate) fn decode_engine_text_packet(input: &str) -> Result { + if let Some(encoded) = input.strip_prefix('b') { + return Ok(EnginePacket::Message(SocketPayload::Binary( + STANDARD.decode(encoded).map_err(|_| { + SocketIoError::InvalidPacket( + "invalid base64 payload".to_owned(), + ) + })?, + ))); + } + + let mut chars = input.chars(); + let packet_type = chars.next().ok_or_else(|| { + SocketIoError::InvalidPacket("empty engine packet".to_owned()) + })?; + let rest = chars.as_str(); + match packet_type { + '0' => Ok(EnginePacket::Open(serde_json::from_str(rest)?)), + '1' => Ok(EnginePacket::Close), + '2' => Ok(EnginePacket::Ping(non_empty(rest))), + '3' => Ok(EnginePacket::Pong(non_empty(rest))), + '4' => Ok(EnginePacket::Message(SocketPayload::Text(rest.to_owned()))), + '5' => Ok(EnginePacket::Upgrade), + '6' => Ok(EnginePacket::Noop), + _ => Err(SocketIoError::InvalidPacket(format!( + "unknown engine packet type {packet_type}" + ))), + } +} + +fn non_empty(value: &str) -> Option { + if value.is_empty() { + None + } else { + Some(value.to_owned()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn polling_payload_uses_record_separator() { + let packets = vec![ + EnginePacket::Message(SocketPayload::Text("40".to_owned())), + EnginePacket::Message(SocketPayload::Text( + "42[\"ready\",null]".to_owned(), + )), + ]; + let encoded = encode_engine_payload(&packets, true); + assert_eq!(decode_engine_payload(&encoded).unwrap(), packets); + } +} diff --git a/lib/socketio/error.rs b/lib/socketio/error.rs new file mode 100644 index 0000000..1d27a94 --- /dev/null +++ b/lib/socketio/error.rs @@ -0,0 +1,19 @@ +#[derive(Debug, thiserror::Error)] +pub enum SocketIoError { + #[error("invalid packet: {0}")] + InvalidPacket(String), + #[error("unknown session")] + UnknownSession, + #[error("unknown namespace: {0}")] + UnknownNamespace(String), + #[error("ack timeout")] + AckTimeout, + #[error("serialization failed: {0}")] + Serialization(#[from] serde_json::Error), + #[error("redis failed: {0}")] + Redis(#[from] redis::RedisError), + #[error("adapter failed: {0}")] + Adapter(String), +} + +pub type Result = std::result::Result; diff --git a/lib/socketio/lib.rs b/lib/socketio/lib.rs new file mode 100644 index 0000000..d2f6f4e --- /dev/null +++ b/lib/socketio/lib.rs @@ -0,0 +1,23 @@ +mod actix; +mod adapter; +mod config; +mod engine_packet; +mod error; +mod namespace; +mod nats; +mod packet; +mod redis; +mod server; +mod session; +mod socket; + +pub use actix::{configure, configure_at}; +pub use adapter::{Adapter, BroadcastOptions, MemoryAdapter}; +pub use config::SocketIoConfig; +pub use error::{Result, SocketIoError}; +pub use namespace::Broadcast; +pub use nats::{NatsJetStreamAdapter, NatsJetStreamAdapterConfig}; +pub use packet::{EventPayload, Packet, PacketType}; +pub use redis::{RedisClusterAdapter, RedisClusterAdapterConfig}; +pub use server::{Namespace, SocketIo, SocketIoBuilder}; +pub use socket::{AckSender, DisconnectReason, Socket}; diff --git a/lib/socketio/namespace.rs b/lib/socketio/namespace.rs new file mode 100644 index 0000000..c9e06e3 --- /dev/null +++ b/lib/socketio/namespace.rs @@ -0,0 +1,152 @@ +use std::{ + collections::HashMap, collections::HashSet, future::Future, sync::Arc, +}; + +use serde::Serialize; +use serde_json::Value; +use tokio::sync::RwLock; + +use crate::{ + adapter::{Adapter, BroadcastOptions}, + config::SocketIoConfig, + error::Result, + packet::EventPayload, + server::{Inner, Namespace, NamespaceState, SocketIo, SocketIoBuilder}, + socket::{DisconnectReason, Socket}, +}; + +impl SocketIoBuilder { + pub fn config(mut self, config: SocketIoConfig) -> Self { + self.config = config; + self + } + + pub fn adapter(mut self, adapter: Arc) -> Self { + self.adapter = adapter; + self + } + + pub fn build(self) -> SocketIo { + let mut namespaces = HashMap::new(); + namespaces.insert("/".to_owned(), Arc::new(NamespaceState::default())); + SocketIo { + inner: Arc::new(Inner { + config: self.config, + sessions: RwLock::new(HashMap::new()), + namespaces: RwLock::new(namespaces), + adapter: self.adapter, + next_ack_id: std::sync::atomic::AtomicU64::new(1), + }), + } + } +} + +impl Namespace { + pub async fn on_connect(&self, handler: F) + where + F: Fn(Socket) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + let state = self.io.ensure_namespace(&self.name).await; + *state.connect_handler.write().await = + Some(Arc::new(move |socket| Box::pin(handler(socket)))); + } + + pub async fn on_disconnect(&self, handler: F) + where + F: Fn(Socket, DisconnectReason) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + let state = self.io.ensure_namespace(&self.name).await; + *state.disconnect_handler.write().await = + Some(Arc::new(move |socket, reason| { + Box::pin(handler(socket, reason)) + })); + } + + pub async fn use_middleware(&self, middleware: F) + where + F: Fn(Socket, Option) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let state = self.io.ensure_namespace(&self.name).await; + state + .middleware + .write() + .await + .push(Arc::new(move |socket, auth| { + Box::pin(middleware(socket, auth)) + })); + } + + pub async fn on(&self, event: impl Into, handler: F) + where + F: Fn(Socket, EventPayload) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + let state = self.io.ensure_namespace(&self.name).await; + state.event_handlers.write().await.insert( + event.into(), + Arc::new(move |socket, payload| Box::pin(handler(socket, payload))), + ); + } + + pub async fn emit(&self, event: &str, data: T) -> Result<()> { + self.to_all().emit(event, data).await + } + + pub fn to(&self, room: impl Into) -> Broadcast { + let mut rooms = HashSet::new(); + rooms.insert(room.into()); + Broadcast { + io: self.io.clone(), + opts: BroadcastOptions { + namespace: self.name.clone(), + rooms, + ..BroadcastOptions::default() + }, + } + } + + pub fn to_all(&self) -> Broadcast { + Broadcast { + io: self.io.clone(), + opts: BroadcastOptions { + namespace: self.name.clone(), + ..BroadcastOptions::default() + }, + } + } +} + +pub struct Broadcast { + io: SocketIo, + opts: BroadcastOptions, +} + +impl Broadcast { + pub fn to(mut self, room: impl Into) -> Self { + self.opts.rooms.insert(room.into()); + self + } + + pub fn except(mut self, room: impl Into) -> Self { + self.opts.except.insert(room.into()); + self + } + + pub async fn emit(self, event: &str, data: T) -> Result<()> { + self.io.broadcast_with_opts(self.opts, event, data).await + } +} + +impl Default for NamespaceState { + fn default() -> Self { + Self { + connect_handler: RwLock::new(None), + disconnect_handler: RwLock::new(None), + event_handlers: RwLock::new(HashMap::new()), + middleware: RwLock::new(Vec::new()), + } + } +} diff --git a/lib/socketio/nats.rs b/lib/socketio/nats.rs new file mode 100644 index 0000000..b262a85 --- /dev/null +++ b/lib/socketio/nats.rs @@ -0,0 +1,211 @@ +use std::{sync::Arc, time::Duration}; + +use async_nats::jetstream; +use async_trait::async_trait; +use futures_util::StreamExt; +use tracing::warn; + +use crate::{ + adapter::{Adapter, BroadcastOptions, MemoryAdapter}, + error::{Result, SocketIoError}, + packet::Packet, + server::SocketIo, +}; + +#[derive(Clone, Debug)] +pub struct NatsJetStreamAdapterConfig { + pub stream_name: String, + pub subject: String, + pub durable_name: String, + pub node_id: String, + pub max_age: Duration, + pub ack_wait: Duration, +} + +impl Default for NatsJetStreamAdapterConfig { + fn default() -> Self { + Self { + stream_name: "SOCKETIO_ADAPTER".to_owned(), + subject: "socketio.adapter.broadcast".to_owned(), + durable_name: format!("socketio-adapter-{}", uuid::Uuid::new_v4()), + node_id: uuid::Uuid::new_v4().to_string(), + max_age: Duration::from_secs(60 * 60), + ack_wait: Duration::from_secs(30), + } + } +} + +pub struct NatsJetStreamAdapter { + local: Arc, + jetstream: jetstream::Context, + config: NatsJetStreamAdapterConfig, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct NatsMessage { + origin: String, + packet: NatsPacket, + opts: BroadcastOptions, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct NatsPacket { + encoded: String, + attachments: Vec>, +} + +impl NatsJetStreamAdapter { + pub fn new( + jetstream: jetstream::Context, + config: NatsJetStreamAdapterConfig, + ) -> Arc { + Arc::new(Self { + local: MemoryAdapter::new(), + jetstream, + config, + }) + } + + pub async fn attach(self: Arc, io: SocketIo) -> Result<()> { + let stream = self + .jetstream + .get_or_create_stream(jetstream::stream::Config { + name: self.config.stream_name.clone(), + subjects: vec![self.config.subject.clone()], + max_age: self.config.max_age, + ..Default::default() + }) + .await + .map_err(|err| { + SocketIoError::Adapter(format!("nats stream failed: {err}")) + })?; + + let consumer = stream + .get_or_create_consumer( + &self.config.durable_name, + jetstream::consumer::pull::Config { + durable_name: Some(self.config.durable_name.clone()), + filter_subject: self.config.subject.clone(), + ack_wait: self.config.ack_wait, + ..Default::default() + }, + ) + .await + .map_err(|err| { + SocketIoError::Adapter(format!("nats consumer failed: {err}")) + })?; + + let node_id = self.config.node_id.clone(); + actix_web::rt::spawn(async move { + let messages = consumer.messages().await; + let mut messages = match messages { + Ok(messages) => messages, + Err(err) => { + warn!(error = %err, "failed to open nats jetstream adapter consumer"); + return; + } + }; + + while let Some(message) = messages.next().await { + let Ok(message) = message else { + warn!("failed to receive nats jetstream adapter message"); + continue; + }; + let payload = message.payload.to_vec(); + let Ok(message_data) = + serde_json::from_slice::(&payload) + else { + warn!("failed to parse nats jetstream adapter message"); + let _ = message.ack().await; + continue; + }; + if message_data.origin == node_id { + let _ = message.ack().await; + continue; + } + let Ok(mut packet) = + Packet::decode(&message_data.packet.encoded) + else { + warn!( + "failed to decode nats jetstream adapter socket.io packet" + ); + let _ = message.ack().await; + continue; + }; + packet.attachments = message_data.packet.attachments; + if let Err(err) = + io.deliver_remote_packet(message_data.opts, packet).await + { + warn!(error = %err, "failed to deliver nats jetstream adapter packet"); + } + let _ = message.ack().await; + } + }); + + Ok(()) + } +} + +#[async_trait] +impl Adapter for NatsJetStreamAdapter { + async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()> { + self.local.add_socket(namespace, sid).await + } + + async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()> { + self.local.remove_socket(namespace, sid).await + } + + async fn add_to_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()> { + self.local.add_to_room(namespace, sid, room).await + } + + async fn remove_from_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()> { + self.local.remove_from_room(namespace, sid, room).await + } + + async fn sockets( + &self, + namespace: &str, + opts: &BroadcastOptions, + ) -> Result> { + self.local.sockets(namespace, opts).await + } + + async fn publish( + &self, + packet: &Packet, + opts: &BroadcastOptions, + ) -> Result<()> { + let message = NatsMessage { + origin: self.config.node_id.clone(), + packet: NatsPacket { + encoded: packet.encode(), + attachments: packet.attachments.clone(), + }, + opts: opts.clone(), + }; + let payload = serde_json::to_vec(&message)?; + let ack = self + .jetstream + .publish(self.config.subject.clone(), payload.into()) + .await + .map_err(|err| { + SocketIoError::Adapter(format!("nats publish failed: {err}")) + })?; + ack.await.map_err(|err| { + SocketIoError::Adapter(format!("nats publish ack failed: {err}")) + })?; + Ok(()) + } +} diff --git a/lib/socketio/packet.rs b/lib/socketio/packet.rs new file mode 100644 index 0000000..bf58dc8 --- /dev/null +++ b/lib/socketio/packet.rs @@ -0,0 +1,354 @@ +use serde_json::{Value, json}; + +use crate::error::{Result, SocketIoError}; + +#[derive( + Clone, Copy, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize, +)] +pub enum PacketType { + Connect, + Disconnect, + Event, + Ack, + ConnectError, + BinaryEvent, + BinaryAck, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Packet { + pub packet_type: PacketType, + pub namespace: String, + pub id: Option, + pub data: Option, + pub attachments: Vec>, + pub expected_attachments: usize, +} + +#[derive(Clone, Debug)] +pub struct EventPayload { + pub event: String, + pub args: Vec, + pub binary: Vec>, + pub ack_id: Option, + pub ack: Option, +} + +impl Packet { + pub fn connect(namespace: impl Into, data: Option) -> Self { + Self::new(PacketType::Connect, namespace, None, data) + } + + pub fn event( + namespace: impl Into, + event: &str, + args: Vec, + ) -> Self { + let mut data = Vec::with_capacity(args.len() + 1); + data.push(Value::String(event.to_owned())); + data.extend(args); + Self::new(PacketType::Event, namespace, None, Some(Value::Array(data))) + } + + pub fn ack( + namespace: impl Into, + id: u64, + args: Vec, + ) -> Self { + Self::new( + PacketType::Ack, + namespace, + Some(id), + Some(Value::Array(args)), + ) + } + + pub fn connect_error( + namespace: impl Into, + message: impl Into, + ) -> Self { + Self::new( + PacketType::ConnectError, + namespace, + None, + Some(json!({ "message": message.into() })), + ) + } + + pub fn new( + packet_type: PacketType, + namespace: impl Into, + id: Option, + data: Option, + ) -> Self { + Self { + packet_type, + namespace: namespace.into(), + id, + data, + attachments: Vec::new(), + expected_attachments: 0, + } + } + + pub fn with_binary(mut self, attachments: Vec>) -> Self { + self.expected_attachments = attachments.len(); + if self.expected_attachments > 0 + && let Some(Value::Array(values)) = &mut self.data + { + for num in 0..self.expected_attachments { + values.push(json!({ "_placeholder": true, "num": num })); + } + } + self.attachments = attachments; + self.packet_type = match self.packet_type { + PacketType::Ack => PacketType::BinaryAck, + PacketType::Event => PacketType::BinaryEvent, + other => other, + }; + self + } + + pub fn encode(&self) -> String { + let mut out = String::new(); + out.push(packet_type_digit(self.packet_type)); + + if matches!( + self.packet_type, + PacketType::BinaryEvent | PacketType::BinaryAck + ) { + out.push_str(&self.expected_attachments.to_string()); + out.push('-'); + } + + if self.namespace != "/" { + out.push_str(&self.namespace); + out.push(','); + } + + if let Some(id) = self.id { + out.push_str(&id.to_string()); + } + + if let Some(data) = &self.data { + out.push_str(&data.to_string()); + } + + out + } + + pub fn decode(input: &str) -> Result { + let mut chars = input.char_indices(); + let (_, first) = chars.next().ok_or_else(|| { + SocketIoError::InvalidPacket("empty packet".to_owned()) + })?; + let packet_type = packet_type_from_digit(first)?; + let mut index = first.len_utf8(); + let mut expected_attachments = 0; + + if matches!( + packet_type, + PacketType::BinaryEvent | PacketType::BinaryAck + ) { + let rest = &input[index..]; + let dash = rest.find('-').ok_or_else(|| { + SocketIoError::InvalidPacket( + "binary packet missing attachment count".to_owned(), + ) + })?; + expected_attachments = rest[..dash].parse().map_err(|_| { + SocketIoError::InvalidPacket( + "invalid attachment count".to_owned(), + ) + })?; + index += dash + 1; + } + + let namespace = if input[index..].starts_with('/') { + let rest = &input[index..]; + if let Some(comma) = rest.find(',') { + index += comma + 1; + rest[..comma].to_owned() + } else { + index = input.len(); + rest.to_owned() + } + } else { + "/".to_owned() + }; + + let id_start = index; + while let Some(ch) = input[index..].chars().next() { + if ch.is_ascii_digit() { + index += ch.len_utf8(); + } else { + break; + } + } + let id = if index > id_start { + Some(input[id_start..index].parse().map_err(|_| { + SocketIoError::InvalidPacket( + "invalid acknowledgment id".to_owned(), + ) + })?) + } else { + None + }; + + let data = if index < input.len() { + Some(serde_json::from_str(&input[index..])?) + } else { + None + }; + + Ok(Self { + packet_type, + namespace, + id, + data, + attachments: Vec::new(), + expected_attachments, + }) + } + + pub(crate) fn into_event_payload( + self, + ack: Option, + ) -> Result { + let values = match self.data { + Some(Value::Array(values)) if !values.is_empty() => values, + _ => { + return Err(SocketIoError::InvalidPacket( + "event payload must be a non-empty array".to_owned(), + )); + } + }; + let mut values = values.into_iter(); + let event = values + .next() + .and_then(|value| value.as_str().map(ToOwned::to_owned)) + .ok_or_else(|| { + SocketIoError::InvalidPacket( + "event name must be a string".to_owned(), + ) + })?; + + let args = values.filter(|value| !is_placeholder(value)).collect(); + + Ok(EventPayload { + event, + args, + binary: self.attachments, + ack_id: self.id, + ack, + }) + } +} + +fn is_placeholder(value: &Value) -> bool { + value + .as_object() + .and_then(|object| object.get("_placeholder")) + .and_then(Value::as_bool) + .unwrap_or(false) +} + +fn packet_type_digit(packet_type: PacketType) -> char { + match packet_type { + PacketType::Connect => '0', + PacketType::Disconnect => '1', + PacketType::Event => '2', + PacketType::Ack => '3', + PacketType::ConnectError => '4', + PacketType::BinaryEvent => '5', + PacketType::BinaryAck => '6', + } +} + +fn packet_type_from_digit(value: char) -> Result { + match value { + '0' => Ok(PacketType::Connect), + '1' => Ok(PacketType::Disconnect), + '2' => Ok(PacketType::Event), + '3' => Ok(PacketType::Ack), + '4' => Ok(PacketType::ConnectError), + '5' => Ok(PacketType::BinaryEvent), + '6' => Ok(PacketType::BinaryAck), + _ => Err(SocketIoError::InvalidPacket(format!( + "unknown socket packet type {value}" + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn event_packet_round_trips() { + let packet = + Packet::event("/", "message", vec![json!({ "body": "hello" })]); + assert_eq!(packet.encode(), "2[\"message\",{\"body\":\"hello\"}]"); + + let decoded = Packet::decode(&packet.encode()).unwrap(); + assert_eq!(decoded.packet_type, PacketType::Event); + assert_eq!(decoded.namespace, "/"); + assert_eq!(decoded.data, packet.data); + } + + #[test] + fn event_packet_accepts_namespace_and_ack_id() { + let decoded = + Packet::decode("2/admin,17[\"save\",{\"ok\":true}]").unwrap(); + assert_eq!(decoded.packet_type, PacketType::Event); + assert_eq!(decoded.namespace, "/admin"); + assert_eq!(decoded.id, Some(17)); + } + + #[test] + fn binary_packet_accepts_attachment_count() { + let decoded = Packet::decode( + "51-/admin,13[\"file\",{\"_placeholder\":true,\"num\":0}]", + ) + .unwrap(); + assert_eq!(decoded.packet_type, PacketType::BinaryEvent); + assert_eq!(decoded.expected_attachments, 1); + assert_eq!(decoded.namespace, "/admin"); + assert_eq!(decoded.id, Some(13)); + } + + #[test] + fn binary_emit_adds_placeholders() { + let packet = Packet::event("/", "file", vec![json!("meta")]) + .with_binary(vec![vec![1, 2]]); + assert_eq!( + packet.encode(), + "51-[\"file\",\"meta\",{\"_placeholder\":true,\"num\":0}]" + ); + } + + #[test] + fn decode_empty_packet_returns_error() { + assert!(Packet::decode("").is_err()); + } + + #[test] + fn decode_unknown_type_returns_error() { + assert!(Packet::decode("9[]").is_err()); + } + + #[test] + fn default_namespace_omits_prefix_in_encode() { + let packet = Packet::event("/", "ping", vec![]); + assert_eq!(packet.encode(), "2[\"ping\"]"); + } + + #[test] + fn custom_namespace_includes_trailing_comma() { + let packet = Packet::event("/chat", "msg", vec![json!("hi")]); + let encoded = packet.encode(); + assert!(encoded.starts_with("2/chat,")); + let decoded = Packet::decode(&encoded).unwrap(); + assert_eq!(decoded.namespace, "/chat"); + } +} diff --git a/lib/socketio/redis.rs b/lib/socketio/redis.rs new file mode 100644 index 0000000..dad2bbe --- /dev/null +++ b/lib/socketio/redis.rs @@ -0,0 +1,177 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use redis::{AsyncCommands, Msg, PushInfo}; +use tracing::warn; + +use crate::{ + adapter::{Adapter, BroadcastOptions, MemoryAdapter}, + error::Result, + packet::Packet, + server::SocketIo, +}; + +#[derive(Clone, Debug)] +pub struct RedisClusterAdapterConfig { + pub channel_prefix: String, + pub node_id: String, +} + +impl Default for RedisClusterAdapterConfig { + fn default() -> Self { + Self { + channel_prefix: "socket.io:{adapter}".to_owned(), + node_id: uuid::Uuid::new_v4().to_string(), + } + } +} + +pub struct RedisClusterAdapter { + local: Arc, + pool: deadpool_redis::cluster::Pool, + config: RedisClusterAdapterConfig, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct RedisMessage { + origin: String, + packet: RedisPacket, + opts: BroadcastOptions, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct RedisPacket { + encoded: String, + attachments: Vec>, +} + +impl RedisClusterAdapter { + pub fn new( + pool: deadpool_redis::cluster::Pool, + config: RedisClusterAdapterConfig, + ) -> Arc { + Arc::new(Self { + local: MemoryAdapter::new(), + pool, + config, + }) + } + + pub async fn attach_with_push_receiver( + self: Arc, + io: SocketIo, + mut rx: tokio::sync::mpsc::UnboundedReceiver, + ) -> Result<()> { + let channel = self.channel(); + let mut conn = self.pool.get().await.map_err(|err| { + redis::RedisError::from(( + redis::ErrorKind::Io, + "redis pool failed", + err.to_string(), + )) + })?; + let _: () = conn.subscribe(&channel).await?; + let node_id = self.config.node_id.clone(); + + actix_web::rt::spawn(async move { + let mut _subscribed_conn = conn; + while let Some(push) = rx.recv().await { + let Some(message) = Msg::from_push_info(push) else { + continue; + }; + let Ok(payload) = message.get_payload::() else { + warn!("failed to decode redis adapter payload"); + continue; + }; + let Ok(message) = + serde_json::from_str::(&payload) + else { + warn!("failed to parse redis adapter message"); + continue; + }; + if message.origin == node_id { + continue; + } + let Ok(mut packet) = Packet::decode(&message.packet.encoded) + else { + warn!("failed to decode redis adapter socket.io packet"); + continue; + }; + packet.attachments = message.packet.attachments; + if let Err(err) = + io.deliver_remote_packet(message.opts, packet).await + { + warn!(error = %err, "failed to deliver redis adapter packet"); + } + } + }); + + Ok(()) + } + + fn channel(&self) -> String { + format!("{}:broadcast", self.config.channel_prefix) + } +} + +#[async_trait] +impl Adapter for RedisClusterAdapter { + async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()> { + self.local.add_socket(namespace, sid).await + } + + async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()> { + self.local.remove_socket(namespace, sid).await + } + + async fn add_to_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()> { + self.local.add_to_room(namespace, sid, room).await + } + + async fn remove_from_room( + &self, + namespace: &str, + sid: &str, + room: &str, + ) -> Result<()> { + self.local.remove_from_room(namespace, sid, room).await + } + + async fn sockets( + &self, + namespace: &str, + opts: &BroadcastOptions, + ) -> Result> { + self.local.sockets(namespace, opts).await + } + + async fn publish( + &self, + packet: &Packet, + opts: &BroadcastOptions, + ) -> Result<()> { + let message = RedisMessage { + origin: self.config.node_id.clone(), + packet: RedisPacket { + encoded: packet.encode(), + attachments: packet.attachments.clone(), + }, + opts: opts.clone(), + }; + let payload = serde_json::to_string(&message)?; + let mut conn = self.pool.get().await.map_err(|err| { + redis::RedisError::from(( + redis::ErrorKind::Io, + "redis pool failed", + err.to_string(), + )) + })?; + let _: usize = conn.publish(self.channel(), payload).await?; + Ok(()) + } +} diff --git a/lib/socketio/server.rs b/lib/socketio/server.rs new file mode 100644 index 0000000..6d3b913 --- /dev/null +++ b/lib/socketio/server.rs @@ -0,0 +1,625 @@ +use std::{ + collections::{HashMap, HashSet}, + future::Future, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, +}; + +use serde::Serialize; +use serde_json::{Value, json}; +use tokio::sync::RwLock; + +use crate::{ + adapter::{Adapter, BroadcastOptions, MemoryAdapter}, + config::SocketIoConfig, + engine_packet::SocketPayload, + error::{Result, SocketIoError}, + packet::{EventPayload, Packet, PacketType}, + session::{PendingBinary, Session, SocketState}, + socket::{AckSender, DisconnectReason, Socket}, +}; + +pub(crate) type BoxFuture = Pin + Send>>; +pub(crate) type ConnectHandler = Arc BoxFuture + Send + Sync>; +pub(crate) type DisconnectHandler = + Arc BoxFuture + Send + Sync>; +pub(crate) type EventHandler = + Arc BoxFuture + Send + Sync>; +pub(crate) type Middleware = Arc< + dyn Fn( + Socket, + Option, + ) -> Pin> + Send>> + + Send + + Sync, +>; + +#[derive(Clone)] +pub struct SocketIo { + pub(crate) inner: Arc, +} + +#[derive(Clone)] +pub struct Namespace { + pub(crate) io: SocketIo, + pub(crate) name: String, +} + +pub struct SocketIoBuilder { + pub(crate) config: SocketIoConfig, + pub(crate) adapter: Arc, +} + +pub(crate) struct Inner { + pub(crate) config: SocketIoConfig, + pub(crate) sessions: RwLock>>, + pub(crate) namespaces: RwLock>>, + pub(crate) adapter: Arc, + pub(crate) next_ack_id: AtomicU64, +} + +pub(crate) struct NamespaceState { + pub(crate) connect_handler: RwLock>, + pub(crate) disconnect_handler: RwLock>, + pub(crate) event_handlers: RwLock>, + pub(crate) middleware: RwLock>, +} + +impl Default for SocketIo { + fn default() -> Self { + Self::new() + } +} + +impl SocketIo { + pub fn builder() -> SocketIoBuilder { + SocketIoBuilder { + config: SocketIoConfig::default(), + adapter: MemoryAdapter::new(), + } + } + + pub fn new() -> Self { + Self::builder().build() + } + + pub fn config(&self) -> &SocketIoConfig { + &self.inner.config + } + + pub async fn namespace(&self, name: impl Into) -> Namespace { + let name = normalize_namespace(name.into()); + self.ensure_namespace(&name).await; + Namespace { + io: self.clone(), + name, + } + } + + pub async fn on_connect(&self, handler: F) + where + F: Fn(Socket) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.namespace("/").await.on_connect(handler).await; + } + + pub async fn on_disconnect(&self, handler: F) + where + F: Fn(Socket, DisconnectReason) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.namespace("/").await.on_disconnect(handler).await; + } + + pub async fn on(&self, event: impl Into, handler: F) + where + F: Fn(Socket, EventPayload) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.namespace("/").await.on(event, handler).await; + } + + pub async fn emit(&self, event: &str, data: T) -> Result<()> { + self.namespace("/").await.emit(event, data).await + } + + pub async fn emit_to_room( + &self, + room: &str, + event: &str, + data: T, + ) -> Result<()> { + self.namespace("/") + .await + .to(room.to_owned()) + .emit(event, data) + .await + } + + pub(crate) async fn session(&self, sid: &str) -> Option> { + self.inner.sessions.read().await.get(sid).cloned() + } + + pub(crate) async fn insert_session(&self, session: Arc) { + self.inner + .sessions + .write() + .await + .insert(session.engine_id.clone(), session); + } + + pub(crate) async fn remove_session( + &self, + session: &Arc, + reason: DisconnectReason, + ) { + self.inner.sessions.write().await.remove(&session.engine_id); + let namespaces = session + .namespaces + .lock() + .await + .keys() + .cloned() + .collect::>(); + for namespace in namespaces { + let _ = self + .disconnect_socket(&namespace, session, reason.clone()) + .await; + } + } + + pub(crate) async fn handle_socket_payload( + &self, + session: Arc, + payload: SocketPayload, + ) -> Result<()> { + match payload { + SocketPayload::Text(text) => { + let packet = Packet::decode(&text)?; + if packet.expected_attachments == 0 { + self.handle_socket_packet(session, packet).await + } else { + *session.pending_binary.lock().await = + Some(PendingBinary { packet }); + Ok(()) + } + } + SocketPayload::Binary(bytes) => { + let mut pending = session.pending_binary.lock().await; + let Some(mut pending_binary) = pending.take() else { + return Err(SocketIoError::InvalidPacket( + "unexpected binary attachment".to_owned(), + )); + }; + pending_binary.packet.attachments.push(bytes); + if pending_binary.packet.attachments.len() + == pending_binary.packet.expected_attachments + { + drop(pending); + self.handle_socket_packet(session, pending_binary.packet) + .await + } else { + *pending = Some(pending_binary); + Ok(()) + } + } + } + } + + async fn handle_socket_packet( + &self, + session: Arc, + packet: Packet, + ) -> Result<()> { + match packet.packet_type { + PacketType::Connect => { + self.connect_namespace(session, packet).await + } + PacketType::Disconnect => { + self.disconnect_socket( + &packet.namespace, + &session, + DisconnectReason::Client, + ) + .await + } + PacketType::Event | PacketType::BinaryEvent => { + self.dispatch_event(session, packet).await + } + PacketType::Ack | PacketType::BinaryAck => { + self.resolve_ack(session, packet).await + } + PacketType::ConnectError => Ok(()), + } + } + + async fn connect_namespace( + &self, + session: Arc, + packet: Packet, + ) -> Result<()> { + let namespace = normalize_namespace(packet.namespace); + let state = self.namespace_state(&namespace).await?; + let sid = uuid::Uuid::new_v4().to_string(); + let socket = Socket { + io: self.clone(), + session: session.clone(), + namespace: namespace.clone(), + sid: sid.clone(), + }; + + for middleware in state.middleware.read().await.iter().cloned() { + if let Err(err) = + middleware(socket.clone(), packet.data.clone()).await + { + session + .enqueue_socket_packet(Packet::connect_error( + &namespace, + err.to_string(), + )) + .await; + return Ok(()); + } + } + + self.inner + .adapter + .add_socket(&namespace, &session.engine_id) + .await?; + session.namespaces.lock().await.insert( + namespace.clone(), + SocketState { + sid, + rooms: HashSet::new(), + auth: packet.data, + }, + ); + session + .enqueue_socket_packet(Packet::connect( + &namespace, + Some(json!({ "sid": socket.sid })), + )) + .await; + + if let Some(handler) = state.connect_handler.read().await.clone() { + handler(socket).await; + } + Ok(()) + } + + pub(crate) async fn disconnect_socket( + &self, + namespace: &str, + session: &Arc, + reason: DisconnectReason, + ) -> Result<()> { + let namespace = normalize_namespace(namespace.to_owned()); + let removed = session.namespaces.lock().await.remove(&namespace); + if let Some(socket_state) = removed { + self.inner + .adapter + .remove_socket(&namespace, &session.engine_id) + .await?; + if let Ok(state) = self.namespace_state(&namespace).await + && let Some(handler) = + state.disconnect_handler.read().await.clone() + { + handler( + Socket { + io: self.clone(), + session: session.clone(), + namespace, + sid: socket_state.sid, + }, + reason, + ) + .await; + } + } + Ok(()) + } + + async fn dispatch_event( + &self, + session: Arc, + packet: Packet, + ) -> Result<()> { + let namespace = normalize_namespace(packet.namespace.clone()); + let state = self.namespace_state(&namespace).await?; + let socket_state = session + .namespaces + .lock() + .await + .get(&namespace) + .map(|state| state.sid.clone()) + .ok_or_else(|| { + SocketIoError::UnknownNamespace(namespace.clone()) + })?; + let ack = packet + .id + .map(|id| AckSender::new(session.clone(), namespace.clone(), id)); + let payload = packet.into_event_payload(ack)?; + let handler = state + .event_handlers + .read() + .await + .get(&payload.event) + .cloned(); + if let Some(handler) = handler { + handler( + Socket { + io: self.clone(), + session, + namespace, + sid: socket_state, + }, + payload, + ) + .await; + } + Ok(()) + } + + async fn resolve_ack( + &self, + session: Arc, + packet: Packet, + ) -> Result<()> { + let Some(id) = packet.id else { + return Ok(()); + }; + let args = match packet.data { + Some(Value::Array(values)) => values, + Some(value) => vec![value], + None => Vec::new(), + }; + if let Some(sender) = session + .ack_waiters + .lock() + .await + .remove(&(normalize_namespace(packet.namespace), id)) + { + let _ = sender.send(args); + } + Ok(()) + } + + pub(crate) async fn join( + &self, + namespace: &str, + engine_id: &str, + room: String, + ) -> Result<()> { + let namespace = normalize_namespace(namespace.to_owned()); + let session = self + .session(engine_id) + .await + .ok_or(SocketIoError::UnknownSession)?; + if let Some(state) = session.namespaces.lock().await.get_mut(&namespace) + { + state.rooms.insert(room.clone()); + } + self.inner + .adapter + .add_to_room(&namespace, engine_id, &room) + .await + } + + pub(crate) async fn leave( + &self, + namespace: &str, + engine_id: &str, + room: &str, + ) -> Result<()> { + let namespace = normalize_namespace(namespace.to_owned()); + let session = self + .session(engine_id) + .await + .ok_or(SocketIoError::UnknownSession)?; + if let Some(state) = session.namespaces.lock().await.get_mut(&namespace) + { + state.rooms.remove(room); + } + self.inner + .adapter + .remove_from_room(&namespace, engine_id, room) + .await + } + + pub(crate) async fn emit_to_sid( + &self, + namespace: &str, + engine_id: &str, + event: &str, + data: T, + ) -> Result<()> { + let args = value_to_args(serde_json::to_value(data)?); + self.emit_packet_to_sid( + engine_id, + Packet::event(namespace, event, args), + ) + .await + } + + pub(crate) async fn emit_binary_to_sid( + &self, + namespace: &str, + engine_id: &str, + event: &str, + args: Vec, + binary: Vec>, + ) -> Result<()> { + self.emit_packet_to_sid( + engine_id, + Packet::event(namespace, event, args).with_binary(binary), + ) + .await + } + + pub(crate) async fn emit_to_sid_with_ack( + &self, + namespace: &str, + engine_id: &str, + event: &str, + data: T, + ) -> Result> { + let session = self + .session(engine_id) + .await + .ok_or(SocketIoError::UnknownSession)?; + let id = self.inner.next_ack_id.fetch_add(1, Ordering::Relaxed); + let (tx, rx) = tokio::sync::oneshot::channel(); + session + .ack_waiters + .lock() + .await + .insert((normalize_namespace(namespace.to_owned()), id), tx); + let mut packet = Packet::event( + namespace, + event, + value_to_args(serde_json::to_value(data)?), + ); + packet.id = Some(id); + session.enqueue_socket_packet(packet).await; + + match tokio::time::timeout(self.inner.config.ack_timeout, rx).await { + Ok(Ok(values)) => Ok(values), + _ => { + session + .ack_waiters + .lock() + .await + .remove(&(normalize_namespace(namespace.to_owned()), id)); + Err(SocketIoError::AckTimeout) + } + } + } + + pub(crate) async fn broadcast_with_opts( + &self, + mut opts: BroadcastOptions, + event: &str, + data: T, + ) -> Result<()> { + opts.namespace = normalize_namespace(opts.namespace); + let packet = Packet::event( + &opts.namespace, + event, + value_to_args(serde_json::to_value(data)?), + ); + self.broadcast_packet(opts, packet).await + } + + async fn broadcast_packet( + &self, + opts: BroadcastOptions, + packet: Packet, + ) -> Result<()> { + let sockets = + self.inner.adapter.sockets(&opts.namespace, &opts).await?; + let mut failures = Vec::new(); + for engine_id in sockets { + if let Err(err) = + self.emit_packet_to_sid(&engine_id, packet.clone()).await + { + failures.push(format!("{engine_id}: {err}")); + } + } + if let Err(err) = self.inner.adapter.publish(&packet, &opts).await { + failures.push(format!("adapter publish: {err}")); + } + if failures.is_empty() { + Ok(()) + } else { + Err(SocketIoError::Adapter(format!( + "broadcast partially failed: {}", + failures.join("; ") + ))) + } + } + + pub(crate) async fn deliver_remote_packet( + &self, + opts: BroadcastOptions, + packet: Packet, + ) -> Result<()> { + let sockets = + self.inner.adapter.sockets(&opts.namespace, &opts).await?; + let mut failures = Vec::new(); + for engine_id in sockets { + if let Err(err) = + self.emit_packet_to_sid(&engine_id, packet.clone()).await + { + failures.push(format!("{engine_id}: {err}")); + } + } + if failures.is_empty() { + Ok(()) + } else { + Err(SocketIoError::Adapter(format!( + "remote broadcast partially failed: {}", + failures.join("; ") + ))) + } + } + + async fn emit_packet_to_sid( + &self, + engine_id: &str, + packet: Packet, + ) -> Result<()> { + let session = self + .session(engine_id) + .await + .ok_or(SocketIoError::UnknownSession)?; + session.enqueue_socket_packet(packet).await; + Ok(()) + } + + pub(crate) async fn ensure_namespace( + &self, + namespace: &str, + ) -> Arc { + let mut namespaces = self.inner.namespaces.write().await; + namespaces + .entry(namespace.to_owned()) + .or_insert_with(|| Arc::new(NamespaceState::default())) + .clone() + } + + async fn namespace_state( + &self, + namespace: &str, + ) -> Result> { + self.inner + .namespaces + .read() + .await + .get(namespace) + .cloned() + .ok_or_else(|| { + SocketIoError::UnknownNamespace(namespace.to_owned()) + }) + } +} + +fn value_to_args(value: Value) -> Vec { + match value { + Value::Array(values) => values, + value => vec![value], + } +} + +fn normalize_namespace(namespace: String) -> String { + if namespace.is_empty() || namespace == "/" { + "/".to_owned() + } else if namespace.starts_with('/') { + namespace + } else { + format!("/{namespace}") + } +} diff --git a/lib/socketio/session.rs b/lib/socketio/session.rs new file mode 100644 index 0000000..707f2ba --- /dev/null +++ b/lib/socketio/session.rs @@ -0,0 +1,90 @@ +use std::{ + collections::{HashMap, HashSet, VecDeque}, + sync::Arc, + sync::Mutex as StdMutex, + sync::atomic::AtomicBool, + time::Instant, +}; + +use serde_json::Value; +use tokio::sync::{Mutex, Notify, oneshot}; +use uuid::Uuid; + +use crate::{ + engine_packet::{EnginePacket, SocketPayload}, + packet::Packet, +}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum Transport { + Polling, + WebSocket, +} + +#[derive(Debug)] +pub(crate) struct SocketState { + pub(crate) sid: String, + pub(crate) rooms: HashSet, + pub(crate) auth: Option, +} + +pub(crate) struct PendingBinary { + pub(crate) packet: Packet, +} + +pub(crate) struct Session { + pub(crate) engine_id: String, + pub(crate) user: StdMutex>, + pub(crate) transport: Mutex, + pub(crate) namespaces: Mutex>, + pub(crate) pending_binary: Mutex>, + pub(crate) ack_waiters: Mutex, + pub(crate) last_pong: Mutex, + queue: Mutex>, + pub(crate) get_active: Arc, + pub(crate) post_active: Arc, + pub(crate) notify: Notify, +} + +pub(crate) type AckWaiters = + HashMap<(String, u64), oneshot::Sender>>; + +impl Session { + pub(crate) fn new(user: Option) -> Arc { + Arc::new(Self { + engine_id: Uuid::new_v4().to_string(), + user: StdMutex::new(user), + transport: Mutex::new(Transport::Polling), + namespaces: Mutex::new(HashMap::new()), + pending_binary: Mutex::new(None), + ack_waiters: Mutex::new(HashMap::new()), + last_pong: Mutex::new(Instant::now()), + queue: Mutex::new(VecDeque::new()), + get_active: Arc::new(AtomicBool::new(false)), + post_active: Arc::new(AtomicBool::new(false)), + notify: Notify::new(), + }) + } + + pub(crate) async fn enqueue(&self, packet: EnginePacket) { + self.queue.lock().await.push_back(packet); + self.notify.notify_waiters(); + } + + pub(crate) async fn enqueue_socket_packet(&self, packet: Packet) { + self.enqueue(EnginePacket::Message(SocketPayload::Text( + packet.encode(), + ))) + .await; + for attachment in packet.attachments { + self.enqueue(EnginePacket::Message(SocketPayload::Binary( + attachment, + ))) + .await; + } + } + + pub(crate) async fn drain(&self) -> Vec { + self.queue.lock().await.drain(..).collect() + } +} diff --git a/lib/socketio/socket.rs b/lib/socketio/socket.rs new file mode 100644 index 0000000..5e57be1 --- /dev/null +++ b/lib/socketio/socket.rs @@ -0,0 +1,185 @@ +use std::fmt; +use std::{collections::HashSet, sync::Arc}; + +use serde::Serialize; +use serde_json::Value; + +use crate::{ + adapter::BroadcastOptions, error::Result, packet::Packet, server::SocketIo, + session::Session, +}; + +#[derive(Clone)] +pub struct Socket { + pub(crate) io: SocketIo, + pub(crate) session: Arc, + pub(crate) namespace: String, + pub(crate) sid: String, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum DisconnectReason { + Client, + TransportClosed, + PingTimeout, + Server, +} + +#[derive(Clone)] +pub struct AckSender { + session: Arc, + namespace: String, + id: u64, +} + +impl fmt::Debug for AckSender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AckSender") + .field("namespace", &self.namespace) + .field("id", &self.id) + .finish_non_exhaustive() + } +} + +impl AckSender { + pub(crate) fn new( + session: Arc, + namespace: String, + id: u64, + ) -> Self { + Self { + session, + namespace, + id, + } + } + + pub async fn send(&self, data: T) -> Result<()> { + let args = match serde_json::to_value(data)? { + Value::Array(values) => values, + value => vec![value], + }; + self.session + .enqueue_socket_packet(Packet::ack(&self.namespace, self.id, args)) + .await; + Ok(()) + } +} + +impl Socket { + pub fn id(&self) -> &str { + &self.sid + } + + pub fn namespace(&self) -> &str { + &self.namespace + } + + pub fn session_user(&self) -> Option { + self.session.user.lock().unwrap_or_else(|e| e.into_inner()).clone() + } + + pub fn set_user(&self, user: uuid::Uuid) { + *self.session.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user); + } + + pub async fn rooms(&self) -> HashSet { + self.session + .namespaces + .lock() + .await + .get(&self.namespace) + .map(|state| state.rooms.clone()) + .unwrap_or_default() + } + + pub async fn auth(&self) -> Option { + self.session + .namespaces + .lock() + .await + .get(&self.namespace) + .and_then(|state| state.auth.clone()) + } + + pub async fn emit(&self, event: &str, data: T) -> Result<()> { + self.io + .emit_to_sid(&self.namespace, &self.session.engine_id, event, data) + .await + } + + pub async fn emit_with_ack( + &self, + event: &str, + data: T, + ) -> Result> { + self.io + .emit_to_sid_with_ack( + &self.namespace, + &self.session.engine_id, + event, + data, + ) + .await + } + + pub async fn emit_binary( + &self, + event: &str, + args: Vec, + binary: Vec>, + ) -> Result<()> { + self.io + .emit_binary_to_sid( + &self.namespace, + &self.session.engine_id, + event, + args, + binary, + ) + .await + } + + pub async fn broadcast( + &self, + event: &str, + data: T, + ) -> Result<()> { + let opts = BroadcastOptions { + namespace: self.namespace.clone(), + skip_sid: Some(self.session.engine_id.clone()), + ..BroadcastOptions::default() + }; + self.io.broadcast_with_opts(opts, event, data).await + } + + pub async fn join(&self, room: impl Into) -> Result<()> { + self.io + .join(&self.namespace, &self.session.engine_id, room.into()) + .await + } + + pub async fn leave(&self, room: &str) -> Result<()> { + self.io + .leave(&self.namespace, &self.session.engine_id, room) + .await + } + + pub async fn disconnect(&self) -> Result<()> { + self.session + .enqueue_socket_packet(Packet::new( + crate::packet::PacketType::Disconnect, + &self.namespace, + None, + None, + )) + .await; + self.io + .disconnect_socket( + &self.namespace, + &self.session, + DisconnectReason::Server, + ) + .await + } +} diff --git a/lib/storage/Cargo.toml b/lib/storage/Cargo.toml new file mode 100644 index 0000000..0c3b4b4 --- /dev/null +++ b/lib/storage/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "storage" +version.workspace = true +edition.workspace = true +authors.workspace = true +description.workspace = true +repository.workspace = true +readme.workspace = true +homepage.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true +documentation.workspace = true + + +[lib] +path = "lib.rs" +name = "storage" +[dependencies] +async-trait = { workspace = true } +aws-config = { workspace = true } +aws-sdk-s3 = { workspace = true } +config = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["fs", "io-util", "macros", "rt-multi-thread"] } + +[lints] +workspace = true diff --git a/lib/storage/error.rs b/lib/storage/error.rs new file mode 100644 index 0000000..9bb77c4 --- /dev/null +++ b/lib/storage/error.rs @@ -0,0 +1,17 @@ +pub type StorageResult = Result; + +#[derive(Debug, thiserror::Error)] +pub enum StorageError { + #[error("storage config error: {0}")] + Config(String), + #[error("invalid storage key: {0}")] + InvalidKey(String), + #[error("storage object not found: {0}")] + NotFound(String), + #[error("local storage error: {0}")] + Local(String), + #[error("s3 error: {0}")] + S3(String), + #[error("stream error: {0}")] + Stream(String), +} diff --git a/lib/storage/lib.rs b/lib/storage/lib.rs new file mode 100644 index 0000000..dd89ef6 --- /dev/null +++ b/lib/storage/lib.rs @@ -0,0 +1,202 @@ +pub mod error; +pub mod local; +pub mod s3; + +use std::time::Duration; + +use async_trait::async_trait; +pub use aws_sdk_s3::primitives::ByteStream; +use aws_sdk_s3::primitives::ByteStreamError; +pub use error::{StorageError, StorageResult}; +pub use local::{LocalStorage, LocalStorageConfig}; +pub use s3::{S3Storage, S3StorageConfig}; + +#[derive(Clone, Debug)] +pub enum AppStorageConfig { + Local(LocalStorageConfig), + S3(S3StorageConfig), +} + +#[derive(Clone)] +pub enum AppStorage { + Local(LocalStorage), + S3(S3Storage), +} + +#[derive(Clone, Debug, Default)] +pub struct PutObjectOptions { + pub content_type: Option, + pub content_length: Option, + pub cache_control: Option, +} + +#[derive(Clone, Debug)] +pub struct StoredObject { + pub key: String, + pub url: String, + pub e_tag: Option, + pub version_id: Option, +} + +#[derive(Debug)] +pub struct StorageObjectStream { + pub body: ByteStream, + pub content_length: Option, + pub content_type: Option, + pub e_tag: Option, +} + +#[derive(Clone, Debug)] +pub struct StorageObject { + pub bytes: Vec, + pub content_length: Option, + pub content_type: Option, + pub e_tag: Option, +} + +#[async_trait] +pub trait ObjectStorage: Send + Sync { + async fn put_stream( + &self, + key: &str, + body: ByteStream, + options: PutObjectOptions, + ) -> StorageResult; + + async fn put_bytes( + &self, + key: &str, + bytes: Vec, + options: PutObjectOptions, + ) -> StorageResult; + + async fn get_stream(&self, key: &str) + -> StorageResult; + + async fn get_bytes(&self, key: &str) -> StorageResult; + + async fn delete(&self, key: &str) -> StorageResult<()>; + + fn public_url(&self, key: &str) -> StorageResult>; + + async fn presigned_get_url( + &self, + key: &str, + expires_in: Duration, + ) -> StorageResult; +} + +impl AppStorage { + pub async fn init(config: AppStorageConfig) -> StorageResult { + match config { + AppStorageConfig::Local(config) => { + Ok(Self::Local(LocalStorage::connect(config).await?)) + } + AppStorageConfig::S3(config) => { + Ok(Self::S3(S3Storage::connect(config).await?)) + } + } + } +} + +#[async_trait] +impl ObjectStorage for AppStorage { + async fn put_stream( + &self, + key: &str, + body: ByteStream, + options: PutObjectOptions, + ) -> StorageResult { + match self { + Self::Local(storage) => { + storage.put_stream(key, body, options).await + } + Self::S3(storage) => storage.put_stream(key, body, options).await, + } + } + + async fn put_bytes( + &self, + key: &str, + bytes: Vec, + options: PutObjectOptions, + ) -> StorageResult { + match self { + Self::Local(storage) => { + storage.put_bytes(key, bytes, options).await + } + Self::S3(storage) => storage.put_bytes(key, bytes, options).await, + } + } + + async fn get_stream( + &self, + key: &str, + ) -> StorageResult { + match self { + Self::Local(storage) => storage.get_stream(key).await, + Self::S3(storage) => storage.get_stream(key).await, + } + } + + async fn get_bytes(&self, key: &str) -> StorageResult { + match self { + Self::Local(storage) => storage.get_bytes(key).await, + Self::S3(storage) => storage.get_bytes(key).await, + } + } + + async fn delete(&self, key: &str) -> StorageResult<()> { + match self { + Self::Local(storage) => storage.delete(key).await, + Self::S3(storage) => storage.delete(key).await, + } + } + + fn public_url(&self, key: &str) -> StorageResult> { + match self { + Self::Local(storage) => storage.public_url(key), + Self::S3(storage) => storage.public_url(key), + } + } + + async fn presigned_get_url( + &self, + key: &str, + expires_in: Duration, + ) -> StorageResult { + match self { + Self::Local(storage) => { + storage.presigned_get_url(key, expires_in).await + } + Self::S3(storage) => { + storage.presigned_get_url(key, expires_in).await + } + } + } +} + +pub(crate) async fn collect_byte_stream( + body: ByteStream, +) -> Result, ByteStreamError> { + body.collect().await.map(|data| data.to_vec()) +} + +impl TryFrom<&config::AppConfig> for AppStorageConfig { + type Error = StorageError; + + fn try_from(config: &config::AppConfig) -> Result { + let backend = config + .storage_backend() + .map_err(|error| StorageError::Config(error.to_string()))?; + match backend.as_str() { + "local" | "fs" | "filesystem" => { + Ok(Self::Local(LocalStorageConfig::try_from(config)?)) + } + "s3" => Ok(Self::S3(S3StorageConfig::try_from(config)?)), + backend => Err(StorageError::Config(format!( + "unsupported storage backend: {backend}" + ))), + } + } +} diff --git a/lib/storage/local.rs b/lib/storage/local.rs new file mode 100644 index 0000000..75271da --- /dev/null +++ b/lib/storage/local.rs @@ -0,0 +1,380 @@ +use std::{ + path::{Component, Path, PathBuf}, + time::Duration, +}; + +use async_trait::async_trait; +use aws_sdk_s3::primitives::ByteStream; +use serde::{Deserialize, Serialize}; +use tokio::fs; + +use crate::{ + ObjectStorage, PutObjectOptions, StorageError, StorageObject, + StorageObjectStream, StorageResult, StoredObject, +}; + +#[derive(Clone, Debug)] +pub struct LocalStorageConfig { + pub root_path: PathBuf, + pub public_url: Option, +} + +#[derive(Clone)] +pub struct LocalStorage { + config: LocalStorageConfig, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +struct LocalObjectMetadata { + content_type: Option, +} + +impl LocalStorage { + pub async fn connect(config: LocalStorageConfig) -> StorageResult { + if config.root_path.as_os_str().is_empty() { + return Err(StorageError::Config( + "local storage root path cannot be empty".to_string(), + )); + } + + fs::create_dir_all(&config.root_path) + .await + .map_err(|error| StorageError::Local(error.to_string()))?; + + Ok(Self { config }) + } + + pub async fn put_bytes( + &self, + key: &str, + bytes: Vec, + options: PutObjectOptions, + ) -> StorageResult { + self.put_stream(key, ByteStream::from(bytes), options).await + } + + fn normalize_key(key: &str) -> StorageResult { + let key = key.trim().trim_start_matches('/'); + if key.is_empty() { + return Err(StorageError::InvalidKey(key.to_string())); + } + + let mut parts = Vec::new(); + for component in Path::new(key).components() { + match component { + Component::Normal(part) => { + let part = part.to_str().ok_or_else(|| { + StorageError::InvalidKey(key.to_string()) + })?; + parts.push(part); + } + Component::CurDir + | Component::ParentDir + | Component::RootDir + | Component::Prefix(_) => { + return Err(StorageError::InvalidKey(key.to_string())); + } + } + } + + if parts.is_empty() { + return Err(StorageError::InvalidKey(key.to_string())); + } + + Ok(parts.join("/")) + } + + fn object_path(&self, key: &str) -> PathBuf { + self.config.root_path.join(key) + } + + fn metadata_path(&self, key: &str) -> PathBuf { + self.config + .root_path + .join(".metadata") + .join(format!("{key}.json")) + } + + fn public_url_for_config( + config: &LocalStorageConfig, + key: &str, + ) -> Option { + let base_url = config.public_url.as_ref()?.trim_end_matches('/'); + Some(format!("{base_url}/{key}")) + } + + async fn write_metadata( + &self, + key: &str, + metadata: &LocalObjectMetadata, + ) -> StorageResult<()> { + let path = self.metadata_path(key); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .await + .map_err(|error| StorageError::Local(error.to_string()))?; + } + + let bytes = serde_json::to_vec(metadata) + .map_err(|error| StorageError::Local(error.to_string()))?; + fs::write(path, bytes) + .await + .map_err(|error| StorageError::Local(error.to_string())) + } + + async fn read_metadata( + &self, + key: &str, + ) -> StorageResult { + let path = self.metadata_path(key); + let bytes = fs::read(path).await.map_err(|error| { + if error.kind() == std::io::ErrorKind::NotFound { + StorageError::NotFound(key.to_string()) + } else { + StorageError::Local(error.to_string()) + } + })?; + serde_json::from_slice(&bytes) + .map_err(|error| StorageError::Local(error.to_string())) + } +} + +#[async_trait] +impl ObjectStorage for LocalStorage { + async fn put_stream( + &self, + key: &str, + body: ByteStream, + options: PutObjectOptions, + ) -> StorageResult { + let key = Self::normalize_key(key)?; + let path = self.object_path(&key); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .await + .map_err(|error| StorageError::Local(error.to_string()))?; + } + + let bytes = crate::collect_byte_stream(body) + .await + .map_err(|error| StorageError::Stream(error.to_string()))?; + if let Some(content_length) = options.content_length + && content_length >= 0 + && bytes.len() as i64 != content_length + { + return Err(StorageError::Local(format!( + "content length mismatch for {key}: expected {content_length}, got {}", + bytes.len() + ))); + } + + let metadata = LocalObjectMetadata { + content_type: options.content_type, + }; + + fs::write(&path, bytes) + .await + .map_err(|error| StorageError::Local(error.to_string()))?; + self.write_metadata(&key, &metadata).await?; + + Ok(StoredObject { + url: self.public_url(&key)?.unwrap_or_else(|| key.clone()), + key, + e_tag: None, + version_id: None, + }) + } + + async fn put_bytes( + &self, + key: &str, + bytes: Vec, + options: PutObjectOptions, + ) -> StorageResult { + LocalStorage::put_bytes(self, key, bytes, options).await + } + + async fn get_stream( + &self, + key: &str, + ) -> StorageResult { + let key = Self::normalize_key(key)?; + let path = self.object_path(&key); + let metadata = fs::metadata(&path).await.map_err(|error| { + if error.kind() == std::io::ErrorKind::NotFound { + StorageError::NotFound(key.clone()) + } else { + StorageError::Local(error.to_string()) + } + })?; + let object_metadata = match self.read_metadata(&key).await { + Ok(metadata) => metadata, + Err(StorageError::NotFound(_)) => LocalObjectMetadata::default(), + Err(error) => return Err(error), + }; + let body = ByteStream::from_path(&path) + .await + .map_err(|error| StorageError::Stream(error.to_string()))?; + + Ok(StorageObjectStream { + body, + content_length: Some(metadata.len() as i64), + content_type: object_metadata.content_type, + e_tag: None, + }) + } + + async fn get_bytes(&self, key: &str) -> StorageResult { + let stream = self.get_stream(key).await?; + let bytes = crate::collect_byte_stream(stream.body) + .await + .map_err(|error| StorageError::Stream(error.to_string()))?; + + Ok(StorageObject { + bytes, + content_length: stream.content_length, + content_type: stream.content_type, + e_tag: stream.e_tag, + }) + } + + async fn delete(&self, key: &str) -> StorageResult<()> { + let key = Self::normalize_key(key)?; + let path = self.object_path(&key); + fs::remove_file(path).await.map_or_else( + |error| { + if error.kind() == std::io::ErrorKind::NotFound { + Ok(()) + } else { + Err(StorageError::Local(error.to_string())) + } + }, + |_| Ok(()), + )?; + fs::remove_file(self.metadata_path(&key)).await.map_or_else( + |error| { + if error.kind() == std::io::ErrorKind::NotFound { + Ok(()) + } else { + Err(StorageError::Local(error.to_string())) + } + }, + |_| Ok(()), + ) + } + + fn public_url(&self, key: &str) -> StorageResult> { + let key = Self::normalize_key(key)?; + Ok(Self::public_url_for_config(&self.config, &key)) + } + + async fn presigned_get_url( + &self, + key: &str, + _expires_in: Duration, + ) -> StorageResult { + self.public_url(key)?.ok_or_else(|| { + StorageError::Config( + "local storage public URL is not configured".to_string(), + ) + }) + } +} + +impl TryFrom<&config::AppConfig> for LocalStorageConfig { + type Error = StorageError; + + fn try_from(config: &config::AppConfig) -> Result { + Ok(Self { + root_path: PathBuf::from(config.storage_path()), + public_url: Some(config.storage_public_url()), + }) + } +} + +#[cfg(test)] +mod tests { + use std::time::{SystemTime, UNIX_EPOCH}; + + use tokio::fs; + + use super::{LocalStorage, LocalStorageConfig}; + use crate::{ObjectStorage, PutObjectOptions, StorageError}; + + fn temp_root() -> Result { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|error| StorageError::Local(error.to_string()))? + .as_nanos(); + Ok(std::env::temp_dir().join(format!( + "gitdataai-local-storage-{}-{nanos}", + std::process::id() + ))) + } + + #[tokio::test] + async fn stores_reads_and_deletes_bytes() -> Result<(), StorageError> { + let root = temp_root()?; + let storage = LocalStorage::connect(LocalStorageConfig { + root_path: root.clone(), + public_url: Some("/files".to_string()), + }) + .await?; + + let stored = storage + .put_bytes( + "avatars/user-1.txt", + b"hello".to_vec(), + PutObjectOptions { + content_type: Some("text/plain".to_string()), + content_length: Some(5), + ..PutObjectOptions::default() + }, + ) + .await?; + assert_eq!(stored.key, "avatars/user-1.txt"); + assert_eq!(stored.url, "/files/avatars/user-1.txt"); + + let object = storage.get_bytes("avatars/user-1.txt").await?; + assert_eq!(object.bytes, b"hello"); + assert_eq!(object.content_length, Some(5)); + assert_eq!(object.content_type.as_deref(), Some("text/plain")); + + storage.delete("avatars/user-1.txt").await?; + assert!(matches!( + storage.get_bytes("avatars/user-1.txt").await, + Err(StorageError::NotFound(_)) + )); + + fs::remove_dir_all(root) + .await + .map_err(|error| StorageError::Local(error.to_string()))?; + Ok(()) + } + + #[tokio::test] + async fn rejects_path_traversal_keys() -> Result<(), StorageError> { + let storage = LocalStorage::connect(LocalStorageConfig { + root_path: temp_root()?, + public_url: Some("/files".to_string()), + }) + .await?; + + assert!(matches!( + storage + .put_bytes( + "../escape.txt", + b"bad".to_vec(), + PutObjectOptions::default() + ) + .await, + Err(StorageError::InvalidKey(_)) + )); + assert!(matches!( + storage.public_url("nested/../escape.txt"), + Err(StorageError::InvalidKey(_)) + )); + + Ok(()) + } +} diff --git a/lib/storage/s3.rs b/lib/storage/s3.rs new file mode 100644 index 0000000..da700c0 --- /dev/null +++ b/lib/storage/s3.rs @@ -0,0 +1,265 @@ +use std::time::Duration; + +use async_trait::async_trait; +use aws_config::BehaviorVersion; +use aws_sdk_s3::{ + Client, + config::{Credentials, Region}, + presigning::PresigningConfig, + primitives::ByteStream, +}; + +use crate::{ + ObjectStorage, PutObjectOptions, StorageError, StorageObject, + StorageObjectStream, StorageResult, StoredObject, +}; + +#[derive(Clone, Debug)] +pub struct S3StorageConfig { + pub bucket: String, + pub region: String, + pub endpoint_url: Option, + pub access_key_id: Option, + pub secret_access_key: Option, + pub session_token: Option, + pub force_path_style: bool, + pub public_url: Option, + pub presigned_url_ttl: Duration, +} + +#[derive(Clone)] +pub struct S3Storage { + client: Client, + config: S3StorageConfig, +} + +impl S3Storage { + pub async fn connect(config: S3StorageConfig) -> StorageResult { + let mut sdk_config = aws_config::defaults(BehaviorVersion::latest()) + .region(Region::new(config.region.clone())); + + match (&config.access_key_id, &config.secret_access_key) { + (Some(access_key_id), Some(secret_access_key)) => { + sdk_config = sdk_config.credentials_provider(Credentials::new( + access_key_id, + secret_access_key, + config.session_token.clone(), + None, + "app-storage-config", + )); + } + (None, None) => {} + _ => { + return Err(StorageError::Config( + "APP_STORAGE_S3_ACCESS_KEY_ID and APP_STORAGE_S3_SECRET_ACCESS_KEY must be set together" + .to_string(), + )); + } + } + + let sdk_config = sdk_config.load().await; + let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config) + .force_path_style(config.force_path_style); + + if let Some(endpoint_url) = &config.endpoint_url { + s3_config = s3_config.endpoint_url(endpoint_url); + } + + Ok(Self { + client: Client::from_conf(s3_config.build()), + config, + }) + } + + pub async fn put_bytes( + &self, + key: &str, + bytes: Vec, + options: PutObjectOptions, + ) -> StorageResult { + self.put_stream(key, ByteStream::from(bytes), options).await + } + + pub async fn put_file( + &self, + key: &str, + path: impl AsRef, + options: PutObjectOptions, + ) -> StorageResult { + let body = ByteStream::read_from() + .path(path.as_ref()) + .build() + .await + .map_err(|error| StorageError::Stream(error.to_string()))?; + self.put_stream(key, body, options).await + } + + fn normalize_key(key: &str) -> StorageResult { + let key = key.trim().trim_start_matches('/'); + if key.is_empty() || key.contains("..") { + return Err(StorageError::InvalidKey(key.to_string())); + } + Ok(key.to_string()) + } + + fn public_url_for_config( + config: &S3StorageConfig, + key: &str, + ) -> Option { + let base_url = config.public_url.as_ref()?.trim_end_matches('/'); + Some(format!("{base_url}/{key}")) + } +} + +#[async_trait] +impl ObjectStorage for S3Storage { + async fn put_stream( + &self, + key: &str, + body: ByteStream, + options: PutObjectOptions, + ) -> StorageResult { + let key = Self::normalize_key(key)?; + + let mut request = self + .client + .put_object() + .bucket(&self.config.bucket) + .key(&key) + .body(body); + + if let Some(content_type) = options.content_type { + request = request.content_type(content_type); + } + if let Some(content_length) = options.content_length { + request = request.content_length(content_length); + } + if let Some(cache_control) = options.cache_control { + request = request.cache_control(cache_control); + } + + let output = request + .send() + .await + .map_err(|error| StorageError::S3(error.to_string()))?; + let url = match self.public_url(&key)? { + Some(url) => url, + None => { + self.presigned_get_url(&key, self.config.presigned_url_ttl) + .await? + } + }; + + Ok(StoredObject { + key, + url, + e_tag: output.e_tag, + version_id: output.version_id, + }) + } + + async fn put_bytes( + &self, + key: &str, + bytes: Vec, + options: PutObjectOptions, + ) -> StorageResult { + S3Storage::put_bytes(self, key, bytes, options).await + } + + async fn get_stream( + &self, + key: &str, + ) -> StorageResult { + let key = Self::normalize_key(key)?; + let output = self + .client + .get_object() + .bucket(&self.config.bucket) + .key(&key) + .send() + .await + .map_err(|error| StorageError::S3(error.to_string()))?; + + Ok(StorageObjectStream { + body: output.body, + content_length: output.content_length, + content_type: output.content_type, + e_tag: output.e_tag, + }) + } + + async fn get_bytes(&self, key: &str) -> StorageResult { + let stream = self.get_stream(key).await?; + let bytes = crate::collect_byte_stream(stream.body) + .await + .map_err(|error| StorageError::Stream(error.to_string()))?; + + Ok(StorageObject { + bytes, + content_length: stream.content_length, + content_type: stream.content_type, + e_tag: stream.e_tag, + }) + } + + async fn delete(&self, key: &str) -> StorageResult<()> { + let key = Self::normalize_key(key)?; + self.client + .delete_object() + .bucket(&self.config.bucket) + .key(key) + .send() + .await + .map_err(|error| StorageError::S3(error.to_string()))?; + Ok(()) + } + + fn public_url(&self, key: &str) -> StorageResult> { + let key = Self::normalize_key(key)?; + Ok(Self::public_url_for_config(&self.config, &key)) + } + + async fn presigned_get_url( + &self, + key: &str, + expires_in: Duration, + ) -> StorageResult { + let key = Self::normalize_key(key)?; + let config = PresigningConfig::expires_in(expires_in) + .map_err(|error| StorageError::Config(error.to_string()))?; + let request = self + .client + .get_object() + .bucket(&self.config.bucket) + .key(key) + .presigned(config) + .await + .map_err(|error| StorageError::S3(error.to_string()))?; + Ok(request.uri().to_string()) + } +} + +impl TryFrom<&config::AppConfig> for S3StorageConfig { + type Error = StorageError; + + fn try_from(config: &config::AppConfig) -> Result { + Ok(Self { + bucket: config + .storage_s3_bucket() + .map_err(|error| StorageError::Config(error.to_string()))?, + region: config.storage_s3_region(), + endpoint_url: config.storage_s3_endpoint_url(), + access_key_id: config.storage_s3_access_key_id(), + secret_access_key: config.storage_s3_secret_access_key(), + session_token: config.storage_s3_session_token(), + force_path_style: config + .storage_s3_force_path_style() + .map_err(|error| StorageError::Config(error.to_string()))?, + public_url: config.storage_public_url_base(), + presigned_url_ttl: config + .storage_presigned_url_ttl() + .map_err(|error| StorageError::Config(error.to_string()))?, + }) + } +} diff --git a/public/favicon.svg b/public/favicon.svg new file mode 100644 index 0000000..6893eb1 --- /dev/null +++ b/public/favicon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/public/icons.svg b/public/icons.svg new file mode 100644 index 0000000..e952219 --- /dev/null +++ b/public/icons.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..6a8e4c3 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,18 @@ +# Rustfmt configuration +edition = "2024" +max_width = 80 +hard_tabs = false +tab_spaces = 4 +newline_style = "Unix" +use_small_heuristics = "Default" +reorder_imports = true +reorder_modules = true +remove_nested_parens = true +merge_derives = true +use_try_shorthand = false +use_field_init_shorthand = false +force_explicit_abi = true +imports_granularity = "Crate" +group_imports = "StdExternalCrate" +format_macro_bodies = true +format_macro_matchers = true \ No newline at end of file diff --git a/src/App.css b/src/App.css new file mode 100644 index 0000000..f90339d --- /dev/null +++ b/src/App.css @@ -0,0 +1,184 @@ +.counter { + font-size: 16px; + padding: 5px 10px; + border-radius: 5px; + color: var(--accent); + background: var(--accent-bg); + border: 2px solid transparent; + transition: border-color 0.3s; + margin-bottom: 24px; + + &:hover { + border-color: var(--accent-border); + } + &:focus-visible { + outline: 2px solid var(--accent); + outline-offset: 2px; + } +} + +.hero { + position: relative; + + .base, + .framework, + .vite { + inset-inline: 0; + margin: 0 auto; + } + + .base { + width: 170px; + position: relative; + z-index: 0; + } + + .framework, + .vite { + position: absolute; + } + + .framework { + z-index: 1; + top: 34px; + height: 28px; + transform: perspective(2000px) rotateZ(300deg) rotateX(44deg) rotateY(39deg) + scale(1.4); + } + + .vite { + z-index: 0; + top: 107px; + height: 26px; + width: auto; + transform: perspective(2000px) rotateZ(300deg) rotateX(40deg) rotateY(39deg) + scale(0.8); + } +} + +#center { + display: flex; + flex-direction: column; + gap: 25px; + place-content: center; + place-items: center; + flex-grow: 1; + + @media (max-width: 1024px) { + padding: 32px 20px 24px; + gap: 18px; + } +} + +#next-steps { + display: flex; + border-top: 1px solid var(--border); + text-align: left; + + & > div { + flex: 1 1 0; + padding: 32px; + @media (max-width: 1024px) { + padding: 24px 20px; + } + } + + .icon { + margin-bottom: 16px; + width: 22px; + height: 22px; + } + + @media (max-width: 1024px) { + flex-direction: column; + text-align: center; + } +} + +#docs { + border-right: 1px solid var(--border); + + @media (max-width: 1024px) { + border-right: none; + border-bottom: 1px solid var(--border); + } +} + +#next-steps ul { + list-style: none; + padding: 0; + display: flex; + gap: 8px; + margin: 32px 0 0; + + .logo { + height: 18px; + } + + a { + color: var(--text-h); + font-size: 16px; + border-radius: 6px; + background: var(--social-bg); + display: flex; + padding: 6px 12px; + align-items: center; + gap: 8px; + text-decoration: none; + transition: box-shadow 0.3s; + + &:hover { + box-shadow: var(--shadow); + } + .button-icon { + height: 18px; + width: 18px; + } + } + + @media (max-width: 1024px) { + margin-top: 20px; + flex-wrap: wrap; + justify-content: center; + + li { + flex: 1 1 calc(50% - 8px); + } + + a { + width: 100%; + justify-content: center; + box-sizing: border-box; + } + } +} + +#spacer { + height: 88px; + border-top: 1px solid var(--border); + @media (max-width: 1024px) { + height: 48px; + } +} + +.ticks { + position: relative; + width: 100%; + + &::before, + &::after { + content: ''; + position: absolute; + top: -4.5px; + border: 5px solid transparent; + } + + &::before { + left: 0; + border-left-color: var(--border); + } + &::after { + right: 0; + border-right-color: var(--border); + } +} diff --git a/src/assets/hero.png b/src/assets/hero.png new file mode 100644 index 0000000..02251f4 Binary files /dev/null and b/src/assets/hero.png differ diff --git a/src/assets/vite.svg b/src/assets/vite.svg new file mode 100644 index 0000000..5101b67 --- /dev/null +++ b/src/assets/vite.svg @@ -0,0 +1 @@ +Vite diff --git a/src/client/endpoints.ts b/src/client/endpoints.ts new file mode 100644 index 0000000..94c627a --- /dev/null +++ b/src/client/endpoints.ts @@ -0,0 +1,3293 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import * as axios from 'axios'; +import type { + AxiosInstance, + AxiosRequestConfig, + AxiosResponse +} from 'axios'; + +import type { + AccessRequest, + AddIssueLabel, + AddPrLabel, + AddPrReaction, + AddReaction, + AddWorkspaceMember, + AgentListAllConversationsParams, + AgentListMessagesParams, + AgentRunRequest, + AgentRunResponse, + AgentSessionResponse, + AiAddRequest, + AiDiscussionResponse, + AiLikeResponse, + AiListDiscussionsParams, + AiListModelsParams, + AiModelCardResponse, + AiModelListItem, + AiModelResponse, + AiModelVersionResponse, + AiProviderResponse, + AppNotificationItem, + AssignIssueUser, + AssignPrUser, + AuthCaptchaParams, + AvatarUploadResponse, + BanCreateRequest, + BindIssuePullRequest, + BindIssueRepo, + BlameFileResponseDto, + BlobInfoResponse, + BlobUploadBody, + BlobUploadResponseDto, + BranchAheadBehindResponseDto, + BranchInfoResponseDto, + BranchListResponseDto, + BranchUpstreamResponseDto, + CaptchaResponse, + CategoryCreateRequest, + CategoryUpdateRequest, + ChannelListMessagesParams, + ChannelMessagesAroundParams, + ChannelMissedMessagesParams, + ChannelSearchParams, + CherryPickResponseDto, + CloneRepo, + CombinedCommitStatus, + CommitHistoryResponseDto, + CommitInfoResponseDto, + CommitStatusResponse, + CompareResponse, + ContentResponse, + ContextMe, + ContributionHeatmapResponse, + ContributorDto, + ConversationResponse, + ConversationWithSessionResponse, + CreateAgentSession, + CreateComment, + CreateCommitStatus, + CreateContent, + CreateConversation, + CreateFork, + CreateIssue, + CreateLabel, + CreateMessageRequest, + CreateMilestone, + CreatePrComment, + CreatePrReview, + CreatePrReviewComment, + CreateProtect, + CreatePullRequest, + CreateRelease, + CreateRepo, + CreateUserAccessToken, + CreateUserSshKey, + CreateWebhook, + CreateWorkspace, + CreateWorkspaceGroup, + CreateWorkspaceJoinApply, + CreatedUserAccessToken, + CustomStatusRequest, + DiffResultDto, + Disable2FAParams, + DismissPrReview, + DndRequest, + DraftSaveRequest, + EmailChangeRequest, + EmailResponse, + EmailVerifyRequest, + Enable2FAResponse, + ForkResponse, + Get2FAStatusResponse, + GitAheadBehindParams, + GitArchiveParams, + GitBlameFileParams, + GitBlobInfoParams, + GitCherryPickBody, + GitCommitHistoryParams, + GitCommitWalkBody, + GitDeleteBranchParams, + GitDeleteContentsParams, + GitDiffBranchesParams, + GitDiffParams, + GitForkBranchBody, + GitGetContentsParams, + GitInitTagBody, + GitListBranchesParams, + GitListCommitsParams, + GitListContributorsParams, + GitListDeliveriesParams, + GitListForksParams, + GitListProtectsParams, + GitListRefsParams, + GitListReposParams, + GitListTagsParams, + GitListWebhooksParams, + GitRefResponse, + GitTreeEntriesParams, + GitTreeEntryByPathFromCommitParams, + GitTreeEntryByPathParams, + GitUpdateTagBody, + GitWatchRepoBody, + InviteAcceptRequest, + InviteCreateRequest, + IssueAuthor, + IssueCommentResponse, + IssueEventResponse, + IssuePullRequestResponse, + IssueReactionResponse, + IssueRepoResponse, + IssueResponse, + IssuesListIssuesParams, + LabelResponse, + LanguageStatDto, + LoginParams, + MergePullRequest, + MessageResponse, + MilestoneResponse, + NotificationMarkAllReadRequest, + PinRequest, + PresenceUpdateRequest, + ProtectResponse, + PublicUserResponse, + PullRequestCommentResponse, + PullRequestListPrsParams, + PullRequestReactionResponse, + PullRequestResponse, + PullRequestReviewCommentResponse, + PullRequestReviewResponse, + ReactionRequest, + ReadReceiptRequest, + ReadmeDto, + RegisterParams, + ReleaseResponse, + RenameBranchBody, + RepoResponse, + ResetPasswordRequest, + ResetPasswordVerifyParams, + RoomCreateRequest, + RoomUpdateRequest, + RsaResponse, + ScreenShareRequest, + SearchParams, + SearchResponse, + SetIssueMilestone, + TagInfoResponseDto, + TagInitResponseDto, + TagSummaryResponseDto, + ThreadCreateRequest, + TokenRequest, + TokenResponse, + TransferRepo, + TreeEntriesResponseDto, + TreeEntryByPathResponseDto, + TypingRequest, + UpdateAgentSession, + UpdateComment, + UpdateContent, + UpdateConversation, + UpdateIssue, + UpdateLabel, + UpdateMessageRequest, + UpdateMilestone, + UpdatePrComment, + UpdateProtect, + UpdatePullRequest, + UpdateRelease, + UpdateRepo, + UpdateUserAccessToken, + UpdateUserAccessibilityConfig, + UpdateUserAppearanceConfig, + UpdateUserNotificationConfig, + UpdateUserPrivacyConfig, + UpdateUserProfileConfig, + UpdateUserSshKey, + UpdateWebhook, + UpdateWorkspace, + UpdateWorkspaceGroup, + UpdateWorkspaceJoinStrategy, + UpdateWorkspaceMember, + UserAccessToken, + UserAccessibilityConfig, + UserAppearanceConfig, + UserConfigResponse, + UserContributionHeatmapParams, + UserNotificationConfig, + UserPrivacyConfig, + UserProfileConfig, + UserRelationCard, + UserRelationCounts, + UserRelationStatus, + UserSshKey, + UserSummaryResponse, + UsersFollowersParams, + UsersFollowingParams, + UsersUserChpcParams, + Verify2FAParams, + VoiceDeafRequest, + VoiceMuteRequest, + WebhookDeliveryResponse, + WebhookResponse, + WorkspaceGroupResponse, + WorkspaceJoinApplyResponse, + WorkspaceJoinApprovalResponse, + WorkspaceJoinStrategyResponse, + WorkspaceListJoinAppliesParams, + WorkspaceListMembersParams, + WorkspaceMemberResponse, + WorkspaceResponse +} from './models'; + +export const getGitDataAIAPI = (axiosInstance: AxiosInstance = axios.default) => { +const agentListAllConversations = ( + params?: AgentListAllConversationsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/agent/conversations`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const agentGetConversation = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/agent/conversations/${id}`,options + ); + } + +const agentDeleteConversation = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/agent/conversations/${id}`,options + ); + } + +const agentUpdateConversation = ( + id: string, + updateConversation: UpdateConversation, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/agent/conversations/${id}`, + updateConversation,options + ); + } + +const agentListMessages = ( + id: string, + params?: AgentListMessagesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/agent/conversations/${id}/messages`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const agentSendMessage = ( + id: string, + agentRunRequest: AgentRunRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/agent/conversations/${id}/messages`, + agentRunRequest,options + ); + } + +const agentStreamAgent = ( + id: string, + agentRunRequest: AgentRunRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/agent/conversations/${id}/stream`, + agentRunRequest,options + ); + } + +const agentListSessions = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/agent/sessions`,options + ); + } + +const agentCreateSession = ( + createAgentSession: CreateAgentSession, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/agent/sessions`, + createAgentSession,options + ); + } + +const agentGetSession = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/agent/sessions/${id}`,options + ); + } + +const agentDeleteSession = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/agent/sessions/${id}`,options + ); + } + +const agentUpdateSession = ( + id: string, + updateAgentSession: UpdateAgentSession, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/agent/sessions/${id}`, + updateAgentSession,options + ); + } + +const agentListConversations = ( + sessionId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/agent/sessions/${sessionId}/conversations`,options + ); + } + +const agentCreateConversation = ( + sessionId: string, + createConversation: CreateConversation, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/agent/sessions/${sessionId}/conversations`, + createConversation,options + ); + } + +const aiListModels = ( + params?: AiListModelsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const aiGetModel = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models/${id}`,options + ); + } + +const aiGetCard = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models/${id}/card`,options + ); + } + +const aiListDiscussions = ( + id: string, + params?: AiListDiscussionsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models/${id}/discussions`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const aiListLikes = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models/${id}/likes`,options + ); + } + +const aiListTags = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models/${id}/tags`,options + ); + } + +const aiListVersions = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/models/${id}/versions`,options + ); + } + +const aiListProviders = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/providers`,options + ); + } + +const aiGetProvider = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ai/providers/${id}`,options + ); + } + +const authStatus2fa = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/auth/2fa`,options + ); + } + +const authDisable2fa = ( + disable2FAParams: Disable2FAParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/auth/2fa`,{data: disable2FAParams,...options} + ); + } + +const authRegenerateBackupCodes = ( + authRegenerateBackupCodesBody: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/2fa/backup-codes`, + authRegenerateBackupCodesBody,options + ); + } + +const authEnable2fa = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/2fa/enable`, + undefined,options + ); + } + +const authVerify2fa = ( + verify2FAParams: Verify2FAParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/2fa/verify`, + verify2FAParams,options + ); + } + +const authCaptcha = ( + params: AuthCaptchaParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/auth/captcha`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const authGetEmail = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/auth/email`,options + ); + } + +const authEmailChangeRequest = ( + emailChangeRequest: EmailChangeRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/email`, + emailChangeRequest,options + ); + } + +const authEmailVerify = ( + emailVerifyRequest: EmailVerifyRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/email/verify`, + emailVerifyRequest,options + ); + } + +const authLogin = ( + loginParams: LoginParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/login`, + loginParams,options + ); + } + +const authLogout = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/logout`, + undefined,options + ); + } + +const authMe = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/auth/me`,options + ); + } + +const authRsa = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/auth/public-key`,options + ); + } + +const authRegister = ( + registerParams: RegisterParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/register`, + registerParams,options + ); + } + +const authResetPasswordRequest = ( + resetPasswordRequest: ResetPasswordRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/reset-password/request`, + resetPasswordRequest,options + ); + } + +const authResetPasswordVerify = ( + resetPasswordVerifyParams: ResetPasswordVerifyParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/auth/reset-password/verify`, + resetPasswordVerifyParams,options + ); + } + +const search = ( + params?: SearchParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/search`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const userListAccessTokens = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/user/access-tokens`,options + ); + } + +const userCreateAccessToken = ( + createUserAccessToken: CreateUserAccessToken, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/user/access-tokens`, + createUserAccessToken,options + ); + } + +const userUpdateAccessToken = ( + id: number, + updateUserAccessToken: UpdateUserAccessToken, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/access-tokens/${id}`, + updateUserAccessToken,options + ); + } + +const userRevokeAccessToken = ( + id: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/user/access-tokens/${id}`,options + ); + } + +const userUploadAvatar = ( + userUploadAvatarBody: Blob, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/user/avatar`, + userUploadAvatarBody,options + ); + } + +const userConfig = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/user/config`,options + ); + } + +const userListNotifications = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/user/notifications`,options + ); + } + +const userUpdateAccessibility = ( + updateUserAccessibilityConfig: UpdateUserAccessibilityConfig, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/config/accessibility`, + updateUserAccessibilityConfig,options + ); + } + +const userUpdateAppearance = ( + updateUserAppearanceConfig: UpdateUserAppearanceConfig, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/config/appearance`, + updateUserAppearanceConfig,options + ); + } + +const userUpdateNotification = ( + updateUserNotificationConfig: UpdateUserNotificationConfig, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/config/notification`, + updateUserNotificationConfig,options + ); + } + +const userUpdatePrivacy = ( + updateUserPrivacyConfig: UpdateUserPrivacyConfig, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/config/privacy`, + updateUserPrivacyConfig,options + ); + } + +const userUpdateProfile = ( + updateUserProfileConfig: UpdateUserProfileConfig, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/config/profile`, + updateUserProfileConfig,options + ); + } + +const userContributionHeatmap = ( + params?: UserContributionHeatmapParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/user/contribution-heatmap`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const userInvalidateChpcCache = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/user/contribution-heatmap/cache`,options + ); + } + +const userListSshKeys = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/user/ssh-keys`,options + ); + } + +const userAddSshKey = ( + createUserSshKey: CreateUserSshKey, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/user/ssh-keys`, + createUserSshKey,options + ); + } + +const userUpdateSshKey = ( + id: number, + updateUserSshKey: UpdateUserSshKey, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/user/ssh-keys/${id}`, + updateUserSshKey,options + ); + } + +const userRevokeSshKey = ( + id: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/user/ssh-keys/${id}`,options + ); + } + +const usersUserAvatar = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/avatar/${username}`,options + ); + } + +const usersBlockedList = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/blocked`,options + ); + } + +const usersBlockUser = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/users/${username}/block`, + undefined,options + ); + } + +const usersUserChpc = ( + username: string, + params?: UsersUserChpcParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/contribution-heatmap`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const usersFollowUser = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/users/${username}/follow`, + undefined,options + ); + } + +const usersFollowers = ( + username: string, + params?: UsersFollowersParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/followers`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const usersFollowing = ( + username: string, + params?: UsersFollowingParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/following`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const usersUserPublic = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/public`,options + ); + } + +const usersRelationStatus = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/relation`,options + ); + } + +const usersRelationCounts = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/relation-counts`,options + ); + } + +const usersUserSummary = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/users/${username}/summary`,options + ); + } + +const usersUnblockUser = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/users/${username}/unblock`, + undefined,options + ); + } + +const usersUnfollowUser = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/users/${username}/unfollow`, + undefined,options + ); + } + +const workspaceCreateWorkspace = ( + createWorkspace: CreateWorkspace, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace`, + createWorkspace,options + ); + } + +const workspaceMyJoinApplies = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/join/my-applies`,options + ); + } + +const workspaceMyWorkspaces = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/my`,options + ); + } + +const workspaceGetWorkspace = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}`,options + ); + } + +const workspaceUpdateWorkspace = ( + wk: string, + updateWorkspace: UpdateWorkspace, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}`, + updateWorkspace,options + ); + } + +const workspaceGetAvatar = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/avatar`,options + ); + } + +const workspaceUploadAvatar = ( + wk: string, + workspaceUploadAvatarBody: Blob, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/avatar`, + workspaceUploadAvatarBody,options + ); + } + +const workspaceListGroups = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/groups`,options + ); + } + +const workspaceCreateGroup = ( + wk: string, + createWorkspaceGroup: CreateWorkspaceGroup, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/groups`, + createWorkspaceGroup,options + ); + } + +const workspaceUpdateGroup = ( + wk: string, + groupName: string, + updateWorkspaceGroup: UpdateWorkspaceGroup, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/groups/${groupName}`, + updateWorkspaceGroup,options + ); + } + +const workspaceDeleteGroup = ( + wk: string, + groupName: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/groups/${groupName}`,options + ); + } + +const workspaceListGroupMembers = ( + wk: string, + groupName: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/groups/${groupName}/members`,options + ); + } + +const workspaceAddGroupMember = ( + wk: string, + groupName: string, + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/groups/${groupName}/members/${username}`, + undefined,options + ); + } + +const workspaceRemoveGroupMember = ( + wk: string, + groupName: string, + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/groups/${groupName}/members/${username}`,options + ); + } + +const issuesListIssues = ( + wk: string, + params?: IssuesListIssuesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/issues`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const issuesCreateIssue = ( + wk: string, + createIssue: CreateIssue, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues`, + createIssue,options + ); + } + +const issuesGetIssue = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/issues/${number}`,options + ); + } + +const issuesUpdateIssue = ( + wk: string, + number: number, + updateIssue: UpdateIssue, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/issues/${number}`, + updateIssue,options + ); + } + +const issuesDeleteIssue = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}`,options + ); + } + +const issuesAssignUser = ( + wk: string, + number: number, + assignIssueUser: AssignIssueUser, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/assignees`, + assignIssueUser,options + ); + } + +const issuesUnassignUser = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/assignees`,options + ); + } + +const issuesCloseIssue = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/close`, + undefined,options + ); + } + +const issuesListComments = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/issues/${number}/comments`,options + ); + } + +const issuesCreateComment = ( + wk: string, + number: number, + createComment: CreateComment, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/comments`, + createComment,options + ); + } + +const issuesUpdateComment = ( + wk: string, + number: number, + commentId: string, + updateComment: UpdateComment, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/issues/${number}/comments/${commentId}`, + updateComment,options + ); + } + +const issuesDeleteComment = ( + wk: string, + number: number, + commentId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/comments/${commentId}`,options + ); + } + +const issuesAddCommentReaction = ( + wk: string, + number: number, + commentId: string, + addReaction: AddReaction, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/comments/${commentId}/reactions`, + addReaction,options + ); + } + +const issuesRemoveCommentReaction = ( + wk: string, + number: number, + commentId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/comments/${commentId}/reactions`,options + ); + } + +const issuesListEvents = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/issues/${number}/events`,options + ); + } + +const issuesAddIssueLabel = ( + wk: string, + number: number, + addIssueLabel: AddIssueLabel, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/labels`, + addIssueLabel,options + ); + } + +const issuesRemoveIssueLabel = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/labels`,options + ); + } + +const issuesSetIssueMilestone = ( + wk: string, + number: number, + setIssueMilestone: SetIssueMilestone, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/milestone`, + setIssueMilestone,options + ); + } + +const issuesClearIssueMilestone = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/milestone`,options + ); + } + +const issuesBindPullRequest = ( + wk: string, + number: number, + bindIssuePullRequest: BindIssuePullRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/pull-requests`, + bindIssuePullRequest,options + ); + } + +const issuesUnbindPullRequest = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/pull-requests`,options + ); + } + +const issuesAddReaction = ( + wk: string, + number: number, + addReaction: AddReaction, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/reactions`, + addReaction,options + ); + } + +const issuesRemoveReaction = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/reactions`,options + ); + } + +const issuesReopenIssue = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/reopen`, + undefined,options + ); + } + +const issuesBindRepo = ( + wk: string, + number: number, + bindIssueRepo: BindIssueRepo, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/issues/${number}/repos`, + bindIssueRepo,options + ); + } + +const issuesUnbindRepo = ( + wk: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/issues/${number}/repos`,options + ); + } + +const workspaceJoinStrategy = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/join-strategy`,options + ); + } + +const workspaceUpdateJoinStrategy = ( + wk: string, + updateWorkspaceJoinStrategy: UpdateWorkspaceJoinStrategy, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/join-strategy`, + updateWorkspaceJoinStrategy,options + ); + } + +const workspaceListJoinApplies = ( + wk: string, + params?: WorkspaceListJoinAppliesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/join/applies`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const workspaceApproveJoin = ( + wk: string, + username: string, + approveWorkspaceJoinApply: ApproveWorkspaceJoinApply, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/join/applies/${username}/approve`, + approveWorkspaceJoinApply,options + ); + } + +const workspaceApplyJoin = ( + wk: string, + createWorkspaceJoinApply: CreateWorkspaceJoinApply, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/join/apply`, + createWorkspaceJoinApply,options + ); + } + +const workspaceCancelJoin = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/join/cancel`, + undefined,options + ); + } + +const issuesListLabels = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/labels`,options + ); + } + +const issuesCreateLabel = ( + wk: string, + createLabel: CreateLabel, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/labels`, + createLabel,options + ); + } + +const issuesUpdateLabel = ( + wk: string, + labelId: string, + updateLabel: UpdateLabel, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/labels/${labelId}`, + updateLabel,options + ); + } + +const issuesDeleteLabel = ( + wk: string, + labelId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/labels/${labelId}`,options + ); + } + +const workspaceListMembers = ( + wk: string, + params?: WorkspaceListMembersParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/members`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const workspaceAddMember = ( + wk: string, + addWorkspaceMember: AddWorkspaceMember, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/members`, + addWorkspaceMember,options + ); + } + +const workspaceUpdateMember = ( + wk: string, + username: string, + updateWorkspaceMember: UpdateWorkspaceMember, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/members/${username}`, + updateWorkspaceMember,options + ); + } + +const workspaceRemoveMember = ( + wk: string, + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/members/${username}`,options + ); + } + +const issuesListMilestones = ( + wk: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/milestones`,options + ); + } + +const issuesCreateMilestone = ( + wk: string, + createMilestone: CreateMilestone, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/milestones`, + createMilestone,options + ); + } + +const issuesUpdateMilestone = ( + wk: string, + milestoneId: string, + updateMilestone: UpdateMilestone, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/milestones/${milestoneId}`, + updateMilestone,options + ); + } + +const issuesDeleteMilestone = ( + wk: string, + milestoneId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/milestones/${milestoneId}`,options + ); + } + +const gitListRepos = ( + wk: string, + params?: GitListReposParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitCreateRepo = ( + wk: string, + createRepo: CreateRepo, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos`, + createRepo,options + ); + } + +const gitCloneRepo = ( + wk: string, + cloneRepo: CloneRepo, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/clone`, + cloneRepo,options + ); + } + +const gitGetRepo = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}`,options + ); + } + +const gitUpdateRepo = ( + wk: string, + repo: string, + updateRepo: UpdateRepo, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/repos/${repo}`, + updateRepo,options + ); + } + +const gitDeleteRepo = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}`,options + ); + } + +const gitArchiveRepo = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/archive`, + undefined,options + ); + } + +const gitCombinedStatus = ( + wk: string, + repo: string, + sha: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/commits/${sha}/status`,options + ); + } + +const gitListStatuses = ( + wk: string, + repo: string, + sha: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/commits/${sha}/statuses`,options + ); + } + +const gitCompare = ( + wk: string, + repo: string, + basehead: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/compare/${basehead}`,options + ); + } + +const gitGetContents = ( + wk: string, + repo: string, + path: string, + params?: GitGetContentsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/contents/${path}`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitUpdateContents = ( + wk: string, + repo: string, + path: string, + updateContent: UpdateContent, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/repos/${repo}/contents/${path}`, + updateContent,options + ); + } + +const gitCreateContents = ( + wk: string, + repo: string, + path: string, + createContent: CreateContent, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/contents/${path}`, + createContent,options + ); + } + +const gitDeleteContents = ( + wk: string, + repo: string, + path: string, + params: GitDeleteContentsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/contents/${path}`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitListForks = ( + wk: string, + repo: string, + params?: GitListForksParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/forks`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitCreateFork = ( + wk: string, + repo: string, + createFork: CreateFork, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/forks`, + createFork,options + ); + } + +const gitArchive = ( + wk: string, + repo: string, + params?: GitArchiveParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/archive`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitBlameFile = ( + wk: string, + repo: string, + params: GitBlameFileParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/blame`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitBlobUpload = ( + wk: string, + repo: string, + blobUploadBody: BlobUploadBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/blobs`, + blobUploadBody,options + ); + } + +const gitBlobInfo = ( + wk: string, + repo: string, + oid: string, + params?: GitBlobInfoParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/blobs/${oid}`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitListBranches = ( + wk: string, + repo: string, + params?: GitListBranchesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +/** + * BranchForkParams { name, oid, force } + */ +const gitForkBranch = ( + wk: string, + repo: string, + gitForkBranchBody: GitForkBranchBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches`, + gitForkBranchBody,options + ); + } + +const gitBranchInfo = ( + wk: string, + repo: string, + name: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches/${name}`,options + ); + } + +const gitDeleteBranch = ( + wk: string, + repo: string, + name: string, + params?: GitDeleteBranchParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches/${name}`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitRenameBranch = ( + wk: string, + repo: string, + name: string, + renameBranchBody: RenameBranchBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches/${name}`, + renameBranchBody,options + ); + } + +const gitAheadBehind = ( + wk: string, + repo: string, + name: string, + params: GitAheadBehindParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches/${name}/ahead-behind`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitBranchUpstream = ( + wk: string, + repo: string, + name: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/branches/${name}/upstream`,options + ); + } + +const gitListCommits = ( + wk: string, + repo: string, + params?: GitListCommitsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/commits`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +/** + * CommitCherryPickParams + */ +const gitCherryPick = ( + wk: string, + repo: string, + gitCherryPickBody: GitCherryPickBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/commits/cherry-pick`, + gitCherryPickBody,options + ); + } + +const gitCommitHistory = ( + wk: string, + repo: string, + params?: GitCommitHistoryParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/commits/history`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +/** + * CommitWalkParams + */ +const gitCommitWalk = ( + wk: string, + repo: string, + gitCommitWalkBody: GitCommitWalkBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/commits/walk`, + gitCommitWalkBody,options + ); + } + +const gitCommitInfo = ( + wk: string, + repo: string, + oid: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/commits/${oid}`,options + ); + } + +const gitTreeEntryByPathFromCommit = ( + wk: string, + repo: string, + oid: string, + params: GitTreeEntryByPathFromCommitParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/commits/${oid}/tree`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitListContributors = ( + wk: string, + repo: string, + params?: GitListContributorsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/contributors`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitDiff = ( + wk: string, + repo: string, + params?: GitDiffParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/diff`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitDiffBranches = ( + wk: string, + repo: string, + params: GitDiffBranchesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/diff/branches`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitGetLanguages = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/languages`,options + ); + } + +const gitGetReadme = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/readme`,options + ); + } + +const gitListRefs = ( + wk: string, + repo: string, + params?: GitListRefsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/refs`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitStarStatus = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/star`,options + ); + } + +const gitStarRepo = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/star`, + undefined,options + ); + } + +const gitUnstarRepo = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/git/star`,options + ); + } + +const gitListTags = ( + wk: string, + repo: string, + params?: GitListTagsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/tags`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +/** + * TagInitParams { name, oid, message, tagger, force } + */ +const gitInitTag = ( + wk: string, + repo: string, + gitInitTagBody: GitInitTagBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/tags`, + gitInitTagBody,options + ); + } + +const gitTagInfo = ( + wk: string, + repo: string, + name: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/tags/${name}`,options + ); + } + +const gitDeleteTag = ( + wk: string, + repo: string, + name: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/git/tags/${name}`,options + ); + } + +const gitUpdateTag = ( + wk: string, + repo: string, + name: string, + gitUpdateTagBody: GitUpdateTagBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/workspace/${wk}/repos/${repo}/git/tags/${name}`, + gitUpdateTagBody,options + ); + } + +const gitTreeEntries = ( + wk: string, + repo: string, + oid: string, + params?: GitTreeEntriesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/trees/${oid}`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitTreeEntryByPath = ( + wk: string, + repo: string, + treeOid: string, + params: GitTreeEntryByPathParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/trees/${treeOid}/entries`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitWatchStatus = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/git/watch`,options + ); + } + +/** + * WatchLevel {level: String} + */ +const gitWatchRepo = ( + wk: string, + repo: string, + gitWatchRepoBody: GitWatchRepoBody, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/git/watch`, + gitWatchRepoBody,options + ); + } + +const gitUnwatchRepo = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/git/watch`,options + ); + } + +const gitListProtects = ( + wk: string, + repo: string, + params?: GitListProtectsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/protect`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitCreateProtect = ( + wk: string, + repo: string, + createProtect: CreateProtect, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/protect`, + createProtect,options + ); + } + +const gitUpdateProtect = ( + wk: string, + repo: string, + protectId: string, + updateProtect: UpdateProtect, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/repos/${repo}/protect/${protectId}`, + updateProtect,options + ); + } + +const gitDeleteProtect = ( + wk: string, + repo: string, + protectId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/protect/${protectId}`,options + ); + } + +const pullRequestListPrs = ( + wk: string, + repo: string, + params?: PullRequestListPrsParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const pullRequestCreatePr = ( + wk: string, + repo: string, + createPullRequest: CreatePullRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests`, + createPullRequest,options + ); + } + +const pullRequestGetPr = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}`,options + ); + } + +const pullRequestDeletePr = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}`,options + ); + } + +const pullRequestUpdatePr = ( + wk: string, + repo: string, + number: number, + updatePullRequest: UpdatePullRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}`, + updatePullRequest,options + ); + } + +const pullRequestAssignUser = ( + wk: string, + repo: string, + number: number, + assignPrUser: AssignPrUser, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/assignees`, + assignPrUser,options + ); + } + +const pullRequestUnassignUser = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/assignees`,options + ); + } + +const pullRequestListComments = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/comments`,options + ); + } + +const pullRequestCreateComment = ( + wk: string, + repo: string, + number: number, + createPrComment: CreatePrComment, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/comments`, + createPrComment,options + ); + } + +const pullRequestUpdateComment = ( + wk: string, + repo: string, + number: number, + commentId: string, + updatePrComment: UpdatePrComment, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/comments/${commentId}`, + updatePrComment,options + ); + } + +const pullRequestDeleteComment = ( + wk: string, + repo: string, + number: number, + commentId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/comments/${commentId}`,options + ); + } + +const pullRequestAddCommentReaction = ( + wk: string, + repo: string, + number: number, + commentId: string, + addPrReaction: AddPrReaction, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/comments/${commentId}/reactions`, + addPrReaction,options + ); + } + +const pullRequestRemoveCommentReaction = ( + wk: string, + repo: string, + number: number, + commentId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/comments/${commentId}/reactions`,options + ); + } + +const pullRequestAddLabel = ( + wk: string, + repo: string, + number: number, + addPrLabel: AddPrLabel, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/labels`, + addPrLabel,options + ); + } + +const pullRequestRemoveLabel = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/labels`,options + ); + } + +const pullRequestMergeAnalysis = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/merge`,options + ); + } + +const pullRequestMergePr = ( + wk: string, + repo: string, + number: number, + mergePullRequest: MergePullRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/merge`, + mergePullRequest,options + ); + } + +const pullRequestAddReaction = ( + wk: string, + repo: string, + number: number, + addPrReaction: AddPrReaction, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/reactions`, + addPrReaction,options + ); + } + +const pullRequestRemoveReaction = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/reactions`,options + ); + } + +const pullRequestCreateReviewComment = ( + wk: string, + repo: string, + number: number, + createPrReviewComment: CreatePrReviewComment, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/review-comments`, + createPrReviewComment,options + ); + } + +const pullRequestListReviews = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/reviews`,options + ); + } + +const pullRequestCreateReview = ( + wk: string, + repo: string, + number: number, + createPrReview: CreatePrReview, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/reviews`, + createPrReview,options + ); + } + +const pullRequestDismissReview = ( + wk: string, + repo: string, + number: number, + reviewId: string, + dismissPrReview: DismissPrReview, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/reviews/${reviewId}/dismiss`, + dismissPrReview,options + ); + } + +const pullRequestUpdateBranch = ( + wk: string, + repo: string, + number: number, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/pull-requests/${number}/update-branch`, + undefined,options + ); + } + +const gitListReleases = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/releases`,options + ); + } + +const gitCreateRelease = ( + wk: string, + repo: string, + createRelease: CreateRelease, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/releases`, + createRelease,options + ); + } + +const gitGetReleaseByTag = ( + wk: string, + repo: string, + tag: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/releases/tags/${tag}`,options + ); + } + +const gitDeleteReleaseByTag = ( + wk: string, + repo: string, + tag: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/releases/tags/${tag}`,options + ); + } + +const gitGetRelease = ( + wk: string, + repo: string, + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/releases/${id}`,options + ); + } + +const gitDeleteRelease = ( + wk: string, + repo: string, + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/releases/${id}`,options + ); + } + +const gitUpdateRelease = ( + wk: string, + repo: string, + id: string, + updateRelease: UpdateRelease, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/workspace/${wk}/repos/${repo}/releases/${id}`, + updateRelease,options + ); + } + +const gitCreateStatus = ( + wk: string, + repo: string, + sha: string, + createCommitStatus: CreateCommitStatus, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/statuses/${sha}`, + createCommitStatus,options + ); + } + +const gitGetTopics = ( + wk: string, + repo: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/topics`,options + ); + } + +const gitUpdateTopics = ( + wk: string, + repo: string, + gitUpdateTopicsBody: string[], options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/repos/${repo}/topics`, + gitUpdateTopicsBody,options + ); + } + +const gitTransferRepo = ( + wk: string, + repo: string, + transferRepo: TransferRepo, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/transfer`, + transferRepo,options + ); + } + +const gitListWebhooks = ( + wk: string, + repo: string, + params?: GitListWebhooksParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/webhooks`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const gitCreateWebhook = ( + wk: string, + repo: string, + createWebhook: CreateWebhook, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/workspace/${wk}/repos/${repo}/webhooks`, + createWebhook,options + ); + } + +const gitUpdateWebhook = ( + wk: string, + repo: string, + webhookId: string, + updateWebhook: UpdateWebhook, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/workspace/${wk}/repos/${repo}/webhooks/${webhookId}`, + updateWebhook,options + ); + } + +const gitDeleteWebhook = ( + wk: string, + repo: string, + webhookId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/workspace/${wk}/repos/${repo}/webhooks/${webhookId}`,options + ); + } + +const gitListDeliveries = ( + wk: string, + repo: string, + webhookId: string, + params?: GitListDeliveriesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/workspace/${wk}/repos/${repo}/webhooks/${webhookId}/deliveries`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const channelCategoryDelete = ( + categoryId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/categories/${categoryId}`,options + ); + } + +const channelCategoryUpdate = ( + categoryId: string, + categoryUpdateRequest: CategoryUpdateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/categories/${categoryId}`, + categoryUpdateRequest,options + ); + } + +const channelCsrfToken = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/csrf`,options + ); + } + +const channelCustomStatusUpdate = ( + customStatusRequest: CustomStatusRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/custom-status`, + customStatusRequest,options + ); + } + +const channelInviteCreate = ( + inviteCreateRequest: InviteCreateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/invites`, + inviteCreateRequest,options + ); + } + +const channelInviteAccept = ( + inviteAcceptRequest: InviteAcceptRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/invites/accept`, + inviteAcceptRequest,options + ); + } + +const channelInviteRevoke = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/invites/${id}`,options + ); + } + +const channelRevokeMessage = ( + messageId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/messages/${messageId}`,options + ); + } + +const channelUpdateMessage = ( + messageId: string, + updateMessageRequest: UpdateMessageRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/messages/${messageId}`, + updateMessageRequest,options + ); + } + +const channelNotificationMarkAllRead = ( + notificationMarkAllReadRequest: NotificationMarkAllReadRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/notifications/read-all`, + notificationMarkAllReadRequest,options + ); + } + +const channelNotificationArchive = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/notifications/${id}`,options + ); + } + +const channelNotificationMarkRead = ( + id: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/notifications/${id}/read`, + undefined,options + ); + } + +const channelPing = ( + options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/ping`,options + ); + } + +const channelPresenceUpdate = ( + presenceUpdateRequest: PresenceUpdateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/presence`, + presenceUpdateRequest,options + ); + } + +const channelRoomCreate = ( + roomCreateRequest: RoomCreateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms`, + roomCreateRequest,options + ); + } + +const channelRoomGet = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/rooms/${roomId}`,options + ); + } + +const channelRoomDelete = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}`,options + ); + } + +const channelRoomUpdate = ( + roomId: string, + roomUpdateRequest: RoomUpdateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/rooms/${roomId}`, + roomUpdateRequest,options + ); + } + +const channelAiList = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/rooms/${roomId}/ai`,options + ); + } + +const channelAiAdd = ( + roomId: string, + aiAddRequest: AiAddRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/ai`, + aiAddRequest,options + ); + } + +const channelAiStop = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/ai/stop`, + undefined,options + ); + } + +const channelAiRemove = ( + roomId: string, + agentSessionId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}/ai/${agentSessionId}`,options + ); + } + +const channelDndUpdate = ( + roomId: string, + dndRequest: DndRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/rooms/${roomId}/dnd`, + dndRequest,options + ); + } + +const channelDraftSave = ( + roomId: string, + draftSaveRequest: DraftSaveRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.put( + `/api/v1/ws/rooms/${roomId}/drafts`, + draftSaveRequest,options + ); + } + +const channelDraftClear = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}/drafts`,options + ); + } + +const channelAccessGrant = ( + roomId: string, + accessRequest: AccessRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/members`, + accessRequest,options + ); + } + +const channelAccessRevoke = ( + roomId: string, + userId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}/members/${userId}`,options + ); + } + +const channelListMessages = ( + roomId: string, + params?: ChannelListMessagesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/rooms/${roomId}/messages`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const channelCreateMessage = ( + roomId: string, + createMessageRequest: CreateMessageRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/messages`, + createMessageRequest,options + ); + } + +const channelMessagesAround = ( + roomId: string, + params: ChannelMessagesAroundParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/rooms/${roomId}/messages/around`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const channelMissedMessages = ( + roomId: string, + params: ChannelMissedMessagesParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/rooms/${roomId}/messages/missed`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const channelPinAdd = ( + roomId: string, + pinRequest: PinRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/pins`, + pinRequest,options + ); + } + +const channelPinRemove = ( + roomId: string, + pinRequest: PinRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}/pins`,{data: pinRequest,...options} + ); + } + +const channelReactionAdd = ( + roomId: string, + reactionRequest: ReactionRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/reactions`, + reactionRequest,options + ); + } + +const channelReactionRemove = ( + roomId: string, + reactionRequest: ReactionRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}/reactions`,{data: reactionRequest,...options} + ); + } + +const channelReadReceipt = ( + roomId: string, + readReceiptRequest: ReadReceiptRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/read-receipt`, + readReceiptRequest,options + ); + } + +const channelScreenShare = ( + roomId: string, + screenShareRequest: ScreenShareRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/screen-share`, + screenShareRequest,options + ); + } + +const channelSubscribe = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/subscribe`, + undefined,options + ); + } + +const channelUnsubscribe = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/rooms/${roomId}/subscribe`,options + ); + } + +const channelThreadCreate = ( + roomId: string, + threadCreateRequest: ThreadCreateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/threads`, + threadCreateRequest,options + ); + } + +const channelTyping = ( + roomId: string, + typingRequest: TypingRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/typing`, + typingRequest,options + ); + } + +const channelVoiceDeaf = ( + roomId: string, + voiceDeafRequest: VoiceDeafRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/voice/deaf`, + voiceDeafRequest,options + ); + } + +const channelVoiceJoin = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/voice/join`, + undefined,options + ); + } + +const channelVoiceLeave = ( + roomId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/voice/leave`, + undefined,options + ); + } + +const channelVoiceMute = ( + roomId: string, + voiceMuteRequest: VoiceMuteRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/rooms/${roomId}/voice/mute`, + voiceMuteRequest,options + ); + } + +const channelSearch = ( + params: ChannelSearchParams, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/search`,{ + ...options, + params: {...params, ...options?.params},} + ); + } + +const channelThreadArchive = ( + threadId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/threads/${threadId}/archive`, + undefined,options + ); + } + +const channelThreadResolve = ( + threadId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.patch( + `/api/v1/ws/threads/${threadId}/resolve`, + undefined,options + ); + } + +const channelGenerateToken = ( + tokenRequest: TokenRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/token`, + tokenRequest,options + ); + } + +const channelUserSummary = ( + username: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.get( + `/api/v1/ws/users/summary/${username}`,options + ); + } + +const channelBanCreate = ( + workspaceId: string, + banCreateRequest: BanCreateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/workspaces/${workspaceId}/bans`, + banCreateRequest,options + ); + } + +const channelBanRemove = ( + workspaceId: string, + userId: string, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.delete( + `/api/v1/ws/workspaces/${workspaceId}/bans/${userId}`,options + ); + } + +const channelCategoryCreate = ( + workspaceId: string, + categoryCreateRequest: CategoryCreateRequest, options?: AxiosRequestConfig + ): Promise> => { + return axiosInstance.post( + `/api/v1/ws/workspaces/${workspaceId}/categories`, + categoryCreateRequest,options + ); + } + +return {agentListAllConversations,agentGetConversation,agentDeleteConversation,agentUpdateConversation,agentListMessages,agentSendMessage,agentStreamAgent,agentListSessions,agentCreateSession,agentGetSession,agentDeleteSession,agentUpdateSession,agentListConversations,agentCreateConversation,aiListModels,aiGetModel,aiGetCard,aiListDiscussions,aiListLikes,aiListTags,aiListVersions,aiListProviders,aiGetProvider,authStatus2fa,authDisable2fa,authRegenerateBackupCodes,authEnable2fa,authVerify2fa,authCaptcha,authGetEmail,authEmailChangeRequest,authEmailVerify,authLogin,authLogout,authMe,authRsa,authRegister,authResetPasswordRequest,authResetPasswordVerify,search,userListAccessTokens,userCreateAccessToken,userUpdateAccessToken,userRevokeAccessToken,userUploadAvatar,userListNotifications,userConfig,userUpdateAccessibility,userUpdateAppearance,userUpdateNotification,userUpdatePrivacy,userUpdateProfile,userContributionHeatmap,userInvalidateChpcCache,userListSshKeys,userAddSshKey,userUpdateSshKey,userRevokeSshKey,usersUserAvatar,usersBlockedList,usersBlockUser,usersUserChpc,usersFollowUser,usersFollowers,usersFollowing,usersUserPublic,usersRelationStatus,usersRelationCounts,usersUserSummary,usersUnblockUser,usersUnfollowUser,workspaceCreateWorkspace,workspaceMyJoinApplies,workspaceMyWorkspaces,workspaceGetWorkspace,workspaceUpdateWorkspace,workspaceGetAvatar,workspaceUploadAvatar,workspaceListGroups,workspaceCreateGroup,workspaceUpdateGroup,workspaceDeleteGroup,workspaceListGroupMembers,workspaceAddGroupMember,workspaceRemoveGroupMember,issuesListIssues,issuesCreateIssue,issuesGetIssue,issuesUpdateIssue,issuesDeleteIssue,issuesAssignUser,issuesUnassignUser,issuesCloseIssue,issuesListComments,issuesCreateComment,issuesUpdateComment,issuesDeleteComment,issuesAddCommentReaction,issuesRemoveCommentReaction,issuesListEvents,issuesAddIssueLabel,issuesRemoveIssueLabel,issuesSetIssueMilestone,issuesClearIssueMilestone,issuesBindPullRequest,issuesUnbindPullRequest,issuesAddReaction,issuesRemoveReaction,issuesReopenIssue,issuesBindRepo,issuesUnbindRepo,workspaceJoinStrategy,workspaceUpdateJoinStrategy,workspaceListJoinApplies,workspaceApproveJoin,workspaceApplyJoin,workspaceCancelJoin,issuesListLabels,issuesCreateLabel,issuesUpdateLabel,issuesDeleteLabel,workspaceListMembers,workspaceAddMember,workspaceUpdateMember,workspaceRemoveMember,issuesListMilestones,issuesCreateMilestone,issuesUpdateMilestone,issuesDeleteMilestone,gitListRepos,gitCreateRepo,gitCloneRepo,gitGetRepo,gitUpdateRepo,gitDeleteRepo,gitArchiveRepo,gitCombinedStatus,gitListStatuses,gitCompare,gitGetContents,gitUpdateContents,gitCreateContents,gitDeleteContents,gitListForks,gitCreateFork,gitArchive,gitBlameFile,gitBlobUpload,gitBlobInfo,gitListBranches,gitForkBranch,gitBranchInfo,gitDeleteBranch,gitRenameBranch,gitAheadBehind,gitBranchUpstream,gitListCommits,gitCherryPick,gitCommitHistory,gitCommitWalk,gitCommitInfo,gitTreeEntryByPathFromCommit,gitListContributors,gitDiff,gitDiffBranches,gitGetLanguages,gitGetReadme,gitListRefs,gitStarStatus,gitStarRepo,gitUnstarRepo,gitListTags,gitInitTag,gitTagInfo,gitDeleteTag,gitUpdateTag,gitTreeEntries,gitTreeEntryByPath,gitWatchStatus,gitWatchRepo,gitUnwatchRepo,gitListProtects,gitCreateProtect,gitUpdateProtect,gitDeleteProtect,pullRequestListPrs,pullRequestCreatePr,pullRequestGetPr,pullRequestDeletePr,pullRequestUpdatePr,pullRequestAssignUser,pullRequestUnassignUser,pullRequestListComments,pullRequestCreateComment,pullRequestUpdateComment,pullRequestDeleteComment,pullRequestAddCommentReaction,pullRequestRemoveCommentReaction,pullRequestAddLabel,pullRequestRemoveLabel,pullRequestMergeAnalysis,pullRequestMergePr,pullRequestAddReaction,pullRequestRemoveReaction,pullRequestCreateReviewComment,pullRequestListReviews,pullRequestCreateReview,pullRequestDismissReview,pullRequestUpdateBranch,gitListReleases,gitCreateRelease,gitGetReleaseByTag,gitDeleteReleaseByTag,gitGetRelease,gitDeleteRelease,gitUpdateRelease,gitCreateStatus,gitGetTopics,gitUpdateTopics,gitTransferRepo,gitListWebhooks,gitCreateWebhook,gitUpdateWebhook,gitDeleteWebhook,gitListDeliveries,channelCategoryDelete,channelCategoryUpdate,channelCsrfToken,channelCustomStatusUpdate,channelInviteCreate,channelInviteAccept,channelInviteRevoke,channelRevokeMessage,channelUpdateMessage,channelNotificationMarkAllRead,channelNotificationArchive,channelNotificationMarkRead,channelPing,channelPresenceUpdate,channelRoomCreate,channelRoomGet,channelRoomDelete,channelRoomUpdate,channelAiList,channelAiAdd,channelAiStop,channelAiRemove,channelDndUpdate,channelDraftSave,channelDraftClear,channelAccessGrant,channelAccessRevoke,channelListMessages,channelCreateMessage,channelMessagesAround,channelMissedMessages,channelPinAdd,channelPinRemove,channelReactionAdd,channelReactionRemove,channelReadReceipt,channelScreenShare,channelSubscribe,channelUnsubscribe,channelThreadCreate,channelTyping,channelVoiceDeaf,channelVoiceJoin,channelVoiceLeave,channelVoiceMute,channelSearch,channelThreadArchive,channelThreadResolve,channelGenerateToken,channelUserSummary,channelBanCreate,channelBanRemove,channelCategoryCreate}}; +export type AgentListAllConversationsResult = AxiosResponse +export type AgentGetConversationResult = AxiosResponse +export type AgentDeleteConversationResult = AxiosResponse +export type AgentUpdateConversationResult = AxiosResponse +export type AgentListMessagesResult = AxiosResponse +export type AgentSendMessageResult = AxiosResponse +export type AgentStreamAgentResult = AxiosResponse +export type AgentListSessionsResult = AxiosResponse +export type AgentCreateSessionResult = AxiosResponse +export type AgentGetSessionResult = AxiosResponse +export type AgentDeleteSessionResult = AxiosResponse +export type AgentUpdateSessionResult = AxiosResponse +export type AgentListConversationsResult = AxiosResponse +export type AgentCreateConversationResult = AxiosResponse +export type AiListModelsResult = AxiosResponse +export type AiGetModelResult = AxiosResponse +export type AiGetCardResult = AxiosResponse +export type AiListDiscussionsResult = AxiosResponse +export type AiListLikesResult = AxiosResponse +export type AiListTagsResult = AxiosResponse +export type AiListVersionsResult = AxiosResponse +export type AiListProvidersResult = AxiosResponse +export type AiGetProviderResult = AxiosResponse +export type AuthStatus2faResult = AxiosResponse +export type AuthDisable2faResult = AxiosResponse +export type AuthRegenerateBackupCodesResult = AxiosResponse +export type AuthEnable2faResult = AxiosResponse +export type AuthVerify2faResult = AxiosResponse +export type AuthCaptchaResult = AxiosResponse +export type AuthGetEmailResult = AxiosResponse +export type AuthEmailChangeRequestResult = AxiosResponse +export type AuthEmailVerifyResult = AxiosResponse +export type AuthLoginResult = AxiosResponse +export type AuthLogoutResult = AxiosResponse +export type AuthMeResult = AxiosResponse +export type AuthRsaResult = AxiosResponse +export type AuthRegisterResult = AxiosResponse +export type AuthResetPasswordRequestResult = AxiosResponse +export type AuthResetPasswordVerifyResult = AxiosResponse +export type SearchResult = AxiosResponse +export type UserListAccessTokensResult = AxiosResponse +export type UserCreateAccessTokenResult = AxiosResponse +export type UserUpdateAccessTokenResult = AxiosResponse +export type UserRevokeAccessTokenResult = AxiosResponse +export type UserUploadAvatarResult = AxiosResponse +export type UserConfigResult = AxiosResponse +export type UserUpdateAccessibilityResult = AxiosResponse +export type UserUpdateAppearanceResult = AxiosResponse +export type UserUpdateNotificationResult = AxiosResponse +export type UserUpdatePrivacyResult = AxiosResponse +export type UserUpdateProfileResult = AxiosResponse +export type UserContributionHeatmapResult = AxiosResponse +export type UserInvalidateChpcCacheResult = AxiosResponse +export type UserListSshKeysResult = AxiosResponse +export type UserAddSshKeyResult = AxiosResponse +export type UserUpdateSshKeyResult = AxiosResponse +export type UserRevokeSshKeyResult = AxiosResponse +export type UsersUserAvatarResult = AxiosResponse +export type UsersBlockedListResult = AxiosResponse +export type UsersBlockUserResult = AxiosResponse +export type UsersUserChpcResult = AxiosResponse +export type UsersFollowUserResult = AxiosResponse +export type UsersFollowersResult = AxiosResponse +export type UsersFollowingResult = AxiosResponse +export type UsersUserPublicResult = AxiosResponse +export type UsersRelationStatusResult = AxiosResponse +export type UsersRelationCountsResult = AxiosResponse +export type UsersUserSummaryResult = AxiosResponse +export type UsersUnblockUserResult = AxiosResponse +export type UsersUnfollowUserResult = AxiosResponse +export type WorkspaceCreateWorkspaceResult = AxiosResponse +export type WorkspaceMyJoinAppliesResult = AxiosResponse +export type WorkspaceMyWorkspacesResult = AxiosResponse +export type WorkspaceGetWorkspaceResult = AxiosResponse +export type WorkspaceUpdateWorkspaceResult = AxiosResponse +export type WorkspaceGetAvatarResult = AxiosResponse +export type WorkspaceUploadAvatarResult = AxiosResponse +export type WorkspaceListGroupsResult = AxiosResponse +export type WorkspaceCreateGroupResult = AxiosResponse +export type WorkspaceUpdateGroupResult = AxiosResponse +export type WorkspaceDeleteGroupResult = AxiosResponse +export type WorkspaceListGroupMembersResult = AxiosResponse +export type WorkspaceAddGroupMemberResult = AxiosResponse +export type WorkspaceRemoveGroupMemberResult = AxiosResponse +export type IssuesListIssuesResult = AxiosResponse +export type IssuesCreateIssueResult = AxiosResponse +export type IssuesGetIssueResult = AxiosResponse +export type IssuesUpdateIssueResult = AxiosResponse +export type IssuesDeleteIssueResult = AxiosResponse +export type IssuesAssignUserResult = AxiosResponse +export type IssuesUnassignUserResult = AxiosResponse +export type IssuesCloseIssueResult = AxiosResponse +export type IssuesListCommentsResult = AxiosResponse +export type IssuesCreateCommentResult = AxiosResponse +export type IssuesUpdateCommentResult = AxiosResponse +export type IssuesDeleteCommentResult = AxiosResponse +export type IssuesAddCommentReactionResult = AxiosResponse +export type IssuesRemoveCommentReactionResult = AxiosResponse +export type IssuesListEventsResult = AxiosResponse +export type IssuesAddIssueLabelResult = AxiosResponse +export type IssuesRemoveIssueLabelResult = AxiosResponse +export type IssuesSetIssueMilestoneResult = AxiosResponse +export type IssuesClearIssueMilestoneResult = AxiosResponse +export type IssuesBindPullRequestResult = AxiosResponse +export type IssuesUnbindPullRequestResult = AxiosResponse +export type IssuesAddReactionResult = AxiosResponse +export type IssuesRemoveReactionResult = AxiosResponse +export type IssuesReopenIssueResult = AxiosResponse +export type IssuesBindRepoResult = AxiosResponse +export type IssuesUnbindRepoResult = AxiosResponse +export type WorkspaceJoinStrategyResult = AxiosResponse +export type WorkspaceUpdateJoinStrategyResult = AxiosResponse +export type WorkspaceListJoinAppliesResult = AxiosResponse +export type WorkspaceApproveJoinResult = AxiosResponse +export type WorkspaceApplyJoinResult = AxiosResponse +export type WorkspaceCancelJoinResult = AxiosResponse +export type IssuesListLabelsResult = AxiosResponse +export type IssuesCreateLabelResult = AxiosResponse +export type IssuesUpdateLabelResult = AxiosResponse +export type IssuesDeleteLabelResult = AxiosResponse +export type WorkspaceListMembersResult = AxiosResponse +export type WorkspaceAddMemberResult = AxiosResponse +export type WorkspaceUpdateMemberResult = AxiosResponse +export type WorkspaceRemoveMemberResult = AxiosResponse +export type IssuesListMilestonesResult = AxiosResponse +export type IssuesCreateMilestoneResult = AxiosResponse +export type IssuesUpdateMilestoneResult = AxiosResponse +export type IssuesDeleteMilestoneResult = AxiosResponse +export type GitListReposResult = AxiosResponse +export type GitCreateRepoResult = AxiosResponse +export type GitCloneRepoResult = AxiosResponse +export type GitGetRepoResult = AxiosResponse +export type GitUpdateRepoResult = AxiosResponse +export type GitDeleteRepoResult = AxiosResponse +export type GitArchiveRepoResult = AxiosResponse +export type GitCombinedStatusResult = AxiosResponse +export type GitListStatusesResult = AxiosResponse +export type GitCompareResult = AxiosResponse +export type GitGetContentsResult = AxiosResponse +export type GitUpdateContentsResult = AxiosResponse +export type GitCreateContentsResult = AxiosResponse +export type GitDeleteContentsResult = AxiosResponse +export type GitListForksResult = AxiosResponse +export type GitCreateForkResult = AxiosResponse +export type GitArchiveResult = AxiosResponse +export type GitBlameFileResult = AxiosResponse +export type GitBlobUploadResult = AxiosResponse +export type GitBlobInfoResult = AxiosResponse +export type GitListBranchesResult = AxiosResponse +export type GitForkBranchResult = AxiosResponse +export type GitBranchInfoResult = AxiosResponse +export type GitDeleteBranchResult = AxiosResponse +export type GitRenameBranchResult = AxiosResponse +export type GitAheadBehindResult = AxiosResponse +export type GitBranchUpstreamResult = AxiosResponse +export type GitListCommitsResult = AxiosResponse +export type GitCherryPickResult = AxiosResponse +export type GitCommitHistoryResult = AxiosResponse +export type GitCommitWalkResult = AxiosResponse +export type GitCommitInfoResult = AxiosResponse +export type GitTreeEntryByPathFromCommitResult = AxiosResponse +export type GitListContributorsResult = AxiosResponse +export type GitDiffResult = AxiosResponse +export type GitDiffBranchesResult = AxiosResponse +export type GitGetLanguagesResult = AxiosResponse +export type GitGetReadmeResult = AxiosResponse +export type GitListRefsResult = AxiosResponse +export type GitStarStatusResult = AxiosResponse +export type GitStarRepoResult = AxiosResponse +export type GitUnstarRepoResult = AxiosResponse +export type GitListTagsResult = AxiosResponse +export type GitInitTagResult = AxiosResponse +export type GitTagInfoResult = AxiosResponse +export type GitDeleteTagResult = AxiosResponse +export type GitUpdateTagResult = AxiosResponse +export type GitTreeEntriesResult = AxiosResponse +export type GitTreeEntryByPathResult = AxiosResponse +export type GitWatchStatusResult = AxiosResponse +export type GitWatchRepoResult = AxiosResponse +export type GitUnwatchRepoResult = AxiosResponse +export type GitListProtectsResult = AxiosResponse +export type GitCreateProtectResult = AxiosResponse +export type GitUpdateProtectResult = AxiosResponse +export type GitDeleteProtectResult = AxiosResponse +export type PullRequestListPrsResult = AxiosResponse +export type PullRequestCreatePrResult = AxiosResponse +export type PullRequestGetPrResult = AxiosResponse +export type PullRequestDeletePrResult = AxiosResponse +export type PullRequestUpdatePrResult = AxiosResponse +export type PullRequestAssignUserResult = AxiosResponse +export type PullRequestUnassignUserResult = AxiosResponse +export type PullRequestListCommentsResult = AxiosResponse +export type PullRequestCreateCommentResult = AxiosResponse +export type PullRequestUpdateCommentResult = AxiosResponse +export type PullRequestDeleteCommentResult = AxiosResponse +export type PullRequestAddCommentReactionResult = AxiosResponse +export type PullRequestRemoveCommentReactionResult = AxiosResponse +export type PullRequestAddLabelResult = AxiosResponse +export type PullRequestRemoveLabelResult = AxiosResponse +export type PullRequestMergeAnalysisResult = AxiosResponse +export type PullRequestMergePrResult = AxiosResponse +export type PullRequestAddReactionResult = AxiosResponse +export type PullRequestRemoveReactionResult = AxiosResponse +export type PullRequestCreateReviewCommentResult = AxiosResponse +export type PullRequestListReviewsResult = AxiosResponse +export type PullRequestCreateReviewResult = AxiosResponse +export type PullRequestDismissReviewResult = AxiosResponse +export type PullRequestUpdateBranchResult = AxiosResponse +export type GitListReleasesResult = AxiosResponse +export type GitCreateReleaseResult = AxiosResponse +export type GitGetReleaseByTagResult = AxiosResponse +export type GitDeleteReleaseByTagResult = AxiosResponse +export type GitGetReleaseResult = AxiosResponse +export type GitDeleteReleaseResult = AxiosResponse +export type GitUpdateReleaseResult = AxiosResponse +export type GitCreateStatusResult = AxiosResponse +export type GitGetTopicsResult = AxiosResponse +export type GitUpdateTopicsResult = AxiosResponse +export type GitTransferRepoResult = AxiosResponse +export type GitListWebhooksResult = AxiosResponse +export type GitCreateWebhookResult = AxiosResponse +export type GitUpdateWebhookResult = AxiosResponse +export type GitDeleteWebhookResult = AxiosResponse +export type GitListDeliveriesResult = AxiosResponse +export type ChannelCategoryDeleteResult = AxiosResponse +export type ChannelCategoryUpdateResult = AxiosResponse +export type ChannelCsrfTokenResult = AxiosResponse +export type ChannelCustomStatusUpdateResult = AxiosResponse +export type ChannelInviteCreateResult = AxiosResponse +export type ChannelInviteAcceptResult = AxiosResponse +export type ChannelInviteRevokeResult = AxiosResponse +export type ChannelRevokeMessageResult = AxiosResponse +export type ChannelUpdateMessageResult = AxiosResponse +export type ChannelNotificationMarkAllReadResult = AxiosResponse +export type ChannelNotificationArchiveResult = AxiosResponse +export type ChannelNotificationMarkReadResult = AxiosResponse +export type ChannelPingResult = AxiosResponse +export type ChannelPresenceUpdateResult = AxiosResponse +export type ChannelRoomCreateResult = AxiosResponse +export type ChannelRoomGetResult = AxiosResponse +export type ChannelRoomDeleteResult = AxiosResponse +export type ChannelRoomUpdateResult = AxiosResponse +export type ChannelAiListResult = AxiosResponse +export type ChannelAiAddResult = AxiosResponse +export type ChannelAiStopResult = AxiosResponse +export type ChannelAiRemoveResult = AxiosResponse +export type ChannelDndUpdateResult = AxiosResponse +export type ChannelDraftSaveResult = AxiosResponse +export type ChannelDraftClearResult = AxiosResponse +export type ChannelAccessGrantResult = AxiosResponse +export type ChannelAccessRevokeResult = AxiosResponse +export type ChannelListMessagesResult = AxiosResponse +export type ChannelCreateMessageResult = AxiosResponse +export type ChannelMessagesAroundResult = AxiosResponse +export type ChannelMissedMessagesResult = AxiosResponse +export type ChannelPinAddResult = AxiosResponse +export type ChannelPinRemoveResult = AxiosResponse +export type ChannelReactionAddResult = AxiosResponse +export type ChannelReactionRemoveResult = AxiosResponse +export type ChannelReadReceiptResult = AxiosResponse +export type ChannelScreenShareResult = AxiosResponse +export type ChannelSubscribeResult = AxiosResponse +export type ChannelUnsubscribeResult = AxiosResponse +export type ChannelThreadCreateResult = AxiosResponse +export type ChannelTypingResult = AxiosResponse +export type ChannelVoiceDeafResult = AxiosResponse +export type ChannelVoiceJoinResult = AxiosResponse +export type ChannelVoiceLeaveResult = AxiosResponse +export type ChannelVoiceMuteResult = AxiosResponse +export type ChannelSearchResult = AxiosResponse +export type ChannelThreadArchiveResult = AxiosResponse +export type ChannelThreadResolveResult = AxiosResponse +export type ChannelGenerateTokenResult = AxiosResponse +export type ChannelUserSummaryResult = AxiosResponse +export type ChannelBanCreateResult = AxiosResponse +export type ChannelBanRemoveResult = AxiosResponse +export type ChannelCategoryCreateResult = AxiosResponse diff --git a/src/client/index.ts b/src/client/index.ts new file mode 100644 index 0000000..6f60b10 --- /dev/null +++ b/src/client/index.ts @@ -0,0 +1,7 @@ +import axios from "axios"; +import { getGitDataAIAPI } from "./endpoints"; + +export const api = axios.create({ baseURL: "", withCredentials: true, headers: { "Content-Type": "application/json" } }); +export const client = getGitDataAIAPI(api); +export * from "./models"; +export type * from "./endpoints"; diff --git a/src/client/models/accessRequest.ts b/src/client/models/accessRequest.ts new file mode 100644 index 0000000..7992525 --- /dev/null +++ b/src/client/models/accessRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AccessRequest { + user: string; +} diff --git a/src/client/models/addIssueLabel.ts b/src/client/models/addIssueLabel.ts new file mode 100644 index 0000000..e9e8271 --- /dev/null +++ b/src/client/models/addIssueLabel.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AddIssueLabel { + label_id: string; +} diff --git a/src/client/models/addPrLabel.ts b/src/client/models/addPrLabel.ts new file mode 100644 index 0000000..b5f5b7c --- /dev/null +++ b/src/client/models/addPrLabel.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AddPrLabel { + label_id: string; +} diff --git a/src/client/models/addPrReaction.ts b/src/client/models/addPrReaction.ts new file mode 100644 index 0000000..5b28689 --- /dev/null +++ b/src/client/models/addPrReaction.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AddPrReaction { + reaction: string; +} diff --git a/src/client/models/addReaction.ts b/src/client/models/addReaction.ts new file mode 100644 index 0000000..4f46f7c --- /dev/null +++ b/src/client/models/addReaction.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AddReaction { + reaction: string; +} diff --git a/src/client/models/addWorkspaceMember.ts b/src/client/models/addWorkspaceMember.ts new file mode 100644 index 0000000..22225d2 --- /dev/null +++ b/src/client/models/addWorkspaceMember.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AddWorkspaceMember { + /** @nullable */ + admin?: boolean | null; + username: string; +} diff --git a/src/client/models/agentCostInfo.ts b/src/client/models/agentCostInfo.ts new file mode 100644 index 0000000..c57a160 --- /dev/null +++ b/src/client/models/agentCostInfo.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AgentCostInfo { + amount: string; + currency: string; +} diff --git a/src/client/models/agentListAllConversationsParams.ts b/src/client/models/agentListAllConversationsParams.ts new file mode 100644 index 0000000..7597adc --- /dev/null +++ b/src/client/models/agentListAllConversationsParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type AgentListAllConversationsParams = { +/** + * Filter by workspace name + */ +wk?: string; +}; diff --git a/src/client/models/agentListMessagesParams.ts b/src/client/models/agentListMessagesParams.ts new file mode 100644 index 0000000..d419c35 --- /dev/null +++ b/src/client/models/agentListMessagesParams.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type AgentListMessagesParams = { +before?: string; +/** + * @minimum 0 + */ +limit?: number; +}; diff --git a/src/client/models/agentRunRequest.ts b/src/client/models/agentRunRequest.ts new file mode 100644 index 0000000..2e95c07 --- /dev/null +++ b/src/client/models/agentRunRequest.ts @@ -0,0 +1,25 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AgentRunRequest { + /** @nullable */ + conversation_id?: string | null; + input: string; + /** + * @minimum 0 + * @nullable + */ + max_steps?: number | null; + session_id: string; + stream?: boolean; + /** + * @minimum 0 + * @nullable + */ + timeout_secs?: number | null; +} diff --git a/src/client/models/agentRunResponse.ts b/src/client/models/agentRunResponse.ts new file mode 100644 index 0000000..c92f760 --- /dev/null +++ b/src/client/models/agentRunResponse.ts @@ -0,0 +1,19 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { AgentCostInfo } from './agentCostInfo'; +import type { AgentStepInfo } from './agentStepInfo'; +import type { AgentUsageInfo } from './agentUsageInfo'; + +export interface AgentRunResponse { + conversation_id: string; + cost?: null | AgentCostInfo; + message_id: string; + output: string; + steps: AgentStepInfo[]; + usage: AgentUsageInfo; +} diff --git a/src/client/models/agentSessionResponse.ts b/src/client/models/agentSessionResponse.ts new file mode 100644 index 0000000..0de9d31 --- /dev/null +++ b/src/client/models/agentSessionResponse.ts @@ -0,0 +1,48 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AgentSessionResponse { + agent_kind: string; + created_at: string; + /** @nullable */ + description?: string | null; + enabled: boolean; + id: string; + /** @nullable */ + iteration_budget?: number | null; + /** @nullable */ + max_output_tokens?: number | null; + /** @nullable */ + memory_provider?: string | null; + /** @nullable */ + model_version?: string | null; + name: string; + /** @nullable */ + parent_session_id?: string | null; + /** @nullable */ + published_at?: string | null; + /** @nullable */ + source?: string | null; + /** @nullable */ + system_prompt?: string | null; + /** @nullable */ + temperature?: number | null; + /** @nullable */ + tool_policy?: string | null; + /** @nullable */ + toolset_json?: string | null; + updated_at: string; + /** @nullable */ + user?: string | null; + /** @nullable */ + variables?: string | null; + version: number; + visibility: string; + /** @nullable */ + wk?: string | null; +} diff --git a/src/client/models/agentStepInfo.ts b/src/client/models/agentStepInfo.ts new file mode 100644 index 0000000..67b5b08 --- /dev/null +++ b/src/client/models/agentStepInfo.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { AgentToolCallInfo } from './agentToolCallInfo'; + +export interface AgentStepInfo { + /** @nullable */ + assistant?: string | null; + /** @minimum 0 */ + index: number; + /** @nullable */ + reflection?: string | null; + tool_calls: AgentToolCallInfo[]; +} diff --git a/src/client/models/agentToolCallInfo.ts b/src/client/models/agentToolCallInfo.ts new file mode 100644 index 0000000..bba15b1 --- /dev/null +++ b/src/client/models/agentToolCallInfo.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AgentToolCallInfo { + arguments: unknown; + /** @nullable */ + elapsed_ms?: number | null; + /** @nullable */ + error?: string | null; + id: string; + name: string; + output?: unknown; +} diff --git a/src/client/models/agentUsageInfo.ts b/src/client/models/agentUsageInfo.ts new file mode 100644 index 0000000..ba1e260 --- /dev/null +++ b/src/client/models/agentUsageInfo.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AgentUsageInfo { + input_tokens: number; + output_tokens: number; + total_tokens: number; +} diff --git a/src/client/models/aiAddRequest.ts b/src/client/models/aiAddRequest.ts new file mode 100644 index 0000000..6f5b9e7 --- /dev/null +++ b/src/client/models/aiAddRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AiAddRequest { + agent_session: string; +} diff --git a/src/client/models/aiDiscussionResponse.ts b/src/client/models/aiDiscussionResponse.ts new file mode 100644 index 0000000..24ab2ed --- /dev/null +++ b/src/client/models/aiDiscussionResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface AiDiscussionResponse { + author: IssueAuthor; + body: string; + created_at: string; + id: string; + /** @nullable */ + parent?: string | null; + updated_at: string; +} diff --git a/src/client/models/aiLikeResponse.ts b/src/client/models/aiLikeResponse.ts new file mode 100644 index 0000000..db18cbc --- /dev/null +++ b/src/client/models/aiLikeResponse.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface AiLikeResponse { + created_at: string; + user: IssueAuthor; +} diff --git a/src/client/models/aiListDiscussionsParams.ts b/src/client/models/aiListDiscussionsParams.ts new file mode 100644 index 0000000..0a05196 --- /dev/null +++ b/src/client/models/aiListDiscussionsParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type AiListDiscussionsParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/aiListModelsParams.ts b/src/client/models/aiListModelsParams.ts new file mode 100644 index 0000000..1106af9 --- /dev/null +++ b/src/client/models/aiListModelsParams.ts @@ -0,0 +1,36 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type AiListModelsParams = { +/** + * @nullable + */ +enabled?: boolean | null; +/** + * @nullable + */ +provider?: string | null; +/** + * @nullable + */ +modality?: string | null; +/** + * @nullable + */ +name?: string | null; +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/aiModelCardResponse.ts b/src/client/models/aiModelCardResponse.ts new file mode 100644 index 0000000..9ed369e --- /dev/null +++ b/src/client/models/aiModelCardResponse.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AiModelCardResponse { + /** @nullable */ + eval_summary?: string | null; + /** @nullable */ + limitations?: string | null; + /** @nullable */ + metadata?: string | null; + /** @nullable */ + overview?: string | null; + /** @nullable */ + safety_notes?: string | null; + /** @nullable */ + strengths?: string | null; +} diff --git a/src/client/models/aiModelListItem.ts b/src/client/models/aiModelListItem.ts new file mode 100644 index 0000000..3295f11 --- /dev/null +++ b/src/client/models/aiModelListItem.ts @@ -0,0 +1,28 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AiModelListItem { + /** @nullable */ + context_window?: number | null; + created_at: string; + /** @nullable */ + description?: string | null; + display_name: string; + enabled: boolean; + id: string; + /** @nullable */ + input_token_limit?: number | null; + modality: string; + name: string; + /** @nullable */ + output_token_limit?: number | null; + /** @nullable */ + provider_logo_url?: string | null; + provider_name: string; + updated_at: string; +} diff --git a/src/client/models/aiModelResponse.ts b/src/client/models/aiModelResponse.ts new file mode 100644 index 0000000..7ad4f46 --- /dev/null +++ b/src/client/models/aiModelResponse.ts @@ -0,0 +1,35 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { AiModelCardResponse } from './aiModelCardResponse'; +import type { AiModelVersionResponse } from './aiModelVersionResponse'; + +export interface AiModelResponse { + card?: null | AiModelCardResponse; + /** @nullable */ + context_window?: number | null; + created_at: string; + /** @nullable */ + description?: string | null; + display_name: string; + enabled: boolean; + id: string; + /** @nullable */ + input_token_limit?: number | null; + like_count: number; + modality: string; + name: string; + /** @nullable */ + output_token_limit?: number | null; + /** @nullable */ + provider_logo_url?: string | null; + provider_name: string; + public: boolean; + tags: string[]; + updated_at: string; + versions: AiModelVersionResponse[]; +} diff --git a/src/client/models/aiModelVersionResponse.ts b/src/client/models/aiModelVersionResponse.ts new file mode 100644 index 0000000..c907f1d --- /dev/null +++ b/src/client/models/aiModelVersionResponse.ts @@ -0,0 +1,26 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AiModelVersionResponse { + /** @nullable */ + cached_input_price_per_million?: string | null; + /** @nullable */ + deprecated_at?: string | null; + enabled: boolean; + id: string; + /** @nullable */ + input_price_per_million?: string | null; + /** @nullable */ + output_price_per_million?: string | null; + provider_model_name: string; + /** @nullable */ + released_at?: string | null; + /** @nullable */ + training_cutoff?: string | null; + version: string; +} diff --git a/src/client/models/aiProviderResponse.ts b/src/client/models/aiProviderResponse.ts new file mode 100644 index 0000000..84b8df2 --- /dev/null +++ b/src/client/models/aiProviderResponse.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AiProviderResponse { + /** @nullable */ + base_url?: string | null; + created_at: string; + enabled: boolean; + id: string; + /** @nullable */ + logo_url?: string | null; + name: string; + updated_at: string; + /** @nullable */ + website_url?: string | null; +} diff --git a/src/client/models/appNotificationItem.ts b/src/client/models/appNotificationItem.ts new file mode 100644 index 0000000..4624cad --- /dev/null +++ b/src/client/models/appNotificationItem.ts @@ -0,0 +1,9 @@ +export interface AppNotificationItem { + body: string; + created_at: string; + id: string; + notify_type: string; + /** @nullable */ + read_at?: string | null; + title: string; +} diff --git a/src/client/models/approveWorkspaceJoinApply.ts b/src/client/models/approveWorkspaceJoinApply.ts new file mode 100644 index 0000000..f18dc04 --- /dev/null +++ b/src/client/models/approveWorkspaceJoinApply.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ApproveWorkspaceJoinApply { + approved: boolean; + /** @nullable */ + reason?: string | null; +} diff --git a/src/client/models/assignIssueUser.ts b/src/client/models/assignIssueUser.ts new file mode 100644 index 0000000..b390e06 --- /dev/null +++ b/src/client/models/assignIssueUser.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AssignIssueUser { + username: string; +} diff --git a/src/client/models/assignPrUser.ts b/src/client/models/assignPrUser.ts new file mode 100644 index 0000000..790f867 --- /dev/null +++ b/src/client/models/assignPrUser.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AssignPrUser { + username: string; +} diff --git a/src/client/models/authCaptchaParams.ts b/src/client/models/authCaptchaParams.ts new file mode 100644 index 0000000..1cbe68c --- /dev/null +++ b/src/client/models/authCaptchaParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type AuthCaptchaParams = { +/** + * @minimum 0 + */ +w: number; +/** + * @minimum 0 + */ +h: number; +dark: boolean; +rsa: boolean; +}; diff --git a/src/client/models/avatarUploadResponse.ts b/src/client/models/avatarUploadResponse.ts new file mode 100644 index 0000000..ac764bf --- /dev/null +++ b/src/client/models/avatarUploadResponse.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface AvatarUploadResponse { + avatar_url: string; +} diff --git a/src/client/models/banCreateRequest.ts b/src/client/models/banCreateRequest.ts new file mode 100644 index 0000000..93b8b02 --- /dev/null +++ b/src/client/models/banCreateRequest.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BanCreateRequest { + /** @nullable */ + expires_at?: string | null; + /** @nullable */ + reason?: string | null; + user: string; +} diff --git a/src/client/models/bindIssuePullRequest.ts b/src/client/models/bindIssuePullRequest.ts new file mode 100644 index 0000000..0d41482 --- /dev/null +++ b/src/client/models/bindIssuePullRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BindIssuePullRequest { + pull_request_id: string; +} diff --git a/src/client/models/bindIssueRepo.ts b/src/client/models/bindIssueRepo.ts new file mode 100644 index 0000000..bb83681 --- /dev/null +++ b/src/client/models/bindIssueRepo.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BindIssueRepo { + repo_id: string; +} diff --git a/src/client/models/blameFileResponseDto.ts b/src/client/models/blameFileResponseDto.ts new file mode 100644 index 0000000..cc5e348 --- /dev/null +++ b/src/client/models/blameFileResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CommitBlameHunkDto } from './commitBlameHunkDto'; + +export interface BlameFileResponseDto { + hunks: CommitBlameHunkDto[]; +} diff --git a/src/client/models/blobInfoResponse.ts b/src/client/models/blobInfoResponse.ts new file mode 100644 index 0000000..b646221 --- /dev/null +++ b/src/client/models/blobInfoResponse.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { BlobLoadResponseDto } from './blobLoadResponseDto'; + +export type BlobInfoResponse = BlobLoadResponseDto & { + is_binary: boolean; + /** @minimum 0 */ + size: number; +}; diff --git a/src/client/models/blobLoadResponseDto.ts b/src/client/models/blobLoadResponseDto.ts new file mode 100644 index 0000000..3d8910a --- /dev/null +++ b/src/client/models/blobLoadResponseDto.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BlobLoadResponseDto { + blob: string; +} diff --git a/src/client/models/blobUploadBody.ts b/src/client/models/blobUploadBody.ts new file mode 100644 index 0000000..08f07ac --- /dev/null +++ b/src/client/models/blobUploadBody.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BlobUploadBody { + /** @items.minimum 0 */ + blob: number[]; + path: string; +} diff --git a/src/client/models/blobUploadResponseDto.ts b/src/client/models/blobUploadResponseDto.ts new file mode 100644 index 0000000..1385000 --- /dev/null +++ b/src/client/models/blobUploadResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BlobUploadResponseDto { + /** @nullable */ + id?: string | null; +} diff --git a/src/client/models/branchAheadBehindResponseDto.ts b/src/client/models/branchAheadBehindResponseDto.ts new file mode 100644 index 0000000..eb0cedf --- /dev/null +++ b/src/client/models/branchAheadBehindResponseDto.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BranchAheadBehindResponseDto { + /** @minimum 0 */ + ahead: number; + /** @minimum 0 */ + behind: number; +} diff --git a/src/client/models/branchInfoResponseDto.ts b/src/client/models/branchInfoResponseDto.ts new file mode 100644 index 0000000..9342c38 --- /dev/null +++ b/src/client/models/branchInfoResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { BranchListItemDto } from './branchListItemDto'; + +export interface BranchInfoResponseDto { + branch?: null | BranchListItemDto; +} diff --git a/src/client/models/branchListItemDto.ts b/src/client/models/branchListItemDto.ts new file mode 100644 index 0000000..3d0baaf --- /dev/null +++ b/src/client/models/branchListItemDto.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BranchListItemDto { + is_current: boolean; + is_head: boolean; + is_remote: boolean; + name: string; + oid: string; + /** @nullable */ + upstream?: string | null; +} diff --git a/src/client/models/branchListResponseDto.ts b/src/client/models/branchListResponseDto.ts new file mode 100644 index 0000000..c913fdb --- /dev/null +++ b/src/client/models/branchListResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { BranchListItemDto } from './branchListItemDto'; + +export interface BranchListResponseDto { + branches: BranchListItemDto[]; +} diff --git a/src/client/models/branchUpstreamResponseDto.ts b/src/client/models/branchUpstreamResponseDto.ts new file mode 100644 index 0000000..3f740ad --- /dev/null +++ b/src/client/models/branchUpstreamResponseDto.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface BranchUpstreamResponseDto { + upstream_name: string; +} diff --git a/src/client/models/captchaQuery.ts b/src/client/models/captchaQuery.ts new file mode 100644 index 0000000..ad2795d --- /dev/null +++ b/src/client/models/captchaQuery.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CaptchaQuery { + dark: boolean; + /** @minimum 0 */ + h: number; + rsa: boolean; + /** @minimum 0 */ + w: number; +} diff --git a/src/client/models/captchaResponse.ts b/src/client/models/captchaResponse.ts new file mode 100644 index 0000000..24fecb2 --- /dev/null +++ b/src/client/models/captchaResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CaptchaQuery } from './captchaQuery'; +import type { RsaResponse } from './rsaResponse'; + +export interface CaptchaResponse { + base64: string; + req: CaptchaQuery; + rsa?: null | RsaResponse; +} diff --git a/src/client/models/categoryCreateRequest.ts b/src/client/models/categoryCreateRequest.ts new file mode 100644 index 0000000..2d28399 --- /dev/null +++ b/src/client/models/categoryCreateRequest.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CategoryCreateRequest { + name: string; + /** @nullable */ + position?: number | null; +} diff --git a/src/client/models/categoryUpdateRequest.ts b/src/client/models/categoryUpdateRequest.ts new file mode 100644 index 0000000..7d10f04 --- /dev/null +++ b/src/client/models/categoryUpdateRequest.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CategoryUpdateRequest { + /** @nullable */ + name?: string | null; + /** @nullable */ + position?: number | null; +} diff --git a/src/client/models/channelListMessagesParams.ts b/src/client/models/channelListMessagesParams.ts new file mode 100644 index 0000000..36a4008 --- /dev/null +++ b/src/client/models/channelListMessagesParams.ts @@ -0,0 +1,23 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type ChannelListMessagesParams = { +/** + * @nullable + */ +before_seq?: number | null; +/** + * @nullable + */ +after_seq?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/channelMessagesAroundParams.ts b/src/client/models/channelMessagesAroundParams.ts new file mode 100644 index 0000000..819cf3d --- /dev/null +++ b/src/client/models/channelMessagesAroundParams.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type ChannelMessagesAroundParams = { +seq: number; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/channelMissedMessagesParams.ts b/src/client/models/channelMissedMessagesParams.ts new file mode 100644 index 0000000..0e4719d --- /dev/null +++ b/src/client/models/channelMissedMessagesParams.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type ChannelMissedMessagesParams = { +after_seq: number; +/** + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/channelSearchParams.ts b/src/client/models/channelSearchParams.ts new file mode 100644 index 0000000..4184a66 --- /dev/null +++ b/src/client/models/channelSearchParams.ts @@ -0,0 +1,25 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type ChannelSearchParams = { +q: string; +/** + * @nullable + */ +room?: string | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +}; diff --git a/src/client/models/cherryPickResponseDto.ts b/src/client/models/cherryPickResponseDto.ts new file mode 100644 index 0000000..f39b753 --- /dev/null +++ b/src/client/models/cherryPickResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CherryPickResponseDto { + /** @nullable */ + oid?: string | null; +} diff --git a/src/client/models/cloneRepo.ts b/src/client/models/cloneRepo.ts new file mode 100644 index 0000000..31b55ea --- /dev/null +++ b/src/client/models/cloneRepo.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CloneRepo { + /** @nullable */ + description?: string | null; + name: string; + source_url: string; + /** @nullable */ + visibility?: string | null; +} diff --git a/src/client/models/combinedCommitStatus.ts b/src/client/models/combinedCommitStatus.ts new file mode 100644 index 0000000..5537388 --- /dev/null +++ b/src/client/models/combinedCommitStatus.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CommitStatusResponse } from './commitStatusResponse'; + +export interface CombinedCommitStatus { + sha: string; + state: string; + statuses: CommitStatusResponse[]; + total_count: number; +} diff --git a/src/client/models/commitBlameHunkDto.ts b/src/client/models/commitBlameHunkDto.ts new file mode 100644 index 0000000..75ba84a --- /dev/null +++ b/src/client/models/commitBlameHunkDto.ts @@ -0,0 +1,23 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CommitBlameHunkDto { + boundary: boolean; + /** @nullable */ + commit_oid?: string | null; + /** @minimum 0 */ + final_lines: number; + /** @minimum 0 */ + final_start_line: number; + /** @minimum 0 */ + orig_lines: number; + /** @nullable */ + orig_path?: string | null; + /** @minimum 0 */ + orig_start_line: number; +} diff --git a/src/client/models/commitHistoryResponseDto.ts b/src/client/models/commitHistoryResponseDto.ts new file mode 100644 index 0000000..0e4bbaf --- /dev/null +++ b/src/client/models/commitHistoryResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CommitMetaDto } from './commitMetaDto'; + +export interface CommitHistoryResponseDto { + commits: CommitMetaDto[]; +} diff --git a/src/client/models/commitInfoResponseDto.ts b/src/client/models/commitInfoResponseDto.ts new file mode 100644 index 0000000..9ddf042 --- /dev/null +++ b/src/client/models/commitInfoResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CommitMetaDto } from './commitMetaDto'; + +export interface CommitInfoResponseDto { + commit?: null | CommitMetaDto; +} diff --git a/src/client/models/commitMetaDto.ts b/src/client/models/commitMetaDto.ts new file mode 100644 index 0000000..989e320 --- /dev/null +++ b/src/client/models/commitMetaDto.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CommitSignatureDto } from './commitSignatureDto'; + +export interface CommitMetaDto { + author?: null | CommitSignatureDto; + committer?: null | CommitSignatureDto; + /** @nullable */ + encoding?: string | null; + message: string; + oid: string; + parent_ids: string[]; + summary: string; + /** @nullable */ + tree_id?: string | null; +} diff --git a/src/client/models/commitSignatureDto.ts b/src/client/models/commitSignatureDto.ts new file mode 100644 index 0000000..2019af5 --- /dev/null +++ b/src/client/models/commitSignatureDto.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CommitSignatureDto { + email: string; + name: string; + offset_minutes: number; + time_secs: number; +} diff --git a/src/client/models/commitStatusResponse.ts b/src/client/models/commitStatusResponse.ts new file mode 100644 index 0000000..55eb00d --- /dev/null +++ b/src/client/models/commitStatusResponse.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CommitStatusResponse { + commit_sha: string; + context: string; + created_at: string; + creator: string; + /** @nullable */ + description?: string | null; + id: string; + state: string; + /** @nullable */ + target_url?: string | null; +} diff --git a/src/client/models/compareCommit.ts b/src/client/models/compareCommit.ts new file mode 100644 index 0000000..5627f35 --- /dev/null +++ b/src/client/models/compareCommit.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CompareCommit { + /** @nullable */ + author_email?: string | null; + /** @nullable */ + author_name?: string | null; + message: string; + sha: string; +} diff --git a/src/client/models/compareResponse.ts b/src/client/models/compareResponse.ts new file mode 100644 index 0000000..3065d4e --- /dev/null +++ b/src/client/models/compareResponse.ts @@ -0,0 +1,23 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { CompareCommit } from './compareCommit'; + +export interface CompareResponse { + ahead_by: number; + base_commit: CompareCommit; + behind_by: number; + commits: CompareCommit[]; + /** @minimum 0 */ + deletions: number; + /** @minimum 0 */ + files_changed: number; + head_commit: CompareCommit; + /** @minimum 0 */ + insertions: number; + total_commits: number; +} diff --git a/src/client/models/contentResponse.ts b/src/client/models/contentResponse.ts new file mode 100644 index 0000000..a6f3f61 --- /dev/null +++ b/src/client/models/contentResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ContentResponse { + /** @nullable */ + content?: string | null; + /** @nullable */ + encoding?: string | null; + name: string; + path: string; + size: number; + type: string; +} diff --git a/src/client/models/contextMe.ts b/src/client/models/contextMe.ts new file mode 100644 index 0000000..82abae7 --- /dev/null +++ b/src/client/models/contextMe.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ContextMe { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + display_name?: string | null; + /** @minimum 0 */ + has_unread_notifications: number; + id: string; + language: string; + timezone: string; + username: string; +} diff --git a/src/client/models/contributionHeatmapItem.ts b/src/client/models/contributionHeatmapItem.ts new file mode 100644 index 0000000..ea463eb --- /dev/null +++ b/src/client/models/contributionHeatmapItem.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ContributionHeatmapItem { + count: number; + date: string; +} diff --git a/src/client/models/contributionHeatmapResponse.ts b/src/client/models/contributionHeatmapResponse.ts new file mode 100644 index 0000000..2542136 --- /dev/null +++ b/src/client/models/contributionHeatmapResponse.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { ContributionHeatmapItem } from './contributionHeatmapItem'; + +export interface ContributionHeatmapResponse { + end_date: string; + heatmap: ContributionHeatmapItem[]; + start_date: string; + total_contributions: number; + username: string; +} diff --git a/src/client/models/contributorDto.ts b/src/client/models/contributorDto.ts new file mode 100644 index 0000000..d031848 --- /dev/null +++ b/src/client/models/contributorDto.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ContributorDto { + commit_count: number; + email: string; + name: string; + /** @nullable */ + user_id?: string | null; +} diff --git a/src/client/models/conversationResponse.ts b/src/client/models/conversationResponse.ts new file mode 100644 index 0000000..1fe7098 --- /dev/null +++ b/src/client/models/conversationResponse.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ConversationResponse { + /** @nullable */ + archived_at?: string | null; + created_at: string; + created_by: string; + id: string; + /** @nullable */ + last_message_at?: string | null; + session_id: string; + title: string; + updated_at: string; +} diff --git a/src/client/models/conversationWithSessionResponse.ts b/src/client/models/conversationWithSessionResponse.ts new file mode 100644 index 0000000..0b1a35a --- /dev/null +++ b/src/client/models/conversationWithSessionResponse.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ConversationWithSessionResponse { + /** @nullable */ + archived_at?: string | null; + created_at: string; + created_by: string; + id: string; + /** @nullable */ + last_message_at?: string | null; + session_id: string; + /** @nullable */ + session_name?: string | null; + title: string; + updated_at: string; +} diff --git a/src/client/models/createAgentSession.ts b/src/client/models/createAgentSession.ts new file mode 100644 index 0000000..cc61d27 --- /dev/null +++ b/src/client/models/createAgentSession.ts @@ -0,0 +1,41 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateAgentSession { + agent_kind: string; + /** @nullable */ + description?: string | null; + /** @nullable */ + iteration_budget?: number | null; + /** @nullable */ + knowledge_base_ids?: string[] | null; + /** @nullable */ + max_output_tokens?: number | null; + /** @nullable */ + memory_provider?: string | null; + /** @nullable */ + memory_provider_config?: string | null; + model_version: string; + name: string; + /** @nullable */ + source?: string | null; + /** @nullable */ + system_prompt?: string | null; + /** @nullable */ + temperature?: number | null; + /** @nullable */ + tool_policy?: string | null; + /** @nullable */ + toolset_json?: string | null; + /** @nullable */ + variables?: string | null; + /** @nullable */ + visibility?: string | null; + /** @nullable */ + wk?: string | null; +} diff --git a/src/client/models/createComment.ts b/src/client/models/createComment.ts new file mode 100644 index 0000000..b820f10 --- /dev/null +++ b/src/client/models/createComment.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateComment { + body: string; +} diff --git a/src/client/models/createCommitStatus.ts b/src/client/models/createCommitStatus.ts new file mode 100644 index 0000000..d014060 --- /dev/null +++ b/src/client/models/createCommitStatus.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateCommitStatus { + /** @nullable */ + context?: string | null; + /** @nullable */ + description?: string | null; + state: string; + /** @nullable */ + target_url?: string | null; +} diff --git a/src/client/models/createContent.ts b/src/client/models/createContent.ts new file mode 100644 index 0000000..3371c0d --- /dev/null +++ b/src/client/models/createContent.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateContent { + /** @nullable */ + branch?: string | null; + content: string; + message: string; +} diff --git a/src/client/models/createConversation.ts b/src/client/models/createConversation.ts new file mode 100644 index 0000000..07950a4 --- /dev/null +++ b/src/client/models/createConversation.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateConversation { + title: string; +} diff --git a/src/client/models/createFork.ts b/src/client/models/createFork.ts new file mode 100644 index 0000000..90755b0 --- /dev/null +++ b/src/client/models/createFork.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateFork { + /** @nullable */ + name?: string | null; + /** @nullable */ + visibility?: string | null; +} diff --git a/src/client/models/createIssue.ts b/src/client/models/createIssue.ts new file mode 100644 index 0000000..28668f1 --- /dev/null +++ b/src/client/models/createIssue.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateIssue { + /** @nullable */ + body?: string | null; + /** @nullable */ + due_at?: string | null; + /** @nullable */ + priority?: string | null; + title: string; +} diff --git a/src/client/models/createLabel.ts b/src/client/models/createLabel.ts new file mode 100644 index 0000000..c579f1d --- /dev/null +++ b/src/client/models/createLabel.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateLabel { + color: string; + /** @nullable */ + description?: string | null; + name: string; +} diff --git a/src/client/models/createMessageRequest.ts b/src/client/models/createMessageRequest.ts new file mode 100644 index 0000000..df3faa9 --- /dev/null +++ b/src/client/models/createMessageRequest.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateMessageRequest { + content: string; + /** @nullable */ + content_type?: string | null; + /** @nullable */ + in_reply_to?: string | null; + /** @nullable */ + thread?: string | null; +} diff --git a/src/client/models/createMilestone.ts b/src/client/models/createMilestone.ts new file mode 100644 index 0000000..3e74319 --- /dev/null +++ b/src/client/models/createMilestone.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateMilestone { + /** @nullable */ + description?: string | null; + /** @nullable */ + due_at?: string | null; + title: string; +} diff --git a/src/client/models/createPrComment.ts b/src/client/models/createPrComment.ts new file mode 100644 index 0000000..cf0c162 --- /dev/null +++ b/src/client/models/createPrComment.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreatePrComment { + body: string; +} diff --git a/src/client/models/createPrReview.ts b/src/client/models/createPrReview.ts new file mode 100644 index 0000000..52fc5a5 --- /dev/null +++ b/src/client/models/createPrReview.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreatePrReview { + /** @nullable */ + body?: string | null; + /** @nullable */ + commit_sha?: string | null; + state: string; +} diff --git a/src/client/models/createPrReviewComment.ts b/src/client/models/createPrReviewComment.ts new file mode 100644 index 0000000..9791f92 --- /dev/null +++ b/src/client/models/createPrReviewComment.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreatePrReviewComment { + body: string; + /** @nullable */ + commit_sha?: string | null; + /** @nullable */ + line?: number | null; + path: string; + /** @nullable */ + side?: string | null; +} diff --git a/src/client/models/createProtect.ts b/src/client/models/createProtect.ts new file mode 100644 index 0000000..18bdaba --- /dev/null +++ b/src/client/models/createProtect.ts @@ -0,0 +1,25 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateProtect { + /** @nullable */ + allow_deletions?: boolean | null; + /** @nullable */ + allow_force_pushes?: boolean | null; + /** @nullable */ + enforce_admins?: boolean | null; + pattern: string; + /** @nullable */ + require_pull_request?: boolean | null; + /** @nullable */ + require_status_checks?: boolean | null; + /** @nullable */ + required_approvals?: number | null; + /** @nullable */ + required_status_contexts?: string[] | null; +} diff --git a/src/client/models/createPullRequest.ts b/src/client/models/createPullRequest.ts new file mode 100644 index 0000000..a5b15dc --- /dev/null +++ b/src/client/models/createPullRequest.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreatePullRequest { + /** @nullable */ + body?: string | null; + /** @nullable */ + draft?: boolean | null; + source_branch: string; + /** @nullable */ + source_repo?: string | null; + /** @nullable */ + target_branch?: string | null; + title: string; +} diff --git a/src/client/models/createRelease.ts b/src/client/models/createRelease.ts new file mode 100644 index 0000000..e6babd1 --- /dev/null +++ b/src/client/models/createRelease.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateRelease { + /** @nullable */ + body?: string | null; + draft?: boolean; + name: string; + prerelease?: boolean; + tag_name: string; + /** @nullable */ + target_commit_sha?: string | null; +} diff --git a/src/client/models/createRepo.ts b/src/client/models/createRepo.ts new file mode 100644 index 0000000..08e3098 --- /dev/null +++ b/src/client/models/createRepo.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateRepo { + /** @nullable */ + default_branch?: string | null; + /** @nullable */ + description?: string | null; + /** @nullable */ + enable_lfs?: boolean | null; + /** @nullable */ + initialize_with_readme?: boolean | null; + name: string; + /** @nullable */ + visibility?: string | null; +} diff --git a/src/client/models/createUserAccessToken.ts b/src/client/models/createUserAccessToken.ts new file mode 100644 index 0000000..c734168 --- /dev/null +++ b/src/client/models/createUserAccessToken.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateUserAccessToken { + /** @nullable */ + expires_at?: string | null; + name: string; + scopes: string[]; +} diff --git a/src/client/models/createUserSshKey.ts b/src/client/models/createUserSshKey.ts new file mode 100644 index 0000000..ce497e1 --- /dev/null +++ b/src/client/models/createUserSshKey.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateUserSshKey { + /** @nullable */ + expires_at?: string | null; + public_key: string; + title: string; +} diff --git a/src/client/models/createWebhook.ts b/src/client/models/createWebhook.ts new file mode 100644 index 0000000..6025fda --- /dev/null +++ b/src/client/models/createWebhook.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateWebhook { + /** @nullable */ + active?: boolean | null; + events: string[]; + /** @nullable */ + secret?: string | null; + url: string; +} diff --git a/src/client/models/createWorkspace.ts b/src/client/models/createWorkspace.ts new file mode 100644 index 0000000..b9cc904 --- /dev/null +++ b/src/client/models/createWorkspace.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateWorkspace { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + description?: string | null; + name: string; +} diff --git a/src/client/models/createWorkspaceGroup.ts b/src/client/models/createWorkspaceGroup.ts new file mode 100644 index 0000000..642cd8b --- /dev/null +++ b/src/client/models/createWorkspaceGroup.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateWorkspaceGroup { + /** @nullable */ + avatar_url?: string | null; + name: string; +} diff --git a/src/client/models/createWorkspaceJoinApply.ts b/src/client/models/createWorkspaceJoinApply.ts new file mode 100644 index 0000000..0be4aa0 --- /dev/null +++ b/src/client/models/createWorkspaceJoinApply.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CreateWorkspaceJoinApply { + /** @nullable */ + answer?: string | null; + /** @nullable */ + message?: string | null; +} diff --git a/src/client/models/createdUserAccessToken.ts b/src/client/models/createdUserAccessToken.ts new file mode 100644 index 0000000..cf9c389 --- /dev/null +++ b/src/client/models/createdUserAccessToken.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { UserAccessToken } from './userAccessToken'; + +export interface CreatedUserAccessToken { + access_token: UserAccessToken; + token: string; +} diff --git a/src/client/models/customStatusRequest.ts b/src/client/models/customStatusRequest.ts new file mode 100644 index 0000000..6e76aed --- /dev/null +++ b/src/client/models/customStatusRequest.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface CustomStatusRequest { + /** @nullable */ + emoji?: string | null; + /** @nullable */ + expires_at?: string | null; + /** @nullable */ + text?: string | null; +} diff --git a/src/client/models/diffDeltaDto.ts b/src/client/models/diffDeltaDto.ts new file mode 100644 index 0000000..484c73a --- /dev/null +++ b/src/client/models/diffDeltaDto.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { DiffFileDto } from './diffFileDto'; +import type { DiffHunkDto } from './diffHunkDto'; +import type { DiffLineDto } from './diffLineDto'; + +export interface DiffDeltaDto { + hunks: DiffHunkDto[]; + lines: DiffLineDto[]; + new_file?: null | DiffFileDto; + /** @minimum 0 */ + nfiles: number; + old_file?: null | DiffFileDto; + status: number; +} diff --git a/src/client/models/diffFileDto.ts b/src/client/models/diffFileDto.ts new file mode 100644 index 0000000..4201bff --- /dev/null +++ b/src/client/models/diffFileDto.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DiffFileDto { + is_binary: boolean; + /** @nullable */ + oid?: string | null; + /** @nullable */ + path?: string | null; + /** @minimum 0 */ + size: number; +} diff --git a/src/client/models/diffHunkDto.ts b/src/client/models/diffHunkDto.ts new file mode 100644 index 0000000..32624f7 --- /dev/null +++ b/src/client/models/diffHunkDto.ts @@ -0,0 +1,19 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DiffHunkDto { + header: string; + /** @minimum 0 */ + new_lines: number; + /** @minimum 0 */ + new_start: number; + /** @minimum 0 */ + old_lines: number; + /** @minimum 0 */ + old_start: number; +} diff --git a/src/client/models/diffLineDto.ts b/src/client/models/diffLineDto.ts new file mode 100644 index 0000000..f0b4753 --- /dev/null +++ b/src/client/models/diffLineDto.ts @@ -0,0 +1,25 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DiffLineDto { + content: string; + content_offset: number; + /** + * @minimum 0 + * @nullable + */ + new_lineno?: number | null; + /** @minimum 0 */ + num_lines: number; + /** + * @minimum 0 + * @nullable + */ + old_lineno?: number | null; + origin: string; +} diff --git a/src/client/models/diffResultDto.ts b/src/client/models/diffResultDto.ts new file mode 100644 index 0000000..2680e58 --- /dev/null +++ b/src/client/models/diffResultDto.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { DiffDeltaDto } from './diffDeltaDto'; +import type { DiffStatsDto } from './diffStatsDto'; + +export interface DiffResultDto { + deltas: DiffDeltaDto[]; + stats?: null | DiffStatsDto; +} diff --git a/src/client/models/diffStatsDto.ts b/src/client/models/diffStatsDto.ts new file mode 100644 index 0000000..1d72299 --- /dev/null +++ b/src/client/models/diffStatsDto.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DiffStatsDto { + /** @minimum 0 */ + deletions: number; + /** @minimum 0 */ + files_changed: number; + /** @minimum 0 */ + insertions: number; +} diff --git a/src/client/models/disable2FAParams.ts b/src/client/models/disable2FAParams.ts new file mode 100644 index 0000000..8a12072 --- /dev/null +++ b/src/client/models/disable2FAParams.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface Disable2FAParams { + code: string; + password: string; +} diff --git a/src/client/models/dismissPrReview.ts b/src/client/models/dismissPrReview.ts new file mode 100644 index 0000000..4ffd7eb --- /dev/null +++ b/src/client/models/dismissPrReview.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DismissPrReview { + /** @nullable */ + dismiss_reason?: string | null; +} diff --git a/src/client/models/dndRequest.ts b/src/client/models/dndRequest.ts new file mode 100644 index 0000000..ee6c5ea --- /dev/null +++ b/src/client/models/dndRequest.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DndRequest { + /** @nullable */ + dnd_end_hour?: number | null; + /** @nullable */ + dnd_start_hour?: number | null; + /** @nullable */ + do_not_disturb?: boolean | null; +} diff --git a/src/client/models/draftSaveRequest.ts b/src/client/models/draftSaveRequest.ts new file mode 100644 index 0000000..f840a53 --- /dev/null +++ b/src/client/models/draftSaveRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface DraftSaveRequest { + content: string; +} diff --git a/src/client/models/emailChangeRequest.ts b/src/client/models/emailChangeRequest.ts new file mode 100644 index 0000000..1894451 --- /dev/null +++ b/src/client/models/emailChangeRequest.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface EmailChangeRequest { + new_email: string; + password: string; +} diff --git a/src/client/models/emailResponse.ts b/src/client/models/emailResponse.ts new file mode 100644 index 0000000..f609e0f --- /dev/null +++ b/src/client/models/emailResponse.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface EmailResponse { + /** @nullable */ + email?: string | null; +} diff --git a/src/client/models/emailVerifyRequest.ts b/src/client/models/emailVerifyRequest.ts new file mode 100644 index 0000000..5bbc16a --- /dev/null +++ b/src/client/models/emailVerifyRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface EmailVerifyRequest { + token: string; +} diff --git a/src/client/models/enable2FAResponse.ts b/src/client/models/enable2FAResponse.ts new file mode 100644 index 0000000..b07396c --- /dev/null +++ b/src/client/models/enable2FAResponse.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface Enable2FAResponse { + backup_codes: string[]; + qr_code: string; + secret: string; +} diff --git a/src/client/models/forkResponse.ts b/src/client/models/forkResponse.ts new file mode 100644 index 0000000..f3a0e9e --- /dev/null +++ b/src/client/models/forkResponse.ts @@ -0,0 +1,19 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ForkResponse { + created_at: string; + default_branch: string; + /** @nullable */ + description?: string | null; + forked_by: string; + id: string; + name: string; + source_repo: string; + visibility: string; +} diff --git a/src/client/models/get2FAStatusResponse.ts b/src/client/models/get2FAStatusResponse.ts new file mode 100644 index 0000000..6790d2b --- /dev/null +++ b/src/client/models/get2FAStatusResponse.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface Get2FAStatusResponse { + has_backup_codes: boolean; + is_enabled: boolean; + /** @nullable */ + method?: string | null; +} diff --git a/src/client/models/gitAheadBehindParams.ts b/src/client/models/gitAheadBehindParams.ts new file mode 100644 index 0000000..cb3b21d --- /dev/null +++ b/src/client/models/gitAheadBehindParams.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitAheadBehindParams = { +remote_branch: string; +}; diff --git a/src/client/models/gitArchiveParams.ts b/src/client/models/gitArchiveParams.ts new file mode 100644 index 0000000..14d9443 --- /dev/null +++ b/src/client/models/gitArchiveParams.ts @@ -0,0 +1,23 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitArchiveParams = { +format?: string; +/** + * @nullable + */ +tree?: string | null; +/** + * @nullable + */ +prefix?: string | null; +/** + * @nullable + */ +pathspec?: string[] | null; +}; diff --git a/src/client/models/gitBlameFileParams.ts b/src/client/models/gitBlameFileParams.ts new file mode 100644 index 0000000..aa822f9 --- /dev/null +++ b/src/client/models/gitBlameFileParams.ts @@ -0,0 +1,25 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitBlameFileParams = { +path: string; +/** + * @nullable + */ +rev?: string | null; +/** + * @minimum 0 + * @nullable + */ +start_line?: number | null; +/** + * @minimum 0 + * @nullable + */ +end_line?: number | null; +}; diff --git a/src/client/models/gitBlobInfoParams.ts b/src/client/models/gitBlobInfoParams.ts new file mode 100644 index 0000000..8a50628 --- /dev/null +++ b/src/client/models/gitBlobInfoParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitBlobInfoParams = { +/** + * @nullable + */ +path?: string | null; +}; diff --git a/src/client/models/gitCherryPickBody.ts b/src/client/models/gitCherryPickBody.ts new file mode 100644 index 0000000..fc8de9c --- /dev/null +++ b/src/client/models/gitCherryPickBody.ts @@ -0,0 +1,9 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitCherryPickBody = { [key: string]: unknown }; diff --git a/src/client/models/gitCommitHistoryParams.ts b/src/client/models/gitCommitHistoryParams.ts new file mode 100644 index 0000000..1074fca --- /dev/null +++ b/src/client/models/gitCommitHistoryParams.ts @@ -0,0 +1,28 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitCommitHistoryParams = { +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +/** + * @minimum 0 + * @nullable + */ +skip?: number | null; +/** + * @nullable + */ +sort?: number | null; +/** + * @nullable + */ +branch?: string | null; +}; diff --git a/src/client/models/gitCommitWalkBody.ts b/src/client/models/gitCommitWalkBody.ts new file mode 100644 index 0000000..66226bf --- /dev/null +++ b/src/client/models/gitCommitWalkBody.ts @@ -0,0 +1,9 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitCommitWalkBody = { [key: string]: unknown }; diff --git a/src/client/models/gitDeleteBranchParams.ts b/src/client/models/gitDeleteBranchParams.ts new file mode 100644 index 0000000..a96374a --- /dev/null +++ b/src/client/models/gitDeleteBranchParams.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitDeleteBranchParams = { +force?: boolean; +}; diff --git a/src/client/models/gitDeleteContentsParams.ts b/src/client/models/gitDeleteContentsParams.ts new file mode 100644 index 0000000..1346e07 --- /dev/null +++ b/src/client/models/gitDeleteContentsParams.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitDeleteContentsParams = { +message: string; +sha: string; +/** + * @nullable + */ +branch?: string | null; +}; diff --git a/src/client/models/gitDiffBranchesParams.ts b/src/client/models/gitDiffBranchesParams.ts new file mode 100644 index 0000000..67cb5d5 --- /dev/null +++ b/src/client/models/gitDiffBranchesParams.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitDiffBranchesParams = { +old_branch: string; +new_branch: string; +}; diff --git a/src/client/models/gitDiffParams.ts b/src/client/models/gitDiffParams.ts new file mode 100644 index 0000000..5b81c04 --- /dev/null +++ b/src/client/models/gitDiffParams.ts @@ -0,0 +1,35 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitDiffParams = { +/** + * @nullable + */ +old_oid?: string | null; +/** + * @nullable + */ +new_oid?: string | null; +/** + * @nullable + */ +old_tree?: string | null; +/** + * @nullable + */ +new_tree?: string | null; +/** + * @nullable + */ +tree_oid?: string | null; +mode?: string; +/** + * @nullable + */ +path?: string | null; +}; diff --git a/src/client/models/gitForkBranchBody.ts b/src/client/models/gitForkBranchBody.ts new file mode 100644 index 0000000..a75e444 --- /dev/null +++ b/src/client/models/gitForkBranchBody.ts @@ -0,0 +1,9 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitForkBranchBody = { [key: string]: unknown }; diff --git a/src/client/models/gitGetContentsParams.ts b/src/client/models/gitGetContentsParams.ts new file mode 100644 index 0000000..db4e67e --- /dev/null +++ b/src/client/models/gitGetContentsParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitGetContentsParams = { +/** + * @nullable + */ +ref?: string | null; +}; diff --git a/src/client/models/gitInitTagBody.ts b/src/client/models/gitInitTagBody.ts new file mode 100644 index 0000000..9d68e5a --- /dev/null +++ b/src/client/models/gitInitTagBody.ts @@ -0,0 +1,9 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitInitTagBody = { [key: string]: unknown }; diff --git a/src/client/models/gitListBranchesParams.ts b/src/client/models/gitListBranchesParams.ts new file mode 100644 index 0000000..c394d03 --- /dev/null +++ b/src/client/models/gitListBranchesParams.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListBranchesParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +summary?: boolean; +default_only?: boolean; +}; diff --git a/src/client/models/gitListCommitsParams.ts b/src/client/models/gitListCommitsParams.ts new file mode 100644 index 0000000..045eb9c --- /dev/null +++ b/src/client/models/gitListCommitsParams.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListCommitsParams = { +summary?: boolean; +refs?: boolean; +/** + * @nullable + */ +prefix?: string | null; +}; diff --git a/src/client/models/gitListContributorsParams.ts b/src/client/models/gitListContributorsParams.ts new file mode 100644 index 0000000..fc88f91 --- /dev/null +++ b/src/client/models/gitListContributorsParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListContributorsParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/gitListDeliveriesParams.ts b/src/client/models/gitListDeliveriesParams.ts new file mode 100644 index 0000000..17e6e9e --- /dev/null +++ b/src/client/models/gitListDeliveriesParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListDeliveriesParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/gitListForksParams.ts b/src/client/models/gitListForksParams.ts new file mode 100644 index 0000000..21ddb1f --- /dev/null +++ b/src/client/models/gitListForksParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListForksParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/gitListProtectsParams.ts b/src/client/models/gitListProtectsParams.ts new file mode 100644 index 0000000..c509de5 --- /dev/null +++ b/src/client/models/gitListProtectsParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListProtectsParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/gitListRefsParams.ts b/src/client/models/gitListRefsParams.ts new file mode 100644 index 0000000..872770a --- /dev/null +++ b/src/client/models/gitListRefsParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListRefsParams = { +/** + * @nullable + */ +ref?: string | null; +}; diff --git a/src/client/models/gitListReposParams.ts b/src/client/models/gitListReposParams.ts new file mode 100644 index 0000000..568ff3e --- /dev/null +++ b/src/client/models/gitListReposParams.ts @@ -0,0 +1,32 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListReposParams = { +/** + * @nullable + */ +visibility?: string | null; +/** + * @nullable + */ +is_archived?: boolean | null; +/** + * @nullable + */ +search?: string | null; +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/gitListTagsParams.ts b/src/client/models/gitListTagsParams.ts new file mode 100644 index 0000000..56017b0 --- /dev/null +++ b/src/client/models/gitListTagsParams.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListTagsParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +summary?: boolean; +}; diff --git a/src/client/models/gitListWebhooksParams.ts b/src/client/models/gitListWebhooksParams.ts new file mode 100644 index 0000000..240a8f1 --- /dev/null +++ b/src/client/models/gitListWebhooksParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitListWebhooksParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/gitRefResponse.ts b/src/client/models/gitRefResponse.ts new file mode 100644 index 0000000..1b634f7 --- /dev/null +++ b/src/client/models/gitRefResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface GitRefResponse { + is_default: boolean; + is_protected: boolean; + kind: string; + name: string; + target_sha: string; +} diff --git a/src/client/models/gitTreeEntriesParams.ts b/src/client/models/gitTreeEntriesParams.ts new file mode 100644 index 0000000..8ee4f69 --- /dev/null +++ b/src/client/models/gitTreeEntriesParams.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitTreeEntriesParams = { +/** + * @nullable + */ +path?: string | null; +last?: boolean; +resolve?: boolean; +}; diff --git a/src/client/models/gitTreeEntryByPathFromCommitParams.ts b/src/client/models/gitTreeEntryByPathFromCommitParams.ts new file mode 100644 index 0000000..c636f2a --- /dev/null +++ b/src/client/models/gitTreeEntryByPathFromCommitParams.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitTreeEntryByPathFromCommitParams = { +path: string; +}; diff --git a/src/client/models/gitTreeEntryByPathParams.ts b/src/client/models/gitTreeEntryByPathParams.ts new file mode 100644 index 0000000..e27e8f5 --- /dev/null +++ b/src/client/models/gitTreeEntryByPathParams.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitTreeEntryByPathParams = { +path: string; +}; diff --git a/src/client/models/gitUpdateTagBody.ts b/src/client/models/gitUpdateTagBody.ts new file mode 100644 index 0000000..714ba28 --- /dev/null +++ b/src/client/models/gitUpdateTagBody.ts @@ -0,0 +1,9 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitUpdateTagBody = { [key: string]: unknown }; diff --git a/src/client/models/gitWatchRepoBody.ts b/src/client/models/gitWatchRepoBody.ts new file mode 100644 index 0000000..eaad848 --- /dev/null +++ b/src/client/models/gitWatchRepoBody.ts @@ -0,0 +1,9 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type GitWatchRepoBody = { [key: string]: unknown }; diff --git a/src/client/models/index.ts b/src/client/models/index.ts new file mode 100644 index 0000000..2cb07c7 --- /dev/null +++ b/src/client/models/index.ts @@ -0,0 +1,271 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export * from './accessRequest'; +export * from './addIssueLabel'; +export * from './addPrLabel'; +export * from './addPrReaction'; +export * from './addReaction'; +export * from './addWorkspaceMember'; +export * from './agentCostInfo'; +export * from './agentListAllConversationsParams'; +export * from './agentListMessagesParams'; +export * from './agentRunRequest'; +export * from './agentRunResponse'; +export * from './agentSessionResponse'; +export * from './agentStepInfo'; +export * from './agentToolCallInfo'; +export * from './agentUsageInfo'; +export * from './aiAddRequest'; +export * from './aiDiscussionResponse'; +export * from './aiLikeResponse'; +export * from './aiListDiscussionsParams'; +export * from './aiListModelsParams'; +export * from './aiModelCardResponse'; +export * from './aiModelListItem'; +export * from './aiModelResponse'; +export * from './aiModelVersionResponse'; +export * from './aiProviderResponse'; +export * from './appNotificationItem'; +export * from './assignIssueUser'; +export * from './assignPrUser'; +export * from './authCaptchaParams'; +export * from './avatarUploadResponse'; +export * from './banCreateRequest'; +export * from './bindIssuePullRequest'; +export * from './bindIssueRepo'; +export * from './blameFileResponseDto'; +export * from './blobInfoResponse'; +export * from './blobLoadResponseDto'; +export * from './blobUploadBody'; +export * from './blobUploadResponseDto'; +export * from './branchAheadBehindResponseDto'; +export * from './branchInfoResponseDto'; +export * from './branchListItemDto'; +export * from './branchListResponseDto'; +export * from './branchUpstreamResponseDto'; +export * from './captchaQuery'; +export * from './captchaResponse'; +export * from './categoryCreateRequest'; +export * from './categoryUpdateRequest'; +export * from './channelListMessagesParams'; +export * from './channelMessagesAroundParams'; +export * from './channelMissedMessagesParams'; +export * from './channelSearchParams'; +export * from './cherryPickResponseDto'; +export * from './cloneRepo'; +export * from './combinedCommitStatus'; +export * from './commitBlameHunkDto'; +export * from './commitHistoryResponseDto'; +export * from './commitInfoResponseDto'; +export * from './commitMetaDto'; +export * from './commitSignatureDto'; +export * from './commitStatusResponse'; +export * from './compareCommit'; +export * from './compareResponse'; +export * from './contentResponse'; +export * from './contextMe'; +export * from './contributionHeatmapItem'; +export * from './contributionHeatmapResponse'; +export * from './contributorDto'; +export * from './conversationResponse'; +export * from './conversationWithSessionResponse'; +export * from './createAgentSession'; +export * from './createComment'; +export * from './createCommitStatus'; +export * from './createContent'; +export * from './createConversation'; +export * from './createdUserAccessToken'; +export * from './createFork'; +export * from './createIssue'; +export * from './createLabel'; +export * from './createMessageRequest'; +export * from './createMilestone'; +export * from './createPrComment'; +export * from './createProtect'; +export * from './createPrReview'; +export * from './createPrReviewComment'; +export * from './createPullRequest'; +export * from './createRelease'; +export * from './createRepo'; +export * from './createUserAccessToken'; +export * from './createUserSshKey'; +export * from './createWebhook'; +export * from './createWorkspace'; +export * from './createWorkspaceGroup'; +export * from './createWorkspaceJoinApply'; +export * from './customStatusRequest'; +export * from './diffDeltaDto'; +export * from './diffFileDto'; +export * from './diffHunkDto'; +export * from './diffLineDto'; +export * from './diffResultDto'; +export * from './diffStatsDto'; +export * from './disable2FAParams'; +export * from './dismissPrReview'; +export * from './dndRequest'; +export * from './draftSaveRequest'; +export * from './emailChangeRequest'; +export * from './emailResponse'; +export * from './emailVerifyRequest'; +export * from './enable2FAResponse'; +export * from './forkResponse'; +export * from './get2FAStatusResponse'; +export * from './gitAheadBehindParams'; +export * from './gitArchiveParams'; +export * from './gitBlameFileParams'; +export * from './gitBlobInfoParams'; +export * from './gitCherryPickBody'; +export * from './gitCommitHistoryParams'; +export * from './gitCommitWalkBody'; +export * from './gitDeleteBranchParams'; +export * from './gitDeleteContentsParams'; +export * from './gitDiffBranchesParams'; +export * from './gitDiffParams'; +export * from './gitForkBranchBody'; +export * from './gitGetContentsParams'; +export * from './gitInitTagBody'; +export * from './gitListBranchesParams'; +export * from './gitListCommitsParams'; +export * from './gitListContributorsParams'; +export * from './gitListDeliveriesParams'; +export * from './gitListForksParams'; +export * from './gitListProtectsParams'; +export * from './gitListRefsParams'; +export * from './gitListReposParams'; +export * from './gitListTagsParams'; +export * from './gitListWebhooksParams'; +export * from './gitRefResponse'; +export * from './gitTreeEntriesParams'; +export * from './gitTreeEntryByPathFromCommitParams'; +export * from './gitTreeEntryByPathParams'; +export * from './gitUpdateTagBody'; +export * from './gitWatchRepoBody'; +export * from './inviteAcceptRequest'; +export * from './inviteCreateRequest'; +export * from './issueAuthor'; +export * from './issueCommentResponse'; +export * from './issueEventResponse'; +export * from './issuePullRequestResponse'; +export * from './issueReactionResponse'; +export * from './issueRepoResponse'; +export * from './issueResponse'; +export * from './issuesListIssuesParams'; +export * from './labelResponse'; +export * from './languageStatDto'; +export * from './loginParams'; +export * from './mergePullRequest'; +export * from './messageResponse'; +export * from './milestoneResponse'; +export * from './notificationMarkAllReadRequest'; +export * from './pinRequest'; +export * from './presenceUpdateRequest'; +export * from './protectResponse'; +export * from './publicUserResponse'; +export * from './pullRequestCommentResponse'; +export * from './pullRequestListPrsParams'; +export * from './pullRequestReactionResponse'; +export * from './pullRequestResponse'; +export * from './pullRequestReviewCommentResponse'; +export * from './pullRequestReviewResponse'; +export * from './reactionRequest'; +export * from './readmeDto'; +export * from './readReceiptRequest'; +export * from './registerParams'; +export * from './releaseAssetResponse'; +export * from './releaseResponse'; +export * from './renameBranchBody'; +export * from './repoResponse'; +export * from './resetPasswordRequest'; +export * from './resetPasswordVerifyParams'; +export * from './roomCreateRequest'; +export * from './roomUpdateRequest'; +export * from './rsaResponse'; +export * from './screenShareRequest'; +export * from './searchGroupIssueHit'; +export * from './searchGroupIssueHitItemsItem'; +export * from './searchGroupRepoHit'; +export * from './searchGroupRepoHitItemsItem'; +export * from './searchGroupRoomHit'; +export * from './searchGroupRoomHitItemsItem'; +export * from './searchGroupWorkspaceHit'; +export * from './searchGroupWorkspaceHitItemsItem'; +export * from './searchParams'; +export * from './searchResponse'; +export * from './setIssueMilestone'; +export * from './tagInfoResponseDto'; +export * from './tagInitResponseDto'; +export * from './tagItemDto'; +export * from './tagListResponseDto'; +export * from './tagSummaryDto'; +export * from './tagSummaryResponseDto'; +export * from './threadCreateRequest'; +export * from './tokenRequest'; +export * from './tokenResponse'; +export * from './toolCallResponse'; +export * from './transferRepo'; +export * from './treeEntriesResponseDto'; +export * from './treeEntryByPathResponseDto'; +export * from './treeEntryDto'; +export * from './treeKindDto'; +export * from './typingAction'; +export * from './typingRequest'; +export * from './updateAgentSession'; +export * from './updateComment'; +export * from './updateContent'; +export * from './updateConversation'; +export * from './updateIssue'; +export * from './updateLabel'; +export * from './updateMessageRequest'; +export * from './updateMilestone'; +export * from './updatePrComment'; +export * from './updateProtect'; +export * from './updatePullRequest'; +export * from './updateRelease'; +export * from './updateRepo'; +export * from './updateUserAccessibilityConfig'; +export * from './updateUserAccessToken'; +export * from './updateUserAppearanceConfig'; +export * from './updateUserNotificationConfig'; +export * from './updateUserPrivacyConfig'; +export * from './updateUserProfileConfig'; +export * from './updateUserSshKey'; +export * from './updateWebhook'; +export * from './updateWorkspace'; +export * from './updateWorkspaceGroup'; +export * from './updateWorkspaceJoinStrategy'; +export * from './updateWorkspaceMember'; +export * from './userAccessibilityConfig'; +export * from './userAccessToken'; +export * from './userAppearanceConfig'; +export * from './userConfigResponse'; +export * from './userContributionHeatmapParams'; +export * from './userNotificationConfig'; +export * from './userPrivacyConfig'; +export * from './userProfileConfig'; +export * from './userRelationCard'; +export * from './userRelationCounts'; +export * from './userRelationStatus'; +export * from './usersFollowersParams'; +export * from './usersFollowingParams'; +export * from './userSshKey'; +export * from './userSummaryResponse'; +export * from './usersUserChpcParams'; +export * from './verify2FAParams'; +export * from './voiceDeafRequest'; +export * from './voiceMuteRequest'; +export * from './webhookDeliveryResponse'; +export * from './webhookResponse'; +export * from './workspaceGroupResponse'; +export * from './workspaceJoinApplyResponse'; +export * from './workspaceJoinApprovalResponse'; +export * from './workspaceJoinStrategyResponse'; +export * from './workspaceListJoinAppliesParams'; +export * from './workspaceListMembersParams'; +export * from './workspaceMemberResponse'; +export * from './workspaceResponse'; diff --git a/src/client/models/inviteAcceptRequest.ts b/src/client/models/inviteAcceptRequest.ts new file mode 100644 index 0000000..b95cdd0 --- /dev/null +++ b/src/client/models/inviteAcceptRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface InviteAcceptRequest { + code: string; +} diff --git a/src/client/models/inviteCreateRequest.ts b/src/client/models/inviteCreateRequest.ts new file mode 100644 index 0000000..90592fc --- /dev/null +++ b/src/client/models/inviteCreateRequest.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface InviteCreateRequest { + /** @nullable */ + expires_at?: string | null; + /** @nullable */ + max_uses?: number | null; + /** @nullable */ + room?: string | null; + workspace: string; +} diff --git a/src/client/models/issueAuthor.ts b/src/client/models/issueAuthor.ts new file mode 100644 index 0000000..5cd61d7 --- /dev/null +++ b/src/client/models/issueAuthor.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface IssueAuthor { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + display_name?: string | null; + username: string; +} diff --git a/src/client/models/issueCommentResponse.ts b/src/client/models/issueCommentResponse.ts new file mode 100644 index 0000000..da1df09 --- /dev/null +++ b/src/client/models/issueCommentResponse.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface IssueCommentResponse { + author: IssueAuthor; + body: string; + created_at: string; + id: string; + updated_at: string; +} diff --git a/src/client/models/issueEventResponse.ts b/src/client/models/issueEventResponse.ts new file mode 100644 index 0000000..938912f --- /dev/null +++ b/src/client/models/issueEventResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface IssueEventResponse { + actor?: null | IssueAuthor; + created_at: string; + event: string; + /** @nullable */ + from_value?: string | null; + /** @nullable */ + to_value?: string | null; +} diff --git a/src/client/models/issuePullRequestResponse.ts b/src/client/models/issuePullRequestResponse.ts new file mode 100644 index 0000000..4109f86 --- /dev/null +++ b/src/client/models/issuePullRequestResponse.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface IssuePullRequestResponse { + id: string; + number: number; + state: string; + title: string; +} diff --git a/src/client/models/issueReactionResponse.ts b/src/client/models/issueReactionResponse.ts new file mode 100644 index 0000000..fde2355 --- /dev/null +++ b/src/client/models/issueReactionResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface IssueReactionResponse { + created_at: string; + id: string; + reaction: string; + user: IssueAuthor; +} diff --git a/src/client/models/issueRepoResponse.ts b/src/client/models/issueRepoResponse.ts new file mode 100644 index 0000000..ba2ba0a --- /dev/null +++ b/src/client/models/issueRepoResponse.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface IssueRepoResponse { + id: string; + name: string; +} diff --git a/src/client/models/issueResponse.ts b/src/client/models/issueResponse.ts new file mode 100644 index 0000000..ae138f3 --- /dev/null +++ b/src/client/models/issueResponse.ts @@ -0,0 +1,34 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; +import type { IssuePullRequestResponse } from './issuePullRequestResponse'; +import type { IssueRepoResponse } from './issueRepoResponse'; +import type { LabelResponse } from './labelResponse'; +import type { MilestoneResponse } from './milestoneResponse'; + +export interface IssueResponse { + assignees: IssueAuthor[]; + author: IssueAuthor; + /** @nullable */ + body?: string | null; + /** @nullable */ + closed_at?: string | null; + closed_by?: null | IssueAuthor; + created_at: string; + /** @nullable */ + due_at?: string | null; + labels: LabelResponse[]; + milestone?: null | MilestoneResponse; + number: number; + priority: string; + pull_requests: IssuePullRequestResponse[]; + repos: IssueRepoResponse[]; + state: string; + title: string; + updated_at: string; +} diff --git a/src/client/models/issuesListIssuesParams.ts b/src/client/models/issuesListIssuesParams.ts new file mode 100644 index 0000000..103abb4 --- /dev/null +++ b/src/client/models/issuesListIssuesParams.ts @@ -0,0 +1,40 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type IssuesListIssuesParams = { +/** + * @nullable + */ +state?: string | null; +/** + * @nullable + */ +label?: string | null; +/** + * @nullable + */ +milestone?: string | null; +/** + * @nullable + */ +assignee?: string | null; +/** + * @nullable + */ +priority?: string | null; +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/labelResponse.ts b/src/client/models/labelResponse.ts new file mode 100644 index 0000000..55e6af1 --- /dev/null +++ b/src/client/models/labelResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface LabelResponse { + color: string; + /** @nullable */ + description?: string | null; + id: string; + name: string; +} diff --git a/src/client/models/languageStatDto.ts b/src/client/models/languageStatDto.ts new file mode 100644 index 0000000..d37e0cc --- /dev/null +++ b/src/client/models/languageStatDto.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface LanguageStatDto { + bytes: number; + language: string; + percentage: number; +} diff --git a/src/client/models/loginParams.ts b/src/client/models/loginParams.ts new file mode 100644 index 0000000..698da02 --- /dev/null +++ b/src/client/models/loginParams.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface LoginParams { + captcha: string; + password: string; + /** @nullable */ + totp_code?: string | null; + username: string; +} diff --git a/src/client/models/mergePullRequest.ts b/src/client/models/mergePullRequest.ts new file mode 100644 index 0000000..4b3c098 --- /dev/null +++ b/src/client/models/mergePullRequest.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface MergePullRequest { + /** @nullable */ + commit_message?: string | null; + /** @nullable */ + commit_title?: string | null; + /** @nullable */ + method?: string | null; +} diff --git a/src/client/models/messageResponse.ts b/src/client/models/messageResponse.ts new file mode 100644 index 0000000..c04f485 --- /dev/null +++ b/src/client/models/messageResponse.ts @@ -0,0 +1,28 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { ToolCallResponse } from './toolCallResponse'; + +export interface MessageResponse { + /** @nullable */ + author?: string | null; + content: string; + content_type: string; + conversation_id: string; + created_at: string; + id: string; + /** @nullable */ + model_invocation?: string | null; + /** @nullable */ + parent_id?: string | null; + /** @nullable */ + reasoning_content?: string | null; + role: string; + status: string; + tool_calls?: ToolCallResponse[]; + updated_at: string; +} diff --git a/src/client/models/milestoneResponse.ts b/src/client/models/milestoneResponse.ts new file mode 100644 index 0000000..b3784e2 --- /dev/null +++ b/src/client/models/milestoneResponse.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface MilestoneResponse { + /** @nullable */ + description?: string | null; + /** @nullable */ + due_at?: string | null; + id: string; + state: string; + title: string; +} diff --git a/src/client/models/notificationMarkAllReadRequest.ts b/src/client/models/notificationMarkAllReadRequest.ts new file mode 100644 index 0000000..1409ea2 --- /dev/null +++ b/src/client/models/notificationMarkAllReadRequest.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface NotificationMarkAllReadRequest { + /** @nullable */ + workspace_id?: string | null; +} diff --git a/src/client/models/pinRequest.ts b/src/client/models/pinRequest.ts new file mode 100644 index 0000000..765ecac --- /dev/null +++ b/src/client/models/pinRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface PinRequest { + message: string; +} diff --git a/src/client/models/presenceUpdateRequest.ts b/src/client/models/presenceUpdateRequest.ts new file mode 100644 index 0000000..c82ee2b --- /dev/null +++ b/src/client/models/presenceUpdateRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface PresenceUpdateRequest { + status: string; +} diff --git a/src/client/models/protectResponse.ts b/src/client/models/protectResponse.ts new file mode 100644 index 0000000..aa24c05 --- /dev/null +++ b/src/client/models/protectResponse.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ProtectResponse { + allow_deletions: boolean; + allow_force_pushes: boolean; + created_at: string; + enforce_admins: boolean; + id: string; + pattern: string; + repo: string; + require_pull_request: boolean; + require_status_checks: boolean; + required_approvals: number; + required_status_contexts: string[]; + updated_at: string; +} diff --git a/src/client/models/publicUserResponse.ts b/src/client/models/publicUserResponse.ts new file mode 100644 index 0000000..4ac4c64 --- /dev/null +++ b/src/client/models/publicUserResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface PublicUserResponse { + allow_direct_messages: boolean; + avatar_url: string; + display_name: string; + language: string; + show_online_status: boolean; + timezone: string; + username: string; + website_url: string; +} diff --git a/src/client/models/pullRequestCommentResponse.ts b/src/client/models/pullRequestCommentResponse.ts new file mode 100644 index 0000000..d5f654b --- /dev/null +++ b/src/client/models/pullRequestCommentResponse.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface PullRequestCommentResponse { + author: IssueAuthor; + body: string; + created_at: string; + id: string; + updated_at: string; +} diff --git a/src/client/models/pullRequestListPrsParams.ts b/src/client/models/pullRequestListPrsParams.ts new file mode 100644 index 0000000..c4c198a --- /dev/null +++ b/src/client/models/pullRequestListPrsParams.ts @@ -0,0 +1,36 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type PullRequestListPrsParams = { +/** + * @nullable + */ +state?: string | null; +/** + * @nullable + */ +author?: string | null; +/** + * @nullable + */ +assignee?: string | null; +/** + * @nullable + */ +label?: string | null; +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/pullRequestReactionResponse.ts b/src/client/models/pullRequestReactionResponse.ts new file mode 100644 index 0000000..89f75da --- /dev/null +++ b/src/client/models/pullRequestReactionResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface PullRequestReactionResponse { + created_at: string; + id: string; + reaction: string; + user: IssueAuthor; +} diff --git a/src/client/models/pullRequestResponse.ts b/src/client/models/pullRequestResponse.ts new file mode 100644 index 0000000..0afbc2b --- /dev/null +++ b/src/client/models/pullRequestResponse.ts @@ -0,0 +1,36 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; +import type { LabelResponse } from './labelResponse'; +import type { PullRequestReviewResponse } from './pullRequestReviewResponse'; + +export interface PullRequestResponse { + assignees: IssueAuthor[]; + author: IssueAuthor; + /** @nullable */ + body?: string | null; + /** @nullable */ + closed_at?: string | null; + closed_by?: null | IssueAuthor; + created_at: string; + draft: boolean; + labels: LabelResponse[]; + /** @nullable */ + merged_at?: string | null; + merged_by?: null | IssueAuthor; + number: number; + reviews: PullRequestReviewResponse[]; + source_branch: string; + source_repo: string; + source_sha: string; + state: string; + target_branch: string; + target_sha: string; + title: string; + updated_at: string; +} diff --git a/src/client/models/pullRequestReviewCommentResponse.ts b/src/client/models/pullRequestReviewCommentResponse.ts new file mode 100644 index 0000000..b87a309 --- /dev/null +++ b/src/client/models/pullRequestReviewCommentResponse.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface PullRequestReviewCommentResponse { + author: IssueAuthor; + body: string; + commit_sha: string; + created_at: string; + id: string; + /** @nullable */ + line?: number | null; + path: string; + resolved: boolean; + updated_at: string; +} diff --git a/src/client/models/pullRequestReviewResponse.ts b/src/client/models/pullRequestReviewResponse.ts new file mode 100644 index 0000000..2b3512d --- /dev/null +++ b/src/client/models/pullRequestReviewResponse.ts @@ -0,0 +1,21 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { IssueAuthor } from './issueAuthor'; + +export interface PullRequestReviewResponse { + /** @nullable */ + body?: string | null; + /** @nullable */ + commit_sha?: string | null; + created_at: string; + id: string; + reviewer: IssueAuthor; + state: string; + /** @nullable */ + submitted_at?: string | null; +} diff --git a/src/client/models/reactionRequest.ts b/src/client/models/reactionRequest.ts new file mode 100644 index 0000000..8a162c6 --- /dev/null +++ b/src/client/models/reactionRequest.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ReactionRequest { + emoji: string; + message: string; +} diff --git a/src/client/models/readReceiptRequest.ts b/src/client/models/readReceiptRequest.ts new file mode 100644 index 0000000..966b4b8 --- /dev/null +++ b/src/client/models/readReceiptRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ReadReceiptRequest { + last_read_seq: number; +} diff --git a/src/client/models/readmeDto.ts b/src/client/models/readmeDto.ts new file mode 100644 index 0000000..5d97012 --- /dev/null +++ b/src/client/models/readmeDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ReadmeDto { + content: string; + html: string; +} diff --git a/src/client/models/registerParams.ts b/src/client/models/registerParams.ts new file mode 100644 index 0000000..082bf47 --- /dev/null +++ b/src/client/models/registerParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface RegisterParams { + captcha: string; + email: string; + password: string; + username: string; +} diff --git a/src/client/models/releaseAssetResponse.ts b/src/client/models/releaseAssetResponse.ts new file mode 100644 index 0000000..2e618da --- /dev/null +++ b/src/client/models/releaseAssetResponse.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ReleaseAssetResponse { + /** @nullable */ + content_type?: string | null; + created_at: string; + download_count: number; + id: string; + name: string; + size: number; +} diff --git a/src/client/models/releaseResponse.ts b/src/client/models/releaseResponse.ts new file mode 100644 index 0000000..380b7ab --- /dev/null +++ b/src/client/models/releaseResponse.ts @@ -0,0 +1,24 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { ReleaseAssetResponse } from './releaseAssetResponse'; + +export interface ReleaseResponse { + assets: ReleaseAssetResponse[]; + author: string; + /** @nullable */ + body?: string | null; + created_at: string; + draft: boolean; + id: string; + name: string; + prerelease: boolean; + /** @nullable */ + published_at?: string | null; + tag_name: string; + target_commit_sha: string; +} diff --git a/src/client/models/renameBranchBody.ts b/src/client/models/renameBranchBody.ts new file mode 100644 index 0000000..391023d --- /dev/null +++ b/src/client/models/renameBranchBody.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface RenameBranchBody { + force?: boolean; + new_branch: string; +} diff --git a/src/client/models/repoResponse.ts b/src/client/models/repoResponse.ts new file mode 100644 index 0000000..be0a11f --- /dev/null +++ b/src/client/models/repoResponse.ts @@ -0,0 +1,23 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface RepoResponse { + created_at: string; + created_by: string; + default_branch: string; + /** @nullable */ + description?: string | null; + id: string; + is_archived: boolean; + is_mirror: boolean; + is_template: boolean; + name: string; + size_bytes: number; + updated_at: string; + visibility: string; +} diff --git a/src/client/models/resetPasswordRequest.ts b/src/client/models/resetPasswordRequest.ts new file mode 100644 index 0000000..8e735c4 --- /dev/null +++ b/src/client/models/resetPasswordRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ResetPasswordRequest { + email: string; +} diff --git a/src/client/models/resetPasswordVerifyParams.ts b/src/client/models/resetPasswordVerifyParams.ts new file mode 100644 index 0000000..65b4dd2 --- /dev/null +++ b/src/client/models/resetPasswordVerifyParams.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ResetPasswordVerifyParams { + password: string; + token: string; +} diff --git a/src/client/models/roomCreateRequest.ts b/src/client/models/roomCreateRequest.ts new file mode 100644 index 0000000..7e5bc7a --- /dev/null +++ b/src/client/models/roomCreateRequest.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface RoomCreateRequest { + /** @nullable */ + category?: string | null; + public: boolean; + room_name: string; + workspace: string; +} diff --git a/src/client/models/roomUpdateRequest.ts b/src/client/models/roomUpdateRequest.ts new file mode 100644 index 0000000..ca3ff9a --- /dev/null +++ b/src/client/models/roomUpdateRequest.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface RoomUpdateRequest { + /** @nullable */ + category?: string | null; + /** @nullable */ + public?: boolean | null; + /** @nullable */ + room_name?: string | null; +} diff --git a/src/client/models/rsaResponse.ts b/src/client/models/rsaResponse.ts new file mode 100644 index 0000000..87609f9 --- /dev/null +++ b/src/client/models/rsaResponse.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface RsaResponse { + public_key: string; +} diff --git a/src/client/models/screenShareRequest.ts b/src/client/models/screenShareRequest.ts new file mode 100644 index 0000000..cbb212a --- /dev/null +++ b/src/client/models/screenShareRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ScreenShareRequest { + start: boolean; +} diff --git a/src/client/models/searchGroupIssueHit.ts b/src/client/models/searchGroupIssueHit.ts new file mode 100644 index 0000000..8d4f088 --- /dev/null +++ b/src/client/models/searchGroupIssueHit.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { SearchGroupIssueHitItemsItem } from './searchGroupIssueHitItemsItem'; + +export interface SearchGroupIssueHit { + has_more: boolean; + items: SearchGroupIssueHitItemsItem[]; + /** @minimum 0 */ + total: number; +} diff --git a/src/client/models/searchGroupIssueHitItemsItem.ts b/src/client/models/searchGroupIssueHitItemsItem.ts new file mode 100644 index 0000000..b0cfcab --- /dev/null +++ b/src/client/models/searchGroupIssueHitItemsItem.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type SearchGroupIssueHitItemsItem = { + number: number; + state: string; + title: string; + workspace: string; +}; diff --git a/src/client/models/searchGroupRepoHit.ts b/src/client/models/searchGroupRepoHit.ts new file mode 100644 index 0000000..bc5d5cb --- /dev/null +++ b/src/client/models/searchGroupRepoHit.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { SearchGroupRepoHitItemsItem } from './searchGroupRepoHitItemsItem'; + +export interface SearchGroupRepoHit { + has_more: boolean; + items: SearchGroupRepoHitItemsItem[]; + /** @minimum 0 */ + total: number; +} diff --git a/src/client/models/searchGroupRepoHitItemsItem.ts b/src/client/models/searchGroupRepoHitItemsItem.ts new file mode 100644 index 0000000..0c118d8 --- /dev/null +++ b/src/client/models/searchGroupRepoHitItemsItem.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type SearchGroupRepoHitItemsItem = { + /** @nullable */ + description?: string | null; + name: string; + workspace: string; +}; diff --git a/src/client/models/searchGroupRoomHit.ts b/src/client/models/searchGroupRoomHit.ts new file mode 100644 index 0000000..127336a --- /dev/null +++ b/src/client/models/searchGroupRoomHit.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { SearchGroupRoomHitItemsItem } from './searchGroupRoomHitItemsItem'; + +export interface SearchGroupRoomHit { + has_more: boolean; + items: SearchGroupRoomHitItemsItem[]; + /** @minimum 0 */ + total: number; +} diff --git a/src/client/models/searchGroupRoomHitItemsItem.ts b/src/client/models/searchGroupRoomHitItemsItem.ts new file mode 100644 index 0000000..a85889b --- /dev/null +++ b/src/client/models/searchGroupRoomHitItemsItem.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type SearchGroupRoomHitItemsItem = { + id: string; + name: string; + workspace: string; +}; diff --git a/src/client/models/searchGroupWorkspaceHit.ts b/src/client/models/searchGroupWorkspaceHit.ts new file mode 100644 index 0000000..e5bf225 --- /dev/null +++ b/src/client/models/searchGroupWorkspaceHit.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { SearchGroupWorkspaceHitItemsItem } from './searchGroupWorkspaceHitItemsItem'; + +export interface SearchGroupWorkspaceHit { + has_more: boolean; + items: SearchGroupWorkspaceHitItemsItem[]; + /** @minimum 0 */ + total: number; +} diff --git a/src/client/models/searchGroupWorkspaceHitItemsItem.ts b/src/client/models/searchGroupWorkspaceHitItemsItem.ts new file mode 100644 index 0000000..83c2820 --- /dev/null +++ b/src/client/models/searchGroupWorkspaceHitItemsItem.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type SearchGroupWorkspaceHitItemsItem = { + /** @nullable */ + description?: string | null; + name: string; +}; diff --git a/src/client/models/searchParams.ts b/src/client/models/searchParams.ts new file mode 100644 index 0000000..6a13e74 --- /dev/null +++ b/src/client/models/searchParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type SearchParams = { +/** + * @nullable + */ +q?: string | null; +}; diff --git a/src/client/models/searchResponse.ts b/src/client/models/searchResponse.ts new file mode 100644 index 0000000..47b8f0a --- /dev/null +++ b/src/client/models/searchResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { SearchGroupIssueHit } from './searchGroupIssueHit'; +import type { SearchGroupRepoHit } from './searchGroupRepoHit'; +import type { SearchGroupRoomHit } from './searchGroupRoomHit'; +import type { SearchGroupWorkspaceHit } from './searchGroupWorkspaceHit'; + +export interface SearchResponse { + issues: SearchGroupIssueHit; + repos: SearchGroupRepoHit; + rooms: SearchGroupRoomHit; + workspaces: SearchGroupWorkspaceHit; +} diff --git a/src/client/models/setIssueMilestone.ts b/src/client/models/setIssueMilestone.ts new file mode 100644 index 0000000..170c993 --- /dev/null +++ b/src/client/models/setIssueMilestone.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface SetIssueMilestone { + milestone_id: string; +} diff --git a/src/client/models/tagInfoResponseDto.ts b/src/client/models/tagInfoResponseDto.ts new file mode 100644 index 0000000..b6a1d2b --- /dev/null +++ b/src/client/models/tagInfoResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TagItemDto } from './tagItemDto'; + +export interface TagInfoResponseDto { + tag?: null | TagItemDto; +} diff --git a/src/client/models/tagInitResponseDto.ts b/src/client/models/tagInitResponseDto.ts new file mode 100644 index 0000000..c925e60 --- /dev/null +++ b/src/client/models/tagInitResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface TagInitResponseDto { + /** @nullable */ + oid?: string | null; +} diff --git a/src/client/models/tagItemDto.ts b/src/client/models/tagItemDto.ts new file mode 100644 index 0000000..faa00f4 --- /dev/null +++ b/src/client/models/tagItemDto.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface TagItemDto { + is_annotated: boolean; + /** @nullable */ + message?: string | null; + name: string; + oid: string; + /** @nullable */ + tagger?: string | null; + /** @nullable */ + tagger_email?: string | null; + target: string; +} diff --git a/src/client/models/tagListResponseDto.ts b/src/client/models/tagListResponseDto.ts new file mode 100644 index 0000000..7e09d82 --- /dev/null +++ b/src/client/models/tagListResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TagItemDto } from './tagItemDto'; + +export interface TagListResponseDto { + tags: TagItemDto[]; +} diff --git a/src/client/models/tagSummaryDto.ts b/src/client/models/tagSummaryDto.ts new file mode 100644 index 0000000..c449c3c --- /dev/null +++ b/src/client/models/tagSummaryDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface TagSummaryDto { + /** @minimum 0 */ + total_count: number; +} diff --git a/src/client/models/tagSummaryResponseDto.ts b/src/client/models/tagSummaryResponseDto.ts new file mode 100644 index 0000000..831cb56 --- /dev/null +++ b/src/client/models/tagSummaryResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TagSummaryDto } from './tagSummaryDto'; + +export interface TagSummaryResponseDto { + summary?: null | TagSummaryDto; +} diff --git a/src/client/models/threadCreateRequest.ts b/src/client/models/threadCreateRequest.ts new file mode 100644 index 0000000..05098ee --- /dev/null +++ b/src/client/models/threadCreateRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ThreadCreateRequest { + parent: number; +} diff --git a/src/client/models/tokenRequest.ts b/src/client/models/tokenRequest.ts new file mode 100644 index 0000000..a67aa14 --- /dev/null +++ b/src/client/models/tokenRequest.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface TokenRequest { + client_id: string; + device_id: string; +} diff --git a/src/client/models/tokenResponse.ts b/src/client/models/tokenResponse.ts new file mode 100644 index 0000000..de02d72 --- /dev/null +++ b/src/client/models/tokenResponse.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface TokenResponse { + access_token: string; + /** @minimum 0 */ + expires_in_secs: number; +} diff --git a/src/client/models/toolCallResponse.ts b/src/client/models/toolCallResponse.ts new file mode 100644 index 0000000..f4c61e9 --- /dev/null +++ b/src/client/models/toolCallResponse.ts @@ -0,0 +1,19 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface ToolCallResponse { + arguments: unknown; + /** @nullable */ + elapsed_ms?: number | null; + /** @nullable */ + error?: string | null; + id: string; + name: string; + output?: unknown; + status: string; +} diff --git a/src/client/models/transferRepo.ts b/src/client/models/transferRepo.ts new file mode 100644 index 0000000..95ab3c6 --- /dev/null +++ b/src/client/models/transferRepo.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface TransferRepo { + target_workspace: string; +} diff --git a/src/client/models/treeEntriesResponseDto.ts b/src/client/models/treeEntriesResponseDto.ts new file mode 100644 index 0000000..33049b4 --- /dev/null +++ b/src/client/models/treeEntriesResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TreeEntryDto } from './treeEntryDto'; + +export interface TreeEntriesResponseDto { + entries: TreeEntryDto[]; +} diff --git a/src/client/models/treeEntryByPathResponseDto.ts b/src/client/models/treeEntryByPathResponseDto.ts new file mode 100644 index 0000000..20e3b06 --- /dev/null +++ b/src/client/models/treeEntryByPathResponseDto.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TreeEntryDto } from './treeEntryDto'; + +export interface TreeEntryByPathResponseDto { + entry?: null | TreeEntryDto; +} diff --git a/src/client/models/treeEntryDto.ts b/src/client/models/treeEntryDto.ts new file mode 100644 index 0000000..c81ef57 --- /dev/null +++ b/src/client/models/treeEntryDto.ts @@ -0,0 +1,26 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TreeKindDto } from './treeKindDto'; + +export interface TreeEntryDto { + /** @minimum 0 */ + filemode: number; + is_binary: boolean; + is_lfs: boolean; + kind: TreeKindDto; + /** @nullable */ + last_commit_author_email?: string | null; + /** @nullable */ + last_commit_author_name?: string | null; + /** @nullable */ + last_commit_message?: string | null; + /** @nullable */ + last_commit_time?: string | null; + name: string; + oid: string; +} diff --git a/src/client/models/treeKindDto.ts b/src/client/models/treeKindDto.ts new file mode 100644 index 0000000..40021a7 --- /dev/null +++ b/src/client/models/treeKindDto.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type TreeKindDto = typeof TreeKindDto[keyof typeof TreeKindDto]; + + +export const TreeKindDto = { + blob: 'blob', + tree: 'tree', + lfs_pointer: 'lfs_pointer', +} as const; diff --git a/src/client/models/typingAction.ts b/src/client/models/typingAction.ts new file mode 100644 index 0000000..8d63b9f --- /dev/null +++ b/src/client/models/typingAction.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type TypingAction = typeof TypingAction[keyof typeof TypingAction]; + + +export const TypingAction = { + Start: 'Start', + Stop: 'Stop', +} as const; diff --git a/src/client/models/typingRequest.ts b/src/client/models/typingRequest.ts new file mode 100644 index 0000000..e7dcdba --- /dev/null +++ b/src/client/models/typingRequest.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { TypingAction } from './typingAction'; + +export interface TypingRequest { + action: TypingAction; +} diff --git a/src/client/models/updateAgentSession.ts b/src/client/models/updateAgentSession.ts new file mode 100644 index 0000000..bb8ff91 --- /dev/null +++ b/src/client/models/updateAgentSession.ts @@ -0,0 +1,40 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateAgentSession { + /** @nullable */ + description?: string | null; + /** @nullable */ + enabled?: boolean | null; + /** @nullable */ + iteration_budget?: number | null; + /** @nullable */ + knowledge_base_ids?: string[] | null; + /** @nullable */ + max_output_tokens?: number | null; + /** @nullable */ + memory_provider?: string | null; + /** @nullable */ + memory_provider_config?: string | null; + /** @nullable */ + model_version?: string | null; + /** @nullable */ + name?: string | null; + /** @nullable */ + system_prompt?: string | null; + /** @nullable */ + temperature?: number | null; + /** @nullable */ + tool_policy?: string | null; + /** @nullable */ + toolset_json?: string | null; + /** @nullable */ + variables?: string | null; + /** @nullable */ + visibility?: string | null; +} diff --git a/src/client/models/updateComment.ts b/src/client/models/updateComment.ts new file mode 100644 index 0000000..b92a396 --- /dev/null +++ b/src/client/models/updateComment.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateComment { + body: string; +} diff --git a/src/client/models/updateContent.ts b/src/client/models/updateContent.ts new file mode 100644 index 0000000..77d0b58 --- /dev/null +++ b/src/client/models/updateContent.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateContent { + /** @nullable */ + branch?: string | null; + content: string; + message: string; + sha: string; +} diff --git a/src/client/models/updateConversation.ts b/src/client/models/updateConversation.ts new file mode 100644 index 0000000..0a90cbd --- /dev/null +++ b/src/client/models/updateConversation.ts @@ -0,0 +1,12 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateConversation { + /** @nullable */ + title?: string | null; +} diff --git a/src/client/models/updateIssue.ts b/src/client/models/updateIssue.ts new file mode 100644 index 0000000..234e605 --- /dev/null +++ b/src/client/models/updateIssue.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateIssue { + /** @nullable */ + body?: string | null; + /** @nullable */ + due_at?: string | null; + /** @nullable */ + priority?: string | null; + /** @nullable */ + title?: string | null; +} diff --git a/src/client/models/updateLabel.ts b/src/client/models/updateLabel.ts new file mode 100644 index 0000000..a3465a3 --- /dev/null +++ b/src/client/models/updateLabel.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateLabel { + /** @nullable */ + color?: string | null; + /** @nullable */ + description?: string | null; + /** @nullable */ + name?: string | null; +} diff --git a/src/client/models/updateMessageRequest.ts b/src/client/models/updateMessageRequest.ts new file mode 100644 index 0000000..edab0d9 --- /dev/null +++ b/src/client/models/updateMessageRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateMessageRequest { + content: string; +} diff --git a/src/client/models/updateMilestone.ts b/src/client/models/updateMilestone.ts new file mode 100644 index 0000000..29a8a49 --- /dev/null +++ b/src/client/models/updateMilestone.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateMilestone { + /** @nullable */ + description?: string | null; + /** @nullable */ + due_at?: string | null; + /** @nullable */ + state?: string | null; + /** @nullable */ + title?: string | null; +} diff --git a/src/client/models/updatePrComment.ts b/src/client/models/updatePrComment.ts new file mode 100644 index 0000000..5a5c8a3 --- /dev/null +++ b/src/client/models/updatePrComment.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdatePrComment { + body: string; +} diff --git a/src/client/models/updateProtect.ts b/src/client/models/updateProtect.ts new file mode 100644 index 0000000..b625e74 --- /dev/null +++ b/src/client/models/updateProtect.ts @@ -0,0 +1,26 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateProtect { + /** @nullable */ + allow_deletions?: boolean | null; + /** @nullable */ + allow_force_pushes?: boolean | null; + /** @nullable */ + enforce_admins?: boolean | null; + /** @nullable */ + pattern?: string | null; + /** @nullable */ + require_pull_request?: boolean | null; + /** @nullable */ + require_status_checks?: boolean | null; + /** @nullable */ + required_approvals?: number | null; + /** @nullable */ + required_status_contexts?: string[] | null; +} diff --git a/src/client/models/updatePullRequest.ts b/src/client/models/updatePullRequest.ts new file mode 100644 index 0000000..9fc0838 --- /dev/null +++ b/src/client/models/updatePullRequest.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdatePullRequest { + /** @nullable */ + body?: string | null; + /** @nullable */ + draft?: boolean | null; + /** @nullable */ + state?: string | null; + /** @nullable */ + title?: string | null; +} diff --git a/src/client/models/updateRelease.ts b/src/client/models/updateRelease.ts new file mode 100644 index 0000000..749015d --- /dev/null +++ b/src/client/models/updateRelease.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateRelease { + /** @nullable */ + body?: string | null; + /** @nullable */ + draft?: boolean | null; + /** @nullable */ + name?: string | null; + /** @nullable */ + prerelease?: boolean | null; + /** @nullable */ + tag_name?: string | null; +} diff --git a/src/client/models/updateRepo.ts b/src/client/models/updateRepo.ts new file mode 100644 index 0000000..8f0edef --- /dev/null +++ b/src/client/models/updateRepo.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateRepo { + /** @nullable */ + default_branch?: string | null; + /** @nullable */ + description?: string | null; + /** @nullable */ + is_archived?: boolean | null; + /** @nullable */ + is_template?: boolean | null; + /** @nullable */ + name?: string | null; + /** @nullable */ + visibility?: string | null; +} diff --git a/src/client/models/updateUserAccessToken.ts b/src/client/models/updateUserAccessToken.ts new file mode 100644 index 0000000..1ce1c66 --- /dev/null +++ b/src/client/models/updateUserAccessToken.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserAccessToken { + /** @nullable */ + expires_at?: string | null; + /** @nullable */ + name?: string | null; + /** @nullable */ + scopes?: string[] | null; +} diff --git a/src/client/models/updateUserAccessibilityConfig.ts b/src/client/models/updateUserAccessibilityConfig.ts new file mode 100644 index 0000000..e778168 --- /dev/null +++ b/src/client/models/updateUserAccessibilityConfig.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserAccessibilityConfig { + /** @nullable */ + color_blind_mode?: string | null; + /** @nullable */ + font_scale_percent?: number | null; + /** @nullable */ + high_contrast?: boolean | null; + /** @nullable */ + reduce_motion?: boolean | null; + /** @nullable */ + screen_reader_optimized?: boolean | null; +} diff --git a/src/client/models/updateUserAppearanceConfig.ts b/src/client/models/updateUserAppearanceConfig.ts new file mode 100644 index 0000000..c4c35dd --- /dev/null +++ b/src/client/models/updateUserAppearanceConfig.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserAppearanceConfig { + /** @nullable */ + code_theme?: string | null; + /** @nullable */ + layout_density?: string | null; + /** @nullable */ + show_line_numbers?: boolean | null; + /** @nullable */ + sidebar_collapsed?: boolean | null; + /** @nullable */ + theme?: string | null; +} diff --git a/src/client/models/updateUserNotificationConfig.ts b/src/client/models/updateUserNotificationConfig.ts new file mode 100644 index 0000000..da26014 --- /dev/null +++ b/src/client/models/updateUserNotificationConfig.ts @@ -0,0 +1,36 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserNotificationConfig { + /** @nullable */ + digest_mode?: string | null; + /** @nullable */ + dnd_enabled?: boolean | null; + /** @nullable */ + dnd_end_minute?: number | null; + /** @nullable */ + dnd_start_minute?: number | null; + /** @nullable */ + email_enabled?: boolean | null; + /** @nullable */ + in_app_enabled?: boolean | null; + /** @nullable */ + marketing_enabled?: boolean | null; + /** @nullable */ + product_enabled?: boolean | null; + /** @nullable */ + push_enabled?: boolean | null; + /** @nullable */ + push_subscription_endpoint?: string | null; + /** @nullable */ + push_subscription_keys_auth?: string | null; + /** @nullable */ + push_subscription_keys_p256dh?: string | null; + /** @nullable */ + security_enabled?: boolean | null; +} diff --git a/src/client/models/updateUserPrivacyConfig.ts b/src/client/models/updateUserPrivacyConfig.ts new file mode 100644 index 0000000..dc44e14 --- /dev/null +++ b/src/client/models/updateUserPrivacyConfig.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserPrivacyConfig { + /** @nullable */ + activity_visibility?: string | null; + /** @nullable */ + allow_direct_messages?: boolean | null; + /** @nullable */ + allow_search_indexing?: boolean | null; + /** @nullable */ + email_visibility?: string | null; + /** @nullable */ + profile_visibility?: string | null; + /** @nullable */ + show_online_status?: boolean | null; +} diff --git a/src/client/models/updateUserProfileConfig.ts b/src/client/models/updateUserProfileConfig.ts new file mode 100644 index 0000000..4ecd8b6 --- /dev/null +++ b/src/client/models/updateUserProfileConfig.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserProfileConfig { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + language?: string | null; + /** @nullable */ + theme?: string | null; + /** @nullable */ + timezone?: string | null; +} diff --git a/src/client/models/updateUserSshKey.ts b/src/client/models/updateUserSshKey.ts new file mode 100644 index 0000000..b35597b --- /dev/null +++ b/src/client/models/updateUserSshKey.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateUserSshKey { + /** @nullable */ + expires_at?: string | null; + /** @nullable */ + title?: string | null; +} diff --git a/src/client/models/updateWebhook.ts b/src/client/models/updateWebhook.ts new file mode 100644 index 0000000..6b222d9 --- /dev/null +++ b/src/client/models/updateWebhook.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateWebhook { + /** @nullable */ + active?: boolean | null; + /** @nullable */ + events?: string[] | null; + /** @nullable */ + secret?: string | null; + /** @nullable */ + url?: string | null; +} diff --git a/src/client/models/updateWorkspace.ts b/src/client/models/updateWorkspace.ts new file mode 100644 index 0000000..8f87434 --- /dev/null +++ b/src/client/models/updateWorkspace.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateWorkspace { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + description?: string | null; + /** @nullable */ + name?: string | null; +} diff --git a/src/client/models/updateWorkspaceGroup.ts b/src/client/models/updateWorkspaceGroup.ts new file mode 100644 index 0000000..52e05b5 --- /dev/null +++ b/src/client/models/updateWorkspaceGroup.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateWorkspaceGroup { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + name?: string | null; +} diff --git a/src/client/models/updateWorkspaceJoinStrategy.ts b/src/client/models/updateWorkspaceJoinStrategy.ts new file mode 100644 index 0000000..2b5ed17 --- /dev/null +++ b/src/client/models/updateWorkspaceJoinStrategy.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateWorkspaceJoinStrategy { + /** @nullable */ + answer?: string | null; + /** @nullable */ + enabled?: boolean | null; + /** @nullable */ + question?: string | null; + /** @nullable */ + require_approval?: boolean | null; + /** @nullable */ + require_question?: boolean | null; +} diff --git a/src/client/models/updateWorkspaceMember.ts b/src/client/models/updateWorkspaceMember.ts new file mode 100644 index 0000000..77a994c --- /dev/null +++ b/src/client/models/updateWorkspaceMember.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UpdateWorkspaceMember { + admin: boolean; +} diff --git a/src/client/models/userAccessToken.ts b/src/client/models/userAccessToken.ts new file mode 100644 index 0000000..ec1bbed --- /dev/null +++ b/src/client/models/userAccessToken.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserAccessToken { + created_at: string; + /** @nullable */ + expires_at?: string | null; + id: number; + is_revoked: boolean; + name: string; + scopes: string[]; + updated_at: string; +} diff --git a/src/client/models/userAccessibilityConfig.ts b/src/client/models/userAccessibilityConfig.ts new file mode 100644 index 0000000..89247d0 --- /dev/null +++ b/src/client/models/userAccessibilityConfig.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserAccessibilityConfig { + /** @nullable */ + color_blind_mode?: string | null; + font_scale_percent: number; + high_contrast: boolean; + reduce_motion: boolean; + screen_reader_optimized: boolean; +} diff --git a/src/client/models/userAppearanceConfig.ts b/src/client/models/userAppearanceConfig.ts new file mode 100644 index 0000000..ceb1be6 --- /dev/null +++ b/src/client/models/userAppearanceConfig.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserAppearanceConfig { + code_theme: string; + layout_density: string; + show_line_numbers: boolean; + sidebar_collapsed: boolean; + theme: string; +} diff --git a/src/client/models/userConfigResponse.ts b/src/client/models/userConfigResponse.ts new file mode 100644 index 0000000..9a2606e --- /dev/null +++ b/src/client/models/userConfigResponse.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ +import type { UserAccessibilityConfig } from './userAccessibilityConfig'; +import type { UserAppearanceConfig } from './userAppearanceConfig'; +import type { UserNotificationConfig } from './userNotificationConfig'; +import type { UserPrivacyConfig } from './userPrivacyConfig'; +import type { UserProfileConfig } from './userProfileConfig'; + +export interface UserConfigResponse { + accessibility: UserAccessibilityConfig; + appearance: UserAppearanceConfig; + notifications: UserNotificationConfig; + privacy: UserPrivacyConfig; + profile: UserProfileConfig; +} diff --git a/src/client/models/userContributionHeatmapParams.ts b/src/client/models/userContributionHeatmapParams.ts new file mode 100644 index 0000000..1bc2934 --- /dev/null +++ b/src/client/models/userContributionHeatmapParams.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type UserContributionHeatmapParams = { +/** + * @nullable + */ +start_date?: string | null; +/** + * @nullable + */ +end_date?: string | null; +}; diff --git a/src/client/models/userNotificationConfig.ts b/src/client/models/userNotificationConfig.ts new file mode 100644 index 0000000..3b78821 --- /dev/null +++ b/src/client/models/userNotificationConfig.ts @@ -0,0 +1,28 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserNotificationConfig { + digest_mode: string; + dnd_enabled: boolean; + /** @nullable */ + dnd_end_minute?: number | null; + /** @nullable */ + dnd_start_minute?: number | null; + email_enabled: boolean; + in_app_enabled: boolean; + marketing_enabled: boolean; + product_enabled: boolean; + push_enabled: boolean; + /** @nullable */ + push_subscription_endpoint?: string | null; + /** @nullable */ + push_subscription_keys_auth?: string | null; + /** @nullable */ + push_subscription_keys_p256dh?: string | null; + security_enabled: boolean; +} diff --git a/src/client/models/userPrivacyConfig.ts b/src/client/models/userPrivacyConfig.ts new file mode 100644 index 0000000..095550a --- /dev/null +++ b/src/client/models/userPrivacyConfig.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserPrivacyConfig { + activity_visibility: string; + allow_direct_messages: boolean; + allow_search_indexing: boolean; + email_visibility: string; + profile_visibility: string; + show_online_status: boolean; +} diff --git a/src/client/models/userProfileConfig.ts b/src/client/models/userProfileConfig.ts new file mode 100644 index 0000000..4f87cd1 --- /dev/null +++ b/src/client/models/userProfileConfig.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserProfileConfig { + language: string; + theme: string; + timezone: string; +} diff --git a/src/client/models/userRelationCard.ts b/src/client/models/userRelationCard.ts new file mode 100644 index 0000000..68f6936 --- /dev/null +++ b/src/client/models/userRelationCard.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserRelationCard { + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + display_name?: string | null; + is_blocked: boolean; + is_following: boolean; + username: string; +} diff --git a/src/client/models/userRelationCounts.ts b/src/client/models/userRelationCounts.ts new file mode 100644 index 0000000..1c3ddfb --- /dev/null +++ b/src/client/models/userRelationCounts.ts @@ -0,0 +1,13 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserRelationCounts { + blocked: number; + followers: number; + following: number; +} diff --git a/src/client/models/userRelationStatus.ts b/src/client/models/userRelationStatus.ts new file mode 100644 index 0000000..3ee4d09 --- /dev/null +++ b/src/client/models/userRelationStatus.ts @@ -0,0 +1,17 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserRelationStatus { + /** @nullable */ + avatar_url?: string | null; + has_blocked_me: boolean; + is_blocked: boolean; + is_followed_by: boolean; + is_following: boolean; + username: string; +} diff --git a/src/client/models/userSshKey.ts b/src/client/models/userSshKey.ts new file mode 100644 index 0000000..1408c62 --- /dev/null +++ b/src/client/models/userSshKey.ts @@ -0,0 +1,25 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserSshKey { + created_at: string; + /** @nullable */ + expires_at?: string | null; + fingerprint: string; + id: number; + is_revoked: boolean; + is_verified: boolean; + /** @nullable */ + key_bits?: number | null; + key_type: string; + /** @nullable */ + last_used_at?: string | null; + public_key: string; + title: string; + updated_at: string; +} diff --git a/src/client/models/userSummaryResponse.ts b/src/client/models/userSummaryResponse.ts new file mode 100644 index 0000000..0f8120c --- /dev/null +++ b/src/client/models/userSummaryResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface UserSummaryResponse { + avatar_url: string; + created_at: string; + display_name: string; + username: string; + website_url: string; +} diff --git a/src/client/models/usersFollowersParams.ts b/src/client/models/usersFollowersParams.ts new file mode 100644 index 0000000..a2ec999 --- /dev/null +++ b/src/client/models/usersFollowersParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type UsersFollowersParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/usersFollowingParams.ts b/src/client/models/usersFollowingParams.ts new file mode 100644 index 0000000..3b97318 --- /dev/null +++ b/src/client/models/usersFollowingParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type UsersFollowingParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/usersUserChpcParams.ts b/src/client/models/usersUserChpcParams.ts new file mode 100644 index 0000000..f7941c6 --- /dev/null +++ b/src/client/models/usersUserChpcParams.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type UsersUserChpcParams = { +/** + * @nullable + */ +start_date?: string | null; +/** + * @nullable + */ +end_date?: string | null; +}; diff --git a/src/client/models/verify2FAParams.ts b/src/client/models/verify2FAParams.ts new file mode 100644 index 0000000..efaceb1 --- /dev/null +++ b/src/client/models/verify2FAParams.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface Verify2FAParams { + code: string; +} diff --git a/src/client/models/voiceDeafRequest.ts b/src/client/models/voiceDeafRequest.ts new file mode 100644 index 0000000..3b037a3 --- /dev/null +++ b/src/client/models/voiceDeafRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface VoiceDeafRequest { + deafened: boolean; +} diff --git a/src/client/models/voiceMuteRequest.ts b/src/client/models/voiceMuteRequest.ts new file mode 100644 index 0000000..bc7d3bd --- /dev/null +++ b/src/client/models/voiceMuteRequest.ts @@ -0,0 +1,11 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface VoiceMuteRequest { + muted: boolean; +} diff --git a/src/client/models/webhookDeliveryResponse.ts b/src/client/models/webhookDeliveryResponse.ts new file mode 100644 index 0000000..cf7b45a --- /dev/null +++ b/src/client/models/webhookDeliveryResponse.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WebhookDeliveryResponse { + created_at: string; + /** @nullable */ + delivered_at?: string | null; + /** @nullable */ + error?: string | null; + event: string; + id: string; + /** @nullable */ + response_status?: number | null; + webhook: string; +} diff --git a/src/client/models/webhookResponse.ts b/src/client/models/webhookResponse.ts new file mode 100644 index 0000000..40c38ce --- /dev/null +++ b/src/client/models/webhookResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WebhookResponse { + active: boolean; + created_at: string; + created_by: string; + events: string[]; + id: string; + repo: string; + updated_at: string; + url: string; +} diff --git a/src/client/models/workspaceGroupResponse.ts b/src/client/models/workspaceGroupResponse.ts new file mode 100644 index 0000000..3c459a8 --- /dev/null +++ b/src/client/models/workspaceGroupResponse.ts @@ -0,0 +1,15 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WorkspaceGroupResponse { + /** @nullable */ + avatar_url?: string | null; + created_at: string; + is_deleted: boolean; + name: string; +} diff --git a/src/client/models/workspaceJoinApplyResponse.ts b/src/client/models/workspaceJoinApplyResponse.ts new file mode 100644 index 0000000..e4b5443 --- /dev/null +++ b/src/client/models/workspaceJoinApplyResponse.ts @@ -0,0 +1,24 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WorkspaceJoinApplyResponse { + /** @nullable */ + answer?: string | null; + /** @nullable */ + avatar_url?: string | null; + created_at: string; + /** @nullable */ + message?: string | null; + /** @nullable */ + question?: string | null; + status: string; + updated_at: string; + username: string; + workspace_avatar_url: string; + workspace_name: string; +} diff --git a/src/client/models/workspaceJoinApprovalResponse.ts b/src/client/models/workspaceJoinApprovalResponse.ts new file mode 100644 index 0000000..47d76cd --- /dev/null +++ b/src/client/models/workspaceJoinApprovalResponse.ts @@ -0,0 +1,22 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WorkspaceJoinApprovalResponse { + approved: boolean; + /** @nullable */ + approver_avatar_url?: string | null; + approver_username: string; + /** @nullable */ + avatar_url?: string | null; + created_at: string; + /** @nullable */ + reason?: string | null; + username: string; + workspace_avatar_url: string; + workspace_name: string; +} diff --git a/src/client/models/workspaceJoinStrategyResponse.ts b/src/client/models/workspaceJoinStrategyResponse.ts new file mode 100644 index 0000000..4713654 --- /dev/null +++ b/src/client/models/workspaceJoinStrategyResponse.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WorkspaceJoinStrategyResponse { + created_at: string; + enabled: boolean; + has_answer: boolean; + /** @nullable */ + question?: string | null; + require_approval: boolean; + require_question: boolean; + updated_at: string; + workspace_avatar_url: string; + workspace_name: string; +} diff --git a/src/client/models/workspaceListJoinAppliesParams.ts b/src/client/models/workspaceListJoinAppliesParams.ts new file mode 100644 index 0000000..b537505 --- /dev/null +++ b/src/client/models/workspaceListJoinAppliesParams.ts @@ -0,0 +1,14 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type WorkspaceListJoinAppliesParams = { +/** + * @nullable + */ +status?: string | null; +}; diff --git a/src/client/models/workspaceListMembersParams.ts b/src/client/models/workspaceListMembersParams.ts new file mode 100644 index 0000000..24a4ac9 --- /dev/null +++ b/src/client/models/workspaceListMembersParams.ts @@ -0,0 +1,20 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export type WorkspaceListMembersParams = { +/** + * @minimum 0 + * @nullable + */ +offset?: number | null; +/** + * @minimum 0 + * @nullable + */ +limit?: number | null; +}; diff --git a/src/client/models/workspaceMemberResponse.ts b/src/client/models/workspaceMemberResponse.ts new file mode 100644 index 0000000..410e1dd --- /dev/null +++ b/src/client/models/workspaceMemberResponse.ts @@ -0,0 +1,18 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WorkspaceMemberResponse { + admin: boolean; + /** @nullable */ + avatar_url?: string | null; + /** @nullable */ + display_name?: string | null; + join_at: string; + owner: boolean; + username: string; +} diff --git a/src/client/models/workspaceResponse.ts b/src/client/models/workspaceResponse.ts new file mode 100644 index 0000000..212481c --- /dev/null +++ b/src/client/models/workspaceResponse.ts @@ -0,0 +1,16 @@ +/** + * Generated by orval v8.12.3 🍺 + * Do not edit manually. + * GitDataAI API + * GitDataAI platform REST API + * OpenAPI spec version: 1.0.0 + */ + +export interface WorkspaceResponse { + admin: boolean; + avatar_url: string; + created_at: string; + description: string; + name: string; + owner: boolean; +} diff --git a/src/components/CodePreviewPanel.tsx b/src/components/CodePreviewPanel.tsx new file mode 100644 index 0000000..30297d6 --- /dev/null +++ b/src/components/CodePreviewPanel.tsx @@ -0,0 +1,136 @@ +import { createContext, useContext, useState, useCallback, type ReactNode } from "react"; +import { PanelRightClose, Copy, Check } from "lucide-react"; + +export interface CodePreviewPayload { + id: string; + code: string; + language: string; + title?: string; + subtitle?: string; + kind?: "code" | "subagent"; + status?: "pending" | "ok" | "error" | "stopped"; + onStop?: () => void; +} + +interface CodePreviewContextValue { + activeCode: CodePreviewPayload | null; + openCodePreview: (payload: CodePreviewPayload) => void; + closeCodePreview: () => void; + updateCodePreview: (updates: Partial) => void; +} + +const CodePreviewContext = createContext(null); + +export function CodePreviewProvider({ children }: { children: ReactNode }) { + const [activeCode, setActiveCode] = useState(null); + + const openCodePreview = useCallback((payload: CodePreviewPayload) => { + setActiveCode(payload); + }, []); + + const closeCodePreview = useCallback(() => { + setActiveCode(null); + }, []); + + const updateCodePreview = useCallback((updates: Partial) => { + setActiveCode((prev) => (prev ? { ...prev, ...updates } : null)); + }, []); + + return ( + + {children} + + ); +} + +export function useCodePreview() { + const ctx = useContext(CodePreviewContext); + return ctx; +} + +export function CodePreviewPanel() { + const { activeCode, closeCodePreview, updateCodePreview } = useCodePreview() || {}; + const [copied, setCopied] = useState(false); + + if (!activeCode) return null; + + const lines = activeCode.code.replace(/\n$/, "").split("\n"); + const isSubAgent = activeCode.kind === "subagent"; + + const handleCopy = () => { + navigator.clipboard.writeText(activeCode.code).then(() => { + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }); + }; + + const handleStop = () => { + activeCode.onStop?.(); + updateCodePreview?.({ status: "stopped" }); + }; + + return ( + + ); +} diff --git a/src/components/CommandPalette.tsx b/src/components/CommandPalette.tsx new file mode 100644 index 0000000..2686fad --- /dev/null +++ b/src/components/CommandPalette.tsx @@ -0,0 +1,384 @@ +import { useState, useEffect, useMemo, useCallback, useRef } from "react"; +import { useNavigate } from "react-router"; +import { Command } from "cmdk"; +import { + Search, + Building2, + FolderGit2, + Hash, + CircleDot, + Loader2, + AlertCircle, + CornerDownLeft, + ArrowUpDown, + X, +} from "lucide-react"; +import { api } from "@/client"; + +// ---- Types ---- + +type HitType = "workspace" | "repo" | "room" | "issue"; + +interface SearchHit { + id: string; + type: HitType; + label: string; + subtitle: string; + url: string; + meta?: string; + metaTone?: "success" | "muted" | "warning"; +} + +interface SearchGroup { + items: T[]; + total: number; + has_more: boolean; +} + +interface SearchResponse { + workspaces: SearchGroup<{ name: string; description?: string | null }>; + repos: SearchGroup<{ name: string; workspace: string; description?: string | null }>; + rooms: SearchGroup<{ id: string; name: string; workspace: string }>; + issues: SearchGroup<{ number: number; title: string; state: string; workspace: string }>; +} + +const TYPE_ICONS: Record = { + workspace: Building2, + repo: FolderGit2, + room: Hash, + issue: CircleDot, +}; + +const TYPE_LABELS: Record = { + workspace: "Workspaces", + repo: "Repositories", + room: "Channels", + issue: "Issues", +}; + +const PLACEHOLDERS = [ + "Search workspaces, repos, channels...", + "Find an issue...", + "Jump to a channel...", + "Open a repository...", +]; + +// ---- Helpers ---- + +function buildHits(data: SearchResponse): SearchHit[] { + const results: SearchHit[] = []; + + for (const ws of data.workspaces.items) { + results.push({ + id: `ws-${ws.name}`, + type: "workspace", + label: ws.name, + subtitle: ws.description || "Workspace", + url: `/${ws.name}`, + }); + } + + for (const repo of data.repos.items) { + results.push({ + id: `repo-${repo.workspace}-${repo.name}`, + type: "repo", + label: `${repo.workspace}/${repo.name}`, + subtitle: repo.description || "Repository", + url: `/${repo.workspace}/repo/${repo.name}`, + }); + } + + for (const room of data.rooms.items) { + results.push({ + id: `room-${room.id}`, + type: "room", + label: `#${room.name}`, + subtitle: "Channel", + url: `/${room.workspace}/channel/${room.id}`, + }); + } + + for (const issue of data.issues.items) { + results.push({ + id: `issue-${issue.workspace}-${issue.number}`, + type: "issue", + label: `#${issue.number} ${issue.title}`, + subtitle: issue.workspace, + meta: issue.state, + metaTone: issue.state === "open" ? "success" : "muted", + url: `/${issue.workspace}/issues/${issue.number}`, + }); + } + + return results; +} + +// ---- Component ---- + +export function CommandPalette() { + const [open, setOpen] = useState(false); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [items, setItems] = useState([]); + const [search, setSearch] = useState(""); + const [placeholderIndex, setPlaceholderIndex] = useState(0); + const navigate = useNavigate(); + const inputRef = useRef(null); + + // Rotating placeholder + useEffect(() => { + if (!open) return; + const interval = setInterval(() => { + setPlaceholderIndex((i) => (i + 1) % PLACEHOLDERS.length); + }, 3000); + return () => clearInterval(interval); + }, [open]); + + // Fetch data + const fetchData = useCallback(async (q: string) => { + setLoading(true); + setError(null); + try { + const { data } = await api.get("/api/v1/search", { + params: q.trim() ? { q: q.trim() } : {}, + }); + setItems(buildHits(data)); + } catch (err) { + setError(err instanceof Error ? err.message : "Search failed"); + setItems([]); + } + setLoading(false); + }, []); + + // Pre-fetch on open, debounced re-fetch on type + useEffect(() => { + if (!open) return; + const timer = setTimeout(() => fetchData(search), search ? 150 : 0); + return () => clearTimeout(timer); + }, [open, search, fetchData]); + + // Reset on close + const handleOpenChange = useCallback((o: boolean) => { + setOpen(o); + if (!o) { + setSearch(""); + setError(null); + } + }, []); + + // Group by type + const grouped = useMemo(() => { + const g: Record = {}; + for (const it of items) { + const key = TYPE_LABELS[it.type]; + (g[key] ??= []).push(it); + } + return g; + }, [items]); + + const totalHits = items.length; + + // ⌘K global shortcut + useEffect(() => { + const down = (e: KeyboardEvent) => { + if (e.key === "k" && (e.metaKey || e.ctrlKey)) { + e.preventDefault(); + setOpen((v) => !v); + } + }; + document.addEventListener("keydown", down); + return () => document.removeEventListener("keydown", down); + }, []); + + const handleSelect = useCallback( + (url: string) => { + setOpen(false); + navigate(url); + }, + [navigate], + ); + + return ( + <> + {/* Trigger button */} + + + {/* Modal dialog */} + + {/* Search input */} +
+ + +
+ {loading && ( + + )} + {error && ( + + )} +
+
+ + {/* Results list */} + + {/* Error state */} + {error && ( +
+
+ +
+

{error}

+ +
+ )} + + {/* Empty state */} + {!error && !loading && totalHits === 0 && ( +
+
+ +
+

+ {search.trim() + ? `No results for "${search.trim()}"` + : "Type to search across all workspaces"} +

+
+ )} + + {/* Loading skeleton */} + {loading && totalHits === 0 && !error && ( +
+ {Array.from({ length: 4 }).map((_, i) => ( +
+
+
+
+
+
+
+ ))} +
+ )} + + {/* Grouped results */} + {Object.entries(grouped).map(([groupName, groupItems]) => ( + + {groupName} + + {groupItems.length} + + + } + key={groupName} + > + {groupItems.map((item) => { + const Icon = TYPE_ICONS[item.type]; + return ( + handleSelect(item.url)} + value={item.id} + > + + + +
+
+ {item.label} +
+
+ + {item.subtitle} + + {item.meta && ( + + {item.meta} + + )} +
+
+ + ↵ + +
+ ); + })} +
+ ))} + + + {/* Footer */} +
+
+ + + Navigate + + + + Open + + + + Close + +
+ {totalHits > 0 && ( + + {totalHits} result{totalHits !== 1 ? "s" : ""} + + )} +
+ + + ); +} diff --git a/src/components/MarkdownRenderer.tsx b/src/components/MarkdownRenderer.tsx new file mode 100644 index 0000000..be42e9a --- /dev/null +++ b/src/components/MarkdownRenderer.tsx @@ -0,0 +1,233 @@ +import { memo, useMemo, useState, useCallback } from "react"; +import ReactMarkdown from "react-markdown"; +import remarkGfm from "remark-gfm"; +import { Check, Copy, Code2, PanelRightOpen } from "lucide-react"; +import { cn } from "@/lib/utils"; + +interface MarkdownRendererProps { + content: string; + className?: string; + /** Called when user clicks to open a code block in a side panel. */ + onOpenCodePanel?: (payload: { + code: string; + language: string; + lineCount: number; + title?: string; + }) => void; +} + +/** Extract plain text from react-markdown children (string or array). */ +function extractText(children: React.ReactNode): string { + if (typeof children === "string") return children; + if (Array.isArray(children)) { + return children.map((c) => (typeof c === "string" ? c : "")).join(""); + } + return ""; +} + +/** + * Code block wrapped in Artifact-style container. + */ +const ArtifactCodeBlock = memo(function ArtifactCodeBlock({ + language, + children, + onOpen, +}: { + language: string; + children: string; + onOpen?: () => void; +}) { + const [copied, setCopied] = useState(false); + const lines = useMemo( + () => children.replace(/\n$/, "").split("\n"), + [children], + ); + + const handleCopy = useCallback(() => { + navigator.clipboard + .writeText(children) + .then(() => { + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }) + .catch(() => {}); + }, [children]); + + return ( +
{ + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + onOpen(); + } + } + : undefined + } + aria-label={onOpen ? `Open ${language || "code"} in side panel` : undefined} + > + {/* Header */} +
+
+ + + + + {language || "code"} + + + {lines.length} line{lines.length !== 1 ? "s" : ""} + +
+
+ {onOpen && ( + + + Open in panel + + )} + {!onOpen && ( + + )} +
+
+ + {/* Content with line numbers — only when no side panel (inline mode) */} + {!onOpen && ( +
+ +
+            
+              {children.replace(/\n$/, "")}
+            
+          
+
+ )} + + {/* Compact summary when panel mode */} + {onOpen && ( +
+

+ {lines[0]?.slice(0, 80)}{(lines[0]?.length ?? 0) > 80 ? "…" : ""} +

+
+ )} +
+ ); +}); + +export const MarkdownRenderer = memo(function MarkdownRenderer({ + content, + className = "", + onOpenCodePanel, +}: MarkdownRendererProps) { + const components = useMemo( + () => ({ + code({ + className: codeClassName, + children, + ...props + }: { + className?: string; + children: React.ReactNode; + node?: unknown; + }) { + const match = /language-(\w+)/.exec(codeClassName || ""); + const codeContent = extractText(children).replace(/\n$/, ""); + const isInline = !match && !codeContent.includes("\n"); + + if (isInline) { + return ( + + {extractText(children)} + + ); + } + + return ( + + onOpenCodePanel({ + code: codeContent, + language: match?.[1] || "text", + lineCount: codeContent.split("\n").length, + }) + : undefined + } + > + {codeContent} + + ); + }, + pre({ children }: { children: React.ReactNode }) { + return <>{children}; + }, + }), + [onOpenCodePanel], + ); + + return ( +
+ + {content} + +
+ ); +}); + +export default MarkdownRenderer; diff --git a/src/components/Reasoning.tsx b/src/components/Reasoning.tsx new file mode 100644 index 0000000..1a70627 --- /dev/null +++ b/src/components/Reasoning.tsx @@ -0,0 +1,57 @@ +import { useState, type ReactNode } from "react"; +import { Brain, ChevronDown } from "lucide-react"; +import { cn } from "@/lib/utils"; + +interface ReasoningProps { + children: ReactNode; + defaultOpen?: boolean; + summary?: string; +} + +export function Reasoning({ children, defaultOpen = false, summary }: ReasoningProps) { + const [isOpen, setIsOpen] = useState(defaultOpen); + + return ( +
+ + {isOpen && ( +
+ {children} +
+ )} +
+ ); +} + +interface ThinkingBlockProps { + content: string; + isStreaming?: boolean; +} + +export function ThinkingBlock({ content, isStreaming }: ThinkingBlockProps) { + return ( + +
+ {content} + {isStreaming && ( + + )} +
+
+ ); +} diff --git a/src/components/ai-elements/artifact.tsx b/src/components/ai-elements/artifact.tsx new file mode 100644 index 0000000..0597d4a --- /dev/null +++ b/src/components/ai-elements/artifact.tsx @@ -0,0 +1,148 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { cn } from "@/lib/utils"; +import type { LucideIcon } from "lucide-react"; +import { XIcon } from "lucide-react"; +import type { ComponentProps, HTMLAttributes } from "react"; + +export type ArtifactProps = HTMLAttributes; + +export const Artifact = ({ className, ...props }: ArtifactProps) => ( +
+); + +export type ArtifactHeaderProps = HTMLAttributes; + +export const ArtifactHeader = ({ + className, + ...props +}: ArtifactHeaderProps) => ( +
+); + +export type ArtifactCloseProps = ComponentProps; + +export const ArtifactClose = ({ + className, + children, + size = "sm", + variant = "ghost", + ...props +}: ArtifactCloseProps) => ( + +); + +export type ArtifactTitleProps = HTMLAttributes; + +export const ArtifactTitle = ({ className, ...props }: ArtifactTitleProps) => ( +

+); + +export type ArtifactDescriptionProps = HTMLAttributes; + +export const ArtifactDescription = ({ + className, + ...props +}: ArtifactDescriptionProps) => ( +

+); + +export type ArtifactActionsProps = HTMLAttributes; + +export const ArtifactActions = ({ + className, + ...props +}: ArtifactActionsProps) => ( +

+); + +export type ArtifactActionProps = ComponentProps & { + tooltip?: string; + label?: string; + icon?: LucideIcon; +}; + +export const ArtifactAction = ({ + tooltip, + label, + icon: Icon, + children, + className, + size = "sm", + variant = "ghost", + ...props +}: ArtifactActionProps) => { + const button = ( + + ); + + if (tooltip) { + return ( + + + {button} + +

{tooltip}

+
+
+
+ ); + } + + return button; +}; + +export type ArtifactContentProps = HTMLAttributes; + +export const ArtifactContent = ({ + className, + ...props +}: ArtifactContentProps) => ( +
+); diff --git a/src/components/ai-elements/code-block.tsx b/src/components/ai-elements/code-block.tsx new file mode 100644 index 0000000..bc1441a --- /dev/null +++ b/src/components/ai-elements/code-block.tsx @@ -0,0 +1,562 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { cn } from "@/lib/utils"; +import { CheckIcon, CopyIcon } from "lucide-react"; +import type { ComponentProps, CSSProperties, HTMLAttributes } from "react"; +import { + createContext, + memo, + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; +import type { + BundledLanguage, + BundledTheme, + HighlighterGeneric, + ThemedToken, +} from "shiki"; +import { createHighlighter } from "shiki"; + +// Shiki uses bitflags for font styles: 1=italic, 2=bold, 4=underline +// oxlint-disable-next-line eslint(no-bitwise) +const isItalic = (fontStyle: number | undefined) => fontStyle && fontStyle & 1; +// oxlint-disable-next-line eslint(no-bitwise) +const isBold = (fontStyle: number | undefined) => fontStyle && fontStyle & 2; +const isUnderline = (fontStyle: number | undefined) => + // oxlint-disable-next-line eslint(no-bitwise) + fontStyle && fontStyle & 4; + +// Transform tokens to include pre-computed keys to avoid noArrayIndexKey lint +interface KeyedToken { + token: ThemedToken; + key: string; +} +interface KeyedLine { + tokens: KeyedToken[]; + key: string; +} + +const addKeysToTokens = (lines: ThemedToken[][]): KeyedLine[] => + lines.map((line, lineIdx) => ({ + key: `line-${lineIdx}`, + tokens: line.map((token, tokenIdx) => ({ + key: `line-${lineIdx}-${tokenIdx}`, + token, + })), + })); + +// Token rendering component +const TokenSpan = ({ token }: { token: ThemedToken }) => ( + + {token.content} + +); + +// Line number styles using CSS counters +const LINE_NUMBER_CLASSES = cn( + "block", + "before:content-[counter(line)]", + "before:inline-block", + "before:[counter-increment:line]", + "before:w-8", + "before:mr-4", + "before:text-right", + "before:text-muted-foreground/50", + "before:font-mono", + "before:select-none" +); + +// Line rendering component +const LineSpan = ({ + keyedLine, + showLineNumbers, +}: { + keyedLine: KeyedLine; + showLineNumbers: boolean; +}) => ( + + {keyedLine.tokens.length === 0 + ? "\n" + : keyedLine.tokens.map(({ token, key }) => ( + + ))} + +); + +// Types +type CodeBlockProps = HTMLAttributes & { + code: string; + language: BundledLanguage; + showLineNumbers?: boolean; +}; + +interface TokenizedCode { + tokens: ThemedToken[][]; + fg: string; + bg: string; +} + +interface CodeBlockContextType { + code: string; +} + +// Context +const CodeBlockContext = createContext({ + code: "", +}); + +// Highlighter cache (singleton per language) +const highlighterCache = new Map< + string, + Promise> +>(); + +// Token cache +const tokensCache = new Map(); + +// Subscribers for async token updates +const subscribers = new Map void>>(); + +const getTokensCacheKey = (code: string, language: BundledLanguage) => { + const start = code.slice(0, 100); + const end = code.length > 100 ? code.slice(-100) : ""; + return `${language}:${code.length}:${start}:${end}`; +}; + +const getHighlighter = ( + language: BundledLanguage +): Promise> => { + const cached = highlighterCache.get(language); + if (cached) { + return cached; + } + + const highlighterPromise = createHighlighter({ + langs: [language], + themes: ["github-light", "github-dark"], + }); + + highlighterCache.set(language, highlighterPromise); + return highlighterPromise; +}; + +// Create raw tokens for immediate display while highlighting loads +const createRawTokens = (code: string): TokenizedCode => ({ + bg: "transparent", + fg: "inherit", + tokens: code.split("\n").map((line) => + line === "" + ? [] + : [ + { + color: "inherit", + content: line, + } as ThemedToken, + ] + ), +}); + +// Synchronous highlight with callback for async results +export const highlightCode = ( + code: string, + language: BundledLanguage, + // oxlint-disable-next-line eslint-plugin-promise(prefer-await-to-callbacks) + callback?: (result: TokenizedCode) => void +): TokenizedCode | null => { + const tokensCacheKey = getTokensCacheKey(code, language); + + // Return cached result if available + const cached = tokensCache.get(tokensCacheKey); + if (cached) { + return cached; + } + + // Subscribe callback if provided + if (callback) { + if (!subscribers.has(tokensCacheKey)) { + subscribers.set(tokensCacheKey, new Set()); + } + subscribers.get(tokensCacheKey)?.add(callback); + } + + // Start highlighting in background - fire-and-forget async pattern + getHighlighter(language) + // oxlint-disable-next-line eslint-plugin-promise(prefer-await-to-then) + .then((highlighter) => { + const availableLangs = highlighter.getLoadedLanguages(); + const langToUse = availableLangs.includes(language) ? language : "text"; + + const result = highlighter.codeToTokens(code, { + lang: langToUse, + themes: { + dark: "github-dark", + light: "github-light", + }, + }); + + const tokenized: TokenizedCode = { + bg: result.bg ?? "transparent", + fg: result.fg ?? "inherit", + tokens: result.tokens, + }; + + // Cache the result + tokensCache.set(tokensCacheKey, tokenized); + + // Notify all subscribers + const subs = subscribers.get(tokensCacheKey); + if (subs) { + for (const sub of subs) { + sub(tokenized); + } + subscribers.delete(tokensCacheKey); + } + }) + // oxlint-disable-next-line eslint-plugin-promise(prefer-await-to-then), eslint-plugin-promise(prefer-await-to-callbacks) + .catch((error) => { + console.error("Failed to highlight code:", error); + subscribers.delete(tokensCacheKey); + }); + + return null; +}; + +const CodeBlockBody = memo( + ({ + tokenized, + showLineNumbers, + className, + }: { + tokenized: TokenizedCode; + showLineNumbers: boolean; + className?: string; + }) => { + const preStyle = useMemo( + () => ({ + backgroundColor: tokenized.bg, + color: tokenized.fg, + }), + [tokenized.bg, tokenized.fg] + ); + + const keyedLines = useMemo( + () => addKeysToTokens(tokenized.tokens), + [tokenized.tokens] + ); + + return ( +
+        
+          {keyedLines.map((keyedLine) => (
+            
+          ))}
+        
+      
+ ); + }, + (prevProps, nextProps) => + prevProps.tokenized === nextProps.tokenized && + prevProps.showLineNumbers === nextProps.showLineNumbers && + prevProps.className === nextProps.className +); + +CodeBlockBody.displayName = "CodeBlockBody"; + +export const CodeBlockContainer = ({ + className, + language, + style, + ...props +}: HTMLAttributes & { language: string }) => ( +
+); + +export const CodeBlockHeader = ({ + children, + className, + ...props +}: HTMLAttributes) => ( +
+ {children} +
+); + +export const CodeBlockTitle = ({ + children, + className, + ...props +}: HTMLAttributes) => ( +
+ {children} +
+); + +export const CodeBlockFilename = ({ + children, + className, + ...props +}: HTMLAttributes) => ( + + {children} + +); + +export const CodeBlockActions = ({ + children, + className, + ...props +}: HTMLAttributes) => ( +
+ {children} +
+); + +export const CodeBlockContent = ({ + code, + language, + showLineNumbers = false, +}: { + code: string; + language: BundledLanguage; + showLineNumbers?: boolean; +}) => { + // Memoized raw tokens for immediate display + const rawTokens = useMemo(() => createRawTokens(code), [code]); + + // Synchronous cache lookup — avoids setState in effect for cached results + const syncTokens = useMemo( + () => highlightCode(code, language) ?? rawTokens, + [code, language, rawTokens] + ); + + // Async highlighting result (populated after shiki loads) + const [asyncTokens, setAsyncTokens] = useState(null); + const asyncKeyRef = useRef({ code, language }); + + // Invalidate stale async tokens synchronously during render + if ( + asyncKeyRef.current.code !== code || + asyncKeyRef.current.language !== language + ) { + asyncKeyRef.current = { code, language }; + setAsyncTokens(null); + } + + useEffect(() => { + let cancelled = false; + + highlightCode(code, language, (result) => { + if (!cancelled) { + setAsyncTokens(result); + } + }); + + return () => { + cancelled = true; + }; + }, [code, language]); + + const tokenized = asyncTokens ?? syncTokens; + + return ( +
+ +
+ ); +}; + +export const CodeBlock = ({ + code, + language, + showLineNumbers = false, + className, + children, + ...props +}: CodeBlockProps) => { + const contextValue = useMemo(() => ({ code }), [code]); + + return ( + + + {children} + + + + ); +}; + +export type CodeBlockCopyButtonProps = ComponentProps & { + onCopy?: () => void; + onError?: (error: Error) => void; + timeout?: number; +}; + +export const CodeBlockCopyButton = ({ + onCopy, + onError, + timeout = 2000, + children, + className, + ...props +}: CodeBlockCopyButtonProps) => { + const [isCopied, setIsCopied] = useState(false); + const timeoutRef = useRef(0); + const { code } = useContext(CodeBlockContext); + + const copyToClipboard = useCallback(async () => { + if (typeof window === "undefined" || !navigator?.clipboard?.writeText) { + onError?.(new Error("Clipboard API not available")); + return; + } + + try { + if (!isCopied) { + await navigator.clipboard.writeText(code); + setIsCopied(true); + onCopy?.(); + timeoutRef.current = window.setTimeout( + () => setIsCopied(false), + timeout + ); + } + } catch (error) { + onError?.(error as Error); + } + }, [code, onCopy, onError, timeout, isCopied]); + + useEffect( + () => () => { + window.clearTimeout(timeoutRef.current); + }, + [] + ); + + const Icon = isCopied ? CheckIcon : CopyIcon; + + return ( + + ); +}; + +export type CodeBlockLanguageSelectorProps = ComponentProps; + +export const CodeBlockLanguageSelector = ( + props: CodeBlockLanguageSelectorProps +) => + ); +}; + +export type WebPreviewBodyProps = ComponentProps<"iframe"> & { + loading?: ReactNode; +}; + +export const WebPreviewBody = ({ + className, + loading, + src, + ...props +}: WebPreviewBodyProps) => { + const { url } = useWebPreview(); + + return ( +
+