feat(core): initialize project with access control and AI integration
This commit is contained in:
parent
14f6e1e500
commit
ba2490dab4
@ -51,9 +51,9 @@ APP_DOMAIN_URL=http://127.0.0.1
|
||||
|
||||
# APP_DATABASE_MAX_CONNECTIONS=10
|
||||
# APP_DATABASE_MIN_CONNECTIONS=2
|
||||
# APP_DATABASE_IDLE_TIMEOUT=60000
|
||||
# APP_DATABASE_MAX_LIFETIME=300000
|
||||
# APP_DATABASE_CONNECTION_TIMEOUT=5000
|
||||
# APP_DATABASE_IDLE_TIMEOUT=60000 (milliseconds, default: 60s)
|
||||
# APP_DATABASE_MAX_LIFETIME=300000 (milliseconds, default: 300s)
|
||||
# APP_DATABASE_CONNECTION_TIMEOUT=5000 (milliseconds, default: 5s)
|
||||
# APP_DATABASE_REPLICAS=
|
||||
# APP_DATABASE_HEALTH_CHECK_INTERVAL=30
|
||||
# APP_DATABASE_RETRY_ATTEMPTS=3
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -23,3 +23,6 @@ coverage/
|
||||
pnpm-lock.yaml
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
.gemini
|
||||
.omg
|
||||
/.sqry
|
||||
11
.mcp.json
Normal file
11
.mcp.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"shadcn": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"shadcn@latest",
|
||||
"mcp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
1537
Cargo.lock
generated
1537
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
20
Cargo.toml
20
Cargo.toml
@ -11,19 +11,21 @@ members = [
|
||||
"libs/service",
|
||||
"libs/db",
|
||||
"libs/api",
|
||||
"libs/webhook",
|
||||
"libs/transport",
|
||||
"libs/observability",
|
||||
"libs/avatar",
|
||||
"libs/agent",
|
||||
"libs/migrate",
|
||||
"libs/fctool",
|
||||
"libs/gingress-proxy",
|
||||
"apps/migrate",
|
||||
"apps/app",
|
||||
"apps/git-hook",
|
||||
"apps/gitserver",
|
||||
"apps/email",
|
||||
"apps/static",
|
||||
"apps/metrics",
|
||||
"apps/gingress",
|
||||
]
|
||||
|
||||
resolver = "3"
|
||||
@ -40,12 +42,14 @@ service = { path = "libs/service" }
|
||||
db = { path = "libs/db" }
|
||||
api = { path = "libs/api" }
|
||||
agent = { path = "libs/agent" }
|
||||
webhook = { path = "libs/webhook" }
|
||||
observability = { path = "libs/observability" }
|
||||
avatar = { path = "libs/avatar" }
|
||||
migrate = { path = "libs/migrate" }
|
||||
fctool = { path = "libs/fctool" }
|
||||
transport = { path = "libs/transport" }
|
||||
metrics-aggregator = { path = "apps/metrics" }
|
||||
gingress-proxy = { path = "libs/gingress-proxy" }
|
||||
gingress = { path = "apps/gingress" }
|
||||
|
||||
sea-query = "1.0.0-rc.33"
|
||||
|
||||
@ -131,7 +135,10 @@ tokio = "1.50.0"
|
||||
tokio-util = "0.7.18"
|
||||
tokio-stream = "0.1.18"
|
||||
url = "2.5.8"
|
||||
tower = "0.5"
|
||||
num_cpus = "1.17.0"
|
||||
ring = "0.17"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std", "tls12"] }
|
||||
clap = "4.6.0"
|
||||
time = "0.3.47"
|
||||
chrono = "0.4.44"
|
||||
@ -165,12 +172,19 @@ phf_codegen = "0.13.1"
|
||||
base64 = "0.22.1"
|
||||
base64ct = "1"
|
||||
p256 = { version = "0.13", features = ["ecdsa", "std"] }
|
||||
http = "1"
|
||||
# http version varies per-crate (pingora needs 1.x, actix needs 0.2)
|
||||
hyper = "0.14"
|
||||
tempfile = "3"
|
||||
rig-core = { version = "0.30.0", default-features = false }
|
||||
tokio-tungstenite = { version = "0.29.0", features = [] }
|
||||
async-nats = { version = "0.47.0", features = [] }
|
||||
kube = { version = "0.98", features = ["runtime", "derive"] }
|
||||
k8s-openapi = { version = "0.24", features = ["v1_31"] }
|
||||
pingora = { version = "0.8", features = ["proxy"] }
|
||||
pingora-proxy = "0.8"
|
||||
pingora-load-balancing = "0.8"
|
||||
pingora-cache = "0.8"
|
||||
rustls-pemfile = "2"
|
||||
[workspace.package]
|
||||
version = "0.2.9"
|
||||
edition = "2024"
|
||||
|
||||
@ -9,6 +9,7 @@ use futures::future::LocalBoxFuture;
|
||||
use observability::{
|
||||
init_tracing_subscriber, install_recorder, prometheus_handler, spawn_http_metrics_poller,
|
||||
HttpMetrics, HttpSnapshotGuard, MetricsMiddleware, TracingSpanMiddleware,
|
||||
push::MetricsPusher,
|
||||
};
|
||||
use sea_orm::ConnectionTrait;
|
||||
use service::AppService;
|
||||
@ -17,6 +18,7 @@ use api::{robots, sidemap};
|
||||
use session::storage::RedisClusterSessionStore;
|
||||
use session::SessionMiddleware;
|
||||
use std::task::{Context, Poll};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
mod args;
|
||||
@ -151,7 +153,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let service = AppService::new(cfg.clone()).await?;
|
||||
tracing::info!("AppService initialized");
|
||||
let _model_sync_handle = service.clone().start_sync_task();
|
||||
let _billing_alert_handle = service.clone().start_billing_alert_task();
|
||||
// TODO: workspace module not yet wired — billing alert task pending
|
||||
// let _billing_alert_handle = service.clone().start_billing_alert_task();
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
|
||||
let worker_service = service.clone();
|
||||
@ -192,6 +195,13 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
let http_snapshot_data = web::Data::new(http_snapshot);
|
||||
|
||||
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
|
||||
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
|
||||
let pusher = MetricsPusher::new(&push_url, "app");
|
||||
pusher.spawn(http_metrics.clone(), Arc::new(prometheus_handle.clone()), std::time::Duration::from_secs(15));
|
||||
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
|
||||
}
|
||||
|
||||
let bind_addr = args.bind.unwrap_or_else(|| "127.0.0.1:8080".to_string());
|
||||
tracing::info!(bind_addr = %bind_addr, "Listening");
|
||||
let http_metrics_server = http_metrics.clone();
|
||||
@ -212,11 +222,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
cors = cors.allowed_origin(origin);
|
||||
}
|
||||
let cors = cors
|
||||
.allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE"])
|
||||
.allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
|
||||
.allowed_headers(["Content-Type", "Authorization", "X-Requested-With", "Accept", "Origin"])
|
||||
.supports_credentials()
|
||||
.max_age(3600);
|
||||
|
||||
let security_headers = actix_web::middleware::DefaultHeaders::new()
|
||||
.add(("X-Content-Type-Options", "nosniff"))
|
||||
.add(("X-Frame-Options", "DENY"))
|
||||
.add(("Referrer-Policy", "strict-origin-when-cross-origin"));
|
||||
|
||||
let session_mw = SessionMiddleware::builder(store.clone(), session_key.clone())
|
||||
.cookie_name("id".to_string())
|
||||
.cookie_path("/".to_string())
|
||||
@ -233,6 +248,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
App::new()
|
||||
.wrap(cors)
|
||||
.wrap(security_headers)
|
||||
.wrap(session_mw)
|
||||
.wrap(RequestLogger)
|
||||
.wrap(metrics_mw)
|
||||
|
||||
@ -2,7 +2,7 @@ use clap::Parser;
|
||||
use config::AppConfig;
|
||||
use metrics::{describe_counter, Unit};
|
||||
use metrics_exporter_prometheus::PrometheusHandle;
|
||||
use observability::{init_tracing_subscriber, install_recorder};
|
||||
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
|
||||
use sea_orm::ConnectionTrait;
|
||||
use service::AppService;
|
||||
use std::sync::Arc;
|
||||
@ -88,6 +88,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
describe_counter!("email_send_failures_total", Unit::Count, "Emails that failed after all retries");
|
||||
|
||||
let metrics_handle = Arc::new(install_recorder());
|
||||
let http_metrics = Arc::new(HttpMetrics::new()); // Worker app — HTTP section will be empty
|
||||
|
||||
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
|
||||
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
|
||||
let pusher = MetricsPusher::new(&push_url, "email");
|
||||
pusher.spawn(http_metrics.clone(), metrics_handle.clone(), std::time::Duration::from_secs(15));
|
||||
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
|
||||
}
|
||||
|
||||
tracing::info!("Starting email worker");
|
||||
let service = AppService::new(cfg).await?;
|
||||
|
||||
46
apps/gingress/Cargo.toml
Normal file
46
apps/gingress/Cargo.toml
Normal file
@ -0,0 +1,46 @@
|
||||
[package]
|
||||
name = "gingress"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description = "GIngress control plane: Kubernetes Ingress Controller using kube-rs"
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "gingress"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "kubectl-gingress"
|
||||
path = "src/bin/kubectl-gingress/main.rs"
|
||||
|
||||
[dependencies]
|
||||
gingress-proxy = { workspace = true }
|
||||
|
||||
kube = { version = "0.98", features = ["runtime", "derive"] }
|
||||
k8s-openapi = { version = "0.24", features = ["v1_31"] }
|
||||
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
observability = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
url = { workspace = true }
|
||||
x509-parser = "0.17"
|
||||
rustls-pemfile = "2"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
797
apps/gingress/src/bin/kubectl-gingress/main.rs
Normal file
797
apps/gingress/src/bin/kubectl-gingress/main.rs
Normal file
@ -0,0 +1,797 @@
|
||||
//! kubectl-gingress — kubectl plugin for managing GIngress resources.
|
||||
//!
|
||||
//! Usage (via kubectl): kubectl gingress <subcommand>
|
||||
//! Usage (standalone): kubectl-gingress <subcommand>
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
use k8s_openapi::api::core::v1::{Pod, Secret};
|
||||
use k8s_openapi::api::networking::v1::{HTTPIngressPath, Ingress};
|
||||
use kube::api::ListParams;
|
||||
use kube::{Api, Client, ResourceExt};
|
||||
|
||||
const INGRESS_CLASS: &str = "gingress";
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(
|
||||
name = "kubectl-gingress",
|
||||
bin_name = "kubectl gingress",
|
||||
about = "Manage GIngress — Kubernetes Ingress Controller",
|
||||
version
|
||||
)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Command {
|
||||
/// List all Ingress resources managed by GIngress
|
||||
#[command(alias = "ls")]
|
||||
List {
|
||||
/// Filter by namespace (omit for all namespaces)
|
||||
#[arg(short, long)]
|
||||
namespace: Option<String>,
|
||||
/// Output as JSON
|
||||
#[arg(long)]
|
||||
json: bool,
|
||||
},
|
||||
/// Show the routing table (host → path → backend)
|
||||
Routes {
|
||||
/// Filter by namespace
|
||||
#[arg(short, long)]
|
||||
namespace: Option<String>,
|
||||
/// Filter by host
|
||||
#[arg(short = 'H', long)]
|
||||
host: Option<String>,
|
||||
/// Output as JSON
|
||||
#[arg(long)]
|
||||
json: bool,
|
||||
},
|
||||
/// Show backend services and their endpoints
|
||||
Backends {
|
||||
/// Filter by namespace
|
||||
#[arg(short, long)]
|
||||
namespace: Option<String>,
|
||||
/// Output as JSON
|
||||
#[arg(long)]
|
||||
json: bool,
|
||||
},
|
||||
/// List TLS certificates (from Secrets)
|
||||
Certs {
|
||||
/// Filter by namespace
|
||||
#[arg(short, long)]
|
||||
namespace: Option<String>,
|
||||
/// Output as JSON
|
||||
#[arg(long)]
|
||||
json: bool,
|
||||
},
|
||||
/// Validate Ingress configurations
|
||||
Validate {
|
||||
/// Filter by namespace
|
||||
#[arg(short, long)]
|
||||
namespace: Option<String>,
|
||||
},
|
||||
/// Show GIngress controller status and summary
|
||||
Status {
|
||||
/// Output as JSON
|
||||
#[arg(long)]
|
||||
json: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cli = Cli::parse();
|
||||
let client = Client::try_default().await?;
|
||||
|
||||
match cli.command {
|
||||
Command::List { namespace, json } => cmd_list(&client, namespace, json).await?,
|
||||
Command::Routes { namespace, host, json } => cmd_routes(&client, namespace, host, json).await?,
|
||||
Command::Backends { namespace, json } => cmd_backends(&client, namespace, json).await?,
|
||||
Command::Certs { namespace, json } => cmd_certs(&client, namespace, json).await?,
|
||||
Command::Validate { namespace } => cmd_validate(&client, namespace).await?,
|
||||
Command::Status { json } => cmd_status(&client, json).await?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── list ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_list(client: &Client, namespace: Option<String>, json: bool) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
|
||||
|
||||
if json {
|
||||
println!("{}", serde_json::to_string_pretty(&ingresses)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if ingresses.is_empty() {
|
||||
println!("No GIngress-managed Ingress resources found.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("{:<25} {:<20} {:<40} {:<50} {:<15}", "NAMESPACE", "NAME", "HOSTS", "PATHS", "TLS");
|
||||
println!("{:-<150}", "");
|
||||
|
||||
for ing in &ingresses {
|
||||
let ns = ing.namespace();
|
||||
let name = ing.name_any();
|
||||
let hosts = ing.hosts().join(", ");
|
||||
let paths = ing
|
||||
.paths_display()
|
||||
.iter()
|
||||
.map(|p| format!("{} {}", p.path_type, p.path))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let tls = if ing.has_tls() { "Enabled" } else { "-" };
|
||||
|
||||
println!("{:<25} {:<20} {:<40} {:<50} {:<15}",
|
||||
truncate(&ns, 25),
|
||||
truncate(&name, 20),
|
||||
truncate(&hosts, 40),
|
||||
truncate(&paths, 50),
|
||||
tls,
|
||||
);
|
||||
}
|
||||
|
||||
println!("\nTotal: {} Ingress(es)", ingresses.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── routes ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_routes(
|
||||
client: &Client,
|
||||
namespace: Option<String>,
|
||||
host_filter: Option<String>,
|
||||
json: bool,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
|
||||
let mut routes: Vec<RouteRow> = Vec::new();
|
||||
|
||||
for ing in &ingresses {
|
||||
for rule in ing.spec.as_ref().and_then(|s| s.rules.as_ref()).into_iter().flatten() {
|
||||
let host = rule.host.as_deref().unwrap_or("*");
|
||||
if let Some(ref hf) = host_filter {
|
||||
if host != hf { continue; }
|
||||
}
|
||||
if let Some(http) = &rule.http {
|
||||
for path_item in &http.paths {
|
||||
let backend = extract_backend(path_item);
|
||||
let port = extract_backend_port(path_item);
|
||||
routes.push(RouteRow {
|
||||
namespace: ing.namespace(),
|
||||
ingress: ing.name_any(),
|
||||
host: host.to_string(),
|
||||
path: path_item.path.clone().unwrap_or_else(|| "/".into()),
|
||||
path_type: path_item.path_type.clone(),
|
||||
backend,
|
||||
port,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if json {
|
||||
println!("{}", serde_json::to_string_pretty(&routes)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if routes.is_empty() {
|
||||
println!("No routes found.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("{:<20} {:<20} {:<30} {:<18} {:<15} {:<15} {:<15}",
|
||||
"NAMESPACE", "INGRESS", "HOST", "PATH", "TYPE", "BACKEND", "PORT");
|
||||
println!("{:-<133}", "");
|
||||
|
||||
for r in &routes {
|
||||
let port = extract_backend_port_str(r);
|
||||
println!("{:<20} {:<20} {:<30} {:<18} {:<15} {:<15} {:<15}",
|
||||
truncate(&r.namespace, 20),
|
||||
truncate(&r.ingress, 20),
|
||||
truncate(&r.host, 30),
|
||||
truncate(&r.path, 18),
|
||||
truncate(&r.path_type, 15),
|
||||
truncate(&r.backend, 15),
|
||||
port,
|
||||
);
|
||||
}
|
||||
|
||||
println!("\nTotal: {} route(s)", routes.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── backends ───────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_backends(
|
||||
client: &Client,
|
||||
namespace: Option<String>,
|
||||
json: bool,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
|
||||
|
||||
// Collect unique backends from all ingresses
|
||||
let mut backends: Vec<BackendRow> = Vec::new();
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
|
||||
for ing in &ingresses {
|
||||
for rule in ing.spec.as_ref().and_then(|s| s.rules.as_ref()).into_iter().flatten() {
|
||||
if let Some(http) = &rule.http {
|
||||
for path_item in &http.paths {
|
||||
let svc = match path_item.backend.service.as_ref() {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
let key = format!("{}/{}:{}", ing.namespace(), svc.name, svc.port.as_ref().and_then(|p| p.number).unwrap_or(80));
|
||||
if seen.insert(key.clone()) {
|
||||
let ns = ing.namespace();
|
||||
let ep_status = get_endpoint_status(client, &ns, &svc.name).await;
|
||||
backends.push(BackendRow {
|
||||
namespace: ns,
|
||||
service: svc.name.clone(),
|
||||
port: svc.port.as_ref().and_then(|p| p.number).unwrap_or(80) as u16,
|
||||
ready_endpoints: ep_status.ready,
|
||||
total_endpoints: ep_status.total,
|
||||
referenced_by: ing.name_any(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if json {
|
||||
println!("{}", serde_json::to_string_pretty(&backends)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if backends.is_empty() {
|
||||
println!("No backends found.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("{:<20} {:<20} {:<8} {:<8} {:<18} {:<20}",
|
||||
"NAMESPACE", "SERVICE", "PORT", "HEALTH", "ENDPOINTS", "REFERENCED BY");
|
||||
println!("{:-<94}", "");
|
||||
|
||||
for b in &backends {
|
||||
let health = if b.total_endpoints == 0 { "WARN" } else if b.ready_endpoints == 0 { "DOWN" } else if b.ready_endpoints < b.total_endpoints { "PARTIAL" } else { "OK" };
|
||||
let eps = format!("{}/{} ready", b.ready_endpoints, b.total_endpoints);
|
||||
println!("{:<20} {:<20} {:<8} {:<8} {:<18} {:<20}",
|
||||
truncate(&b.namespace, 20),
|
||||
truncate(&b.service, 20),
|
||||
b.port,
|
||||
health,
|
||||
eps,
|
||||
truncate(&b.referenced_by, 20),
|
||||
);
|
||||
}
|
||||
|
||||
println!("\nTotal: {} backend(s)", backends.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── certs ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_certs(
|
||||
client: &Client,
|
||||
namespace: Option<String>,
|
||||
json: bool,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
|
||||
let mut certs: Vec<CertRow> = Vec::new();
|
||||
|
||||
for ing in &ingresses {
|
||||
let ns = ing.namespace();
|
||||
let tls_entries = ing
|
||||
.spec
|
||||
.as_ref()
|
||||
.and_then(|s| s.tls.as_ref())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
for tls in &tls_entries {
|
||||
let secret_name = tls.secret_name.clone().unwrap_or_default();
|
||||
let hosts = tls.hosts.clone().unwrap_or_default();
|
||||
|
||||
// Check if the secret exists
|
||||
let secret_exists = check_secret_exists(client, &ns, &secret_name).await;
|
||||
|
||||
for host in &hosts {
|
||||
certs.push(CertRow {
|
||||
namespace: ns.clone(),
|
||||
secret_name: secret_name.clone(),
|
||||
host: host.clone(),
|
||||
found: secret_exists,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if json {
|
||||
println!("{}", serde_json::to_string_pretty(&certs)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if certs.is_empty() {
|
||||
println!("No TLS certificates configured.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("{:<20} {:<30} {:<30} {:<10}", "NAMESPACE", "SECRET", "HOST", "STATUS");
|
||||
println!("{:-<90}", "");
|
||||
|
||||
for c in &certs {
|
||||
let status = if c.found { "OK" } else { "MISSING" };
|
||||
println!("{:<20} {:<30} {:<30} {:<10}",
|
||||
truncate(&c.namespace, 20),
|
||||
truncate(&c.secret_name, 30),
|
||||
truncate(&c.host, 30),
|
||||
status,
|
||||
);
|
||||
}
|
||||
|
||||
let missing = certs.iter().filter(|c| !c.found).count();
|
||||
println!("\nTotal: {} cert(s), {} missing", certs.len(), missing);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── validate ───────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_validate(
|
||||
client: &Client,
|
||||
namespace: Option<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
|
||||
let mut errors = 0usize;
|
||||
let mut warnings = 0usize;
|
||||
|
||||
for ing in &ingresses {
|
||||
let ns = ing.namespace();
|
||||
let name = ing.name_any();
|
||||
|
||||
// Check: has rules
|
||||
let has_rules = ing
|
||||
.spec
|
||||
.as_ref()
|
||||
.map(|s| s.rules.as_ref().map(|r| !r.is_empty()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
if !has_rules {
|
||||
println!("[{}/{}] ERROR: No routing rules defined", ns, name);
|
||||
errors += 1;
|
||||
}
|
||||
|
||||
// Check: has TLS but no secret
|
||||
let tls_entries = ing
|
||||
.spec
|
||||
.as_ref()
|
||||
.and_then(|s| s.tls.as_ref())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
for tls in &tls_entries {
|
||||
let secret_name = tls.secret_name.as_deref().unwrap_or("");
|
||||
if secret_name.is_empty() {
|
||||
println!("[{}/{}] WARNING: TLS configured but no secretName specified", ns, name);
|
||||
warnings += 1;
|
||||
} else {
|
||||
let found = check_secret_exists(client, &ns, secret_name).await;
|
||||
if !found {
|
||||
println!("[{}/{}] ERROR: TLS secret '{}' not found in namespace '{}'", ns, name, secret_name, ns);
|
||||
errors += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check: path backends reference valid services
|
||||
if let Some(rules) = ing.spec.as_ref().and_then(|s| s.rules.as_ref()) {
|
||||
for rule in rules {
|
||||
if let Some(http) = &rule.http {
|
||||
for path_item in &http.paths {
|
||||
if let Some(svc) = &path_item.backend.service {
|
||||
let endpoints = get_endpoint_status(client, &ns, &svc.name).await;
|
||||
if endpoints.total == 0 {
|
||||
println!(
|
||||
"[{}/{}] WARNING: Backend service '{}' has no endpoints (host: {})",
|
||||
ns, name, svc.name,
|
||||
rule.host.as_deref().unwrap_or("*")
|
||||
);
|
||||
warnings += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if errors == 0 && warnings == 0 {
|
||||
println!("Validation passed — no issues found in {} Ingress(es).", ingresses.len());
|
||||
} else {
|
||||
println!("\nValidation complete: {} error(s), {} warning(s) across {} Ingress(es).",
|
||||
errors, warnings, ingresses.len());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── status ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_status(client: &Client, json: bool) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ingresses = list_ingresses(client, None).await?;
|
||||
let controller_pods = find_gingress_pods(client).await;
|
||||
|
||||
if json {
|
||||
#[derive(serde::Serialize)]
|
||||
struct StatusOutput {
|
||||
controller_pods: Vec<PodInfo>,
|
||||
ingress_count: usize,
|
||||
route_count: usize,
|
||||
backend_count: usize,
|
||||
cert_count: usize,
|
||||
}
|
||||
let mut route_count = 0usize;
|
||||
let mut backend_set = std::collections::HashSet::new();
|
||||
let mut cert_count = 0usize;
|
||||
for ing in &ingresses {
|
||||
if let Some(spec) = &ing.spec {
|
||||
if let Some(rules) = &spec.rules {
|
||||
for rule in rules {
|
||||
if let Some(http) = &rule.http {
|
||||
route_count += http.paths.len();
|
||||
for p in &http.paths {
|
||||
if let Some(svc) = &p.backend.service {
|
||||
backend_set.insert(format!("{}/{}", ing.namespace(), svc.name));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(tls) = &spec.tls {
|
||||
cert_count += tls.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"{}",
|
||||
serde_json::to_string_pretty(&StatusOutput {
|
||||
controller_pods,
|
||||
ingress_count: ingresses.len(),
|
||||
route_count,
|
||||
backend_count: backend_set.len(),
|
||||
cert_count,
|
||||
})?
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Controller status
|
||||
println!("══ GIngress Controller Status ══\n");
|
||||
|
||||
if controller_pods.is_empty() {
|
||||
println!("Controller: NOT FOUND (no pods with label gingress.io/component=controller)");
|
||||
} else {
|
||||
for pod in &controller_pods {
|
||||
let ready = if pod.ready { "Running" } else { "NotReady" };
|
||||
println!(
|
||||
"Controller Pod: {:<30} {:<12} {}",
|
||||
truncate(&pod.name, 30),
|
||||
ready,
|
||||
pod.namespace
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Resource summary
|
||||
println!();
|
||||
println!("══ Managed Resources ══\n");
|
||||
|
||||
let mut route_count = 0usize;
|
||||
let mut backend_set = std::collections::HashSet::new();
|
||||
let mut backend_total_eps = 0usize;
|
||||
let mut backend_ready_eps = 0usize;
|
||||
let mut cert_count = 0usize;
|
||||
let mut cert_missing = 0usize;
|
||||
|
||||
for ing in &ingresses {
|
||||
if let Some(spec) = &ing.spec {
|
||||
if let Some(rules) = &spec.rules {
|
||||
for rule in rules {
|
||||
if let Some(http) = &rule.http {
|
||||
route_count += http.paths.len();
|
||||
for p in &http.paths {
|
||||
if let Some(svc) = &p.backend.service {
|
||||
let key = format!("{}/{}", ing.namespace(), svc.name);
|
||||
if backend_set.insert(key) {
|
||||
let eps = get_endpoint_status(client, &ing.namespace(), &svc.name).await;
|
||||
backend_total_eps += eps.total;
|
||||
backend_ready_eps += eps.ready;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(tls) = &spec.tls {
|
||||
cert_count += tls.len();
|
||||
for tls in tls {
|
||||
let sn = tls.secret_name.as_deref().unwrap_or("");
|
||||
if !sn.is_empty() && !check_secret_exists(client, &ing.namespace(), sn).await {
|
||||
cert_missing += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Ingresses: {}", ingresses.len());
|
||||
println!("Routes: {}", route_count);
|
||||
println!(
|
||||
"Backends: {} ({} ready / {} total endpoints)",
|
||||
backend_set.len(), backend_ready_eps, backend_total_eps
|
||||
);
|
||||
println!("TLS Certs: {} ({} missing)", cert_count, cert_missing);
|
||||
println!();
|
||||
|
||||
// Overall health
|
||||
let healthy = !ingresses.is_empty()
|
||||
&& (cert_missing == 0)
|
||||
&& (backend_set.is_empty() || backend_ready_eps > 0)
|
||||
&& controller_pods.iter().any(|p| p.ready);
|
||||
|
||||
if healthy {
|
||||
println!("Status: HEALTHY");
|
||||
} else {
|
||||
println!("Status: DEGRADED");
|
||||
if controller_pods.is_empty() || !controller_pods.iter().any(|p| p.ready) {
|
||||
println!(" → Controller pod not running or not ready");
|
||||
}
|
||||
if cert_missing > 0 {
|
||||
println!(" → {} TLS secret(s) missing", cert_missing);
|
||||
}
|
||||
if !backend_set.is_empty() && backend_ready_eps == 0 {
|
||||
println!(" → No ready backend endpoints");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── k8s helpers ────────────────────────────────────────────────────
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct IngressSummary {
|
||||
namespace: String,
|
||||
name: String,
|
||||
hosts: Vec<String>,
|
||||
#[serde(skip)]
|
||||
paths_for_display: Vec<PathSummary>,
|
||||
has_tls: bool,
|
||||
#[serde(skip)]
|
||||
spec: Option<k8s_openapi::api::networking::v1::IngressSpec>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PathSummary {
|
||||
path_type: String,
|
||||
path: String,
|
||||
}
|
||||
|
||||
impl IngressSummary {
|
||||
fn namespace(&self) -> String { self.namespace.clone() }
|
||||
fn name_any(&self) -> String { self.name.clone() }
|
||||
fn hosts(&self) -> &[String] { &self.hosts }
|
||||
fn paths_display(&self) -> &[PathSummary] { &self.paths_for_display }
|
||||
fn has_tls(&self) -> bool { self.has_tls }
|
||||
}
|
||||
|
||||
fn ingress_to_summary(ing: &Ingress) -> IngressSummary {
|
||||
let spec = ing.spec.clone();
|
||||
let mut hosts = Vec::new();
|
||||
let mut paths = Vec::new();
|
||||
let mut has_tls = false;
|
||||
|
||||
if let Some(ref s) = spec {
|
||||
if let Some(ref rules) = s.rules {
|
||||
for rule in rules {
|
||||
if let Some(ref host) = rule.host {
|
||||
hosts.push(host.clone());
|
||||
}
|
||||
if let Some(ref http) = rule.http {
|
||||
for p in &http.paths {
|
||||
paths.push(PathSummary {
|
||||
path_type: p.path_type.clone(),
|
||||
path: p.path.clone().unwrap_or_else(|| "/".into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
has_tls = s.tls.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
||||
}
|
||||
|
||||
IngressSummary {
|
||||
namespace: ing.namespace().unwrap_or_default(),
|
||||
name: ing.name_any(),
|
||||
hosts,
|
||||
paths_for_display: paths,
|
||||
has_tls,
|
||||
spec,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct PodInfo {
|
||||
name: String,
|
||||
namespace: String,
|
||||
ready: bool,
|
||||
}
|
||||
|
||||
/// Find GIngress controller pods by label `gingress.io/component=controller`.
|
||||
async fn find_gingress_pods(client: &Client) -> Vec<PodInfo> {
|
||||
let api: Api<Pod> = Api::all(client.clone());
|
||||
let lp = ListParams {
|
||||
label_selector: Some("gingress.io/component=controller".into()),
|
||||
..Default::default()
|
||||
};
|
||||
match api.list(&lp).await {
|
||||
Ok(list) => list
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|pod| {
|
||||
let ready = pod
|
||||
.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.conditions.as_ref())
|
||||
.map(|conds| {
|
||||
conds
|
||||
.iter()
|
||||
.any(|c| c.type_ == "Ready" && c.status == "True")
|
||||
})
|
||||
.unwrap_or(false);
|
||||
PodInfo {
|
||||
name: pod.name_any(),
|
||||
namespace: pod.namespace().unwrap_or_default(),
|
||||
ready,
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
Err(_) => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_ingresses(client: &Client, namespace: Option<&str>) -> Result<Vec<IngressSummary>, Box<dyn std::error::Error>> {
|
||||
let params = ListParams {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(ns) = namespace {
|
||||
let api: Api<Ingress> = Api::namespaced(client.clone(), ns);
|
||||
let list = api.list(¶ms).await?;
|
||||
Ok(list
|
||||
.items
|
||||
.into_iter()
|
||||
.filter(|ing| is_gingress_class(ing))
|
||||
.map(|ing| ingress_to_summary(&ing))
|
||||
.collect())
|
||||
} else {
|
||||
let api: Api<Ingress> = Api::all(client.clone());
|
||||
let list = api.list(¶ms).await?;
|
||||
Ok(list
|
||||
.items
|
||||
.into_iter()
|
||||
.filter(|ing| is_gingress_class(ing))
|
||||
.map(|ing| ingress_to_summary(&ing))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn is_gingress_class(ingress: &Ingress) -> bool {
|
||||
ingress
|
||||
.spec
|
||||
.as_ref()
|
||||
.and_then(|s| s.ingress_class_name.as_deref())
|
||||
== Some(INGRESS_CLASS)
|
||||
}
|
||||
|
||||
struct EndpointStatus {
|
||||
ready: usize,
|
||||
total: usize,
|
||||
}
|
||||
|
||||
async fn get_endpoint_status(client: &Client, namespace: &str, service_name: &str) -> EndpointStatus {
|
||||
use k8s_openapi::api::core::v1::Endpoints;
|
||||
let api: Api<Endpoints> = Api::namespaced(client.clone(), namespace);
|
||||
match api.get_opt(service_name).await {
|
||||
Ok(Some(eps)) => {
|
||||
let mut ready = 0usize;
|
||||
let mut total = 0usize;
|
||||
if let Some(subsets) = &eps.subsets {
|
||||
for subset in subsets {
|
||||
let addrs = subset.addresses.as_deref().unwrap_or_default();
|
||||
let not_ready = subset.not_ready_addresses.as_deref().unwrap_or_default();
|
||||
ready += addrs.len();
|
||||
total += addrs.len() + not_ready.len();
|
||||
}
|
||||
}
|
||||
EndpointStatus { ready, total }
|
||||
}
|
||||
_ => EndpointStatus { ready: 0, total: 0 },
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_secret_exists(client: &Client, namespace: &str, name: &str) -> bool {
|
||||
let api: Api<Secret> = Api::namespaced(client.clone(), namespace);
|
||||
api.get_opt(name).await.ok().flatten().is_some()
|
||||
}
|
||||
|
||||
fn extract_backend(path_item: &HTTPIngressPath) -> String {
|
||||
path_item
|
||||
.backend
|
||||
.service
|
||||
.as_ref()
|
||||
.map(|s| s.name.clone())
|
||||
.unwrap_or_else(|| "<resource>".into())
|
||||
}
|
||||
|
||||
fn extract_backend_port(path_item: &HTTPIngressPath) -> u16 {
|
||||
path_item
|
||||
.backend
|
||||
.service
|
||||
.as_ref()
|
||||
.and_then(|s| s.port.as_ref())
|
||||
.and_then(|p| p.number)
|
||||
.unwrap_or(80) as u16
|
||||
}
|
||||
|
||||
fn extract_backend_port_str(r: &RouteRow) -> String {
|
||||
r.port.to_string()
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RouteRow {
|
||||
namespace: String,
|
||||
ingress: String,
|
||||
host: String,
|
||||
path: String,
|
||||
path_type: String,
|
||||
backend: String,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct BackendRow {
|
||||
namespace: String,
|
||||
service: String,
|
||||
port: u16,
|
||||
ready_endpoints: usize,
|
||||
total_endpoints: usize,
|
||||
referenced_by: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct CertRow {
|
||||
namespace: String,
|
||||
secret_name: String,
|
||||
host: String,
|
||||
found: bool,
|
||||
}
|
||||
|
||||
fn truncate(s: &str, max: usize) -> String {
|
||||
// Account for CJK characters — each wide char counts as 2
|
||||
let mut width = 0usize;
|
||||
let mut result = String::new();
|
||||
for c in s.chars() {
|
||||
let cw = if c.is_ascii() { 1 } else { 2 };
|
||||
if width + cw > max {
|
||||
result.push_str("…");
|
||||
break;
|
||||
}
|
||||
result.push(c);
|
||||
width += cw;
|
||||
}
|
||||
result
|
||||
}
|
||||
151
apps/gingress/src/controller/endpoint_watcher.rs
Normal file
151
apps/gingress/src/controller/endpoint_watcher.rs
Normal file
@ -0,0 +1,151 @@
|
||||
//! Watches Kubernetes Endpoints and updates upstream endpoint lists.
|
||||
//!
|
||||
//! Tracks Pod IPs for each Service. When endpoints change (scale up/down,
|
||||
//! rolling restart, health check failures), the upstream pool is updated.
|
||||
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use gingress_proxy::config::{ConfigStore, Endpoint};
|
||||
use k8s_openapi::api::core::v1::Endpoints as K8sEndpoints;
|
||||
use kube::ResourceExt;
|
||||
use kube::runtime::watcher::{self, Event};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Watch Endpoints and update the ConfigStore.
|
||||
pub async fn watch_endpoints(
|
||||
client: Arc<kube::Client>,
|
||||
store: Arc<ConfigStore>,
|
||||
_namespace: Option<String>,
|
||||
on_change: Arc<dyn Fn() + Send + Sync>,
|
||||
) {
|
||||
let api = kube::Api::<K8sEndpoints>::all(client.as_ref().clone());
|
||||
let config = watcher::Config::default();
|
||||
let watcher = watcher::watcher(api, config);
|
||||
pin_mut!(watcher);
|
||||
|
||||
while let Some(event) = watcher.next().await {
|
||||
match event {
|
||||
Ok(Event::Apply(eps)) => {
|
||||
process_endpoints(&eps, &store, &on_change);
|
||||
}
|
||||
Ok(Event::Init) => {
|
||||
tracing::info!("Endpoint watcher re-initializing");
|
||||
}
|
||||
Ok(Event::InitApply(eps)) => {
|
||||
process_endpoints(&eps, &store, &on_change);
|
||||
}
|
||||
Ok(Event::InitDone) => {
|
||||
tracing::info!("Endpoint watcher init complete");
|
||||
}
|
||||
Ok(Event::Delete(eps)) => {
|
||||
remove_endpoints(&eps, &store, &on_change);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Endpoint watcher error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract endpoint addresses, grouped by port, and update the ConfigStore.
|
||||
///
|
||||
/// Stores endpoints under key `upstream:<ns>/<name>:<port>` to match
|
||||
/// the proxy's upstream lookup format.
|
||||
fn process_endpoints(
|
||||
endpoints: &K8sEndpoints,
|
||||
store: &ConfigStore,
|
||||
on_change: &Arc<dyn Fn() + Send + Sync>,
|
||||
) {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let name = endpoints.name_any();
|
||||
let namespace = endpoints.namespace().unwrap_or_default();
|
||||
let base_prefix = format!("upstream:{}/{}:", namespace, name);
|
||||
|
||||
// Collect endpoints grouped by port
|
||||
let mut port_groups: HashMap<u16, Vec<Endpoint>> = HashMap::new();
|
||||
|
||||
if let Some(subsets) = &endpoints.subsets {
|
||||
for subset in subsets {
|
||||
let addrs = subset.addresses.as_deref().unwrap_or_default();
|
||||
let ports = subset.ports.as_deref().unwrap_or_default();
|
||||
let not_ready_addrs = subset.not_ready_addresses.as_deref().unwrap_or_default();
|
||||
|
||||
for port in ports {
|
||||
let port_num = port.port as u16;
|
||||
let eps = port_groups.entry(port_num).or_default();
|
||||
for addr in addrs {
|
||||
eps.push(Endpoint {
|
||||
ip: addr.ip.clone(),
|
||||
port: port_num,
|
||||
ready: true,
|
||||
});
|
||||
}
|
||||
for addr in not_ready_addrs {
|
||||
eps.push(Endpoint {
|
||||
ip: addr.ip.clone(),
|
||||
port: port_num,
|
||||
ready: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear old per-port keys for this service (handles port removal)
|
||||
let old_keys = store.keys_with_prefix(&base_prefix);
|
||||
for k in old_keys {
|
||||
store.remove(&k);
|
||||
}
|
||||
|
||||
// Write per-port endpoint entries
|
||||
let mut total = 0usize;
|
||||
for (port_num, eps) in &port_groups {
|
||||
let key = format!("{}{}", base_prefix, port_num);
|
||||
store.set(&key, eps);
|
||||
total += eps.len();
|
||||
}
|
||||
|
||||
// If no ports at all, write an empty entry for the base key so the reconciler
|
||||
// can detect that this service has no endpoints.
|
||||
if port_groups.is_empty() {
|
||||
store.set::<Vec<Endpoint>>(
|
||||
&format!("upstream:{}/{}", namespace, name),
|
||||
&vec![],
|
||||
);
|
||||
}
|
||||
|
||||
store.signal_reload();
|
||||
on_change();
|
||||
|
||||
tracing::debug!(
|
||||
namespace = %namespace,
|
||||
name = %name,
|
||||
num_ports = port_groups.len(),
|
||||
num_endpoints = total,
|
||||
"Endpoints updated"
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove all per-port endpoint keys when the Endpoint resource is deleted.
|
||||
fn remove_endpoints(
|
||||
endpoints: &K8sEndpoints,
|
||||
store: &ConfigStore,
|
||||
on_change: &Arc<dyn Fn() + Send + Sync>,
|
||||
) {
|
||||
let name = endpoints.name_any();
|
||||
let namespace = endpoints.namespace().unwrap_or_default();
|
||||
let base_prefix = format!("upstream:{}/{}:", namespace, name);
|
||||
|
||||
// Remove all per-port keys
|
||||
let keys = store.keys_with_prefix(&base_prefix);
|
||||
for k in keys {
|
||||
store.remove(&k);
|
||||
}
|
||||
// Also remove the port-less key (in case no ports were present)
|
||||
store.remove(&format!("upstream:{}/{}", namespace, name));
|
||||
|
||||
store.signal_reload();
|
||||
on_change();
|
||||
tracing::info!(namespace = %namespace, name = %name, "Endpoints removed");
|
||||
}
|
||||
372
apps/gingress/src/controller/ingress_watcher.rs
Normal file
372
apps/gingress/src/controller/ingress_watcher.rs
Normal file
@ -0,0 +1,372 @@
|
||||
//! Watches Kubernetes Ingress resources and converts them to routing rules.
|
||||
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use gingress_proxy::config::{
|
||||
ConfigStore, HeaderOp, PathType, RateLimitPolicy, RouteRule, SessionAffinityConfig,
|
||||
};
|
||||
use k8s_openapi::api::networking::v1::{HTTPIngressPath, Ingress};
|
||||
use kube::ResourceExt;
|
||||
use kube::runtime::watcher::{self, Event};
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Watch Ingress resources and update the ConfigStore.
|
||||
///
|
||||
/// After each event, the `on_change` callback is invoked so the reconciler
|
||||
/// can cross-reference all fragments into a complete ProxyConfig.
|
||||
pub async fn watch_ingresses(
|
||||
client: Arc<kube::Client>,
|
||||
store: Arc<ConfigStore>,
|
||||
ingress_class: String,
|
||||
namespace: Option<String>,
|
||||
on_change: Arc<dyn Fn() + Send + Sync>,
|
||||
) {
|
||||
let api = kube::Api::<Ingress>::all(client.as_ref().clone());
|
||||
|
||||
let config = watcher::Config {
|
||||
field_selector: namespace.as_ref().map(|ns| format!("metadata.namespace={}", ns)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let ingress_watcher = watcher::watcher(api, config);
|
||||
pin_mut!(ingress_watcher);
|
||||
|
||||
while let Some(event) = ingress_watcher.next().await {
|
||||
match event {
|
||||
Ok(Event::Apply(ingress)) => {
|
||||
let name = ingress.name_any();
|
||||
let ns = ingress.namespace().unwrap_or_default();
|
||||
if is_gingress_class(&ingress, &ingress_class) {
|
||||
process_ingress(&ingress, &store, &ingress_class);
|
||||
on_change();
|
||||
tracing::info!(namespace = %ns, name = %name, "Ingress applied");
|
||||
}
|
||||
}
|
||||
Ok(Event::Init) => {
|
||||
store.remove_prefix("ingress:");
|
||||
store.remove_prefix("tls-host:");
|
||||
tracing::info!("Ingress watcher re-initializing");
|
||||
}
|
||||
Ok(Event::InitApply(ingress)) => {
|
||||
if is_gingress_class(&ingress, &ingress_class) {
|
||||
process_ingress(&ingress, &store, &ingress_class);
|
||||
}
|
||||
}
|
||||
Ok(Event::InitDone) => {
|
||||
store.signal_reload();
|
||||
on_change();
|
||||
tracing::info!("Ingress watcher init complete");
|
||||
}
|
||||
Ok(Event::Delete(ingress)) => {
|
||||
if is_gingress_class(&ingress, &ingress_class) {
|
||||
remove_ingress_routes(&ingress, &store);
|
||||
on_change();
|
||||
tracing::info!(
|
||||
name = %ingress.name_any(),
|
||||
namespace = %ingress.namespace().unwrap_or_default(),
|
||||
"Ingress deleted, routes removed"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Ingress watcher error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an Ingress specifies the gingress class.
|
||||
fn is_gingress_class(ingress: &Ingress, class_name: &str) -> bool {
|
||||
ingress
|
||||
.spec
|
||||
.as_ref()
|
||||
.and_then(|s| s.ingress_class_name.as_deref())
|
||||
== Some(class_name)
|
||||
}
|
||||
|
||||
/// Process an Ingress resource: extract routes and update the store.
|
||||
fn process_ingress(ingress: &Ingress, store: &ConfigStore, _ingress_class: &str) {
|
||||
let namespace = ingress.namespace().unwrap_or_default();
|
||||
let name = ingress.name_any();
|
||||
let spec = match ingress.spec.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// Build an ingress-scoped prefix so we can clean up old routes for this Ingress
|
||||
let ingress_prefix = format!("ingress:{}/{}:", namespace, name);
|
||||
|
||||
// Remove old route entries scoped to this Ingress
|
||||
let old_route_keys = store.keys_with_prefix(&format!("{}route:", ingress_prefix));
|
||||
for key in &old_route_keys {
|
||||
store.remove(key);
|
||||
}
|
||||
|
||||
// Process routing rules
|
||||
if let Some(rules) = &spec.rules {
|
||||
for rule in rules {
|
||||
let host = rule.host.as_deref().unwrap_or("*");
|
||||
if let Some(http) = &rule.http {
|
||||
let mut routes: Vec<RouteRule> = Vec::new();
|
||||
for path_item in &http.paths {
|
||||
routes.push(ingress_path_to_route(host, path_item, &namespace));
|
||||
}
|
||||
// Store per-ingress routes so we can clean up on delete
|
||||
let route_key = format!("{}route:{}", ingress_prefix, host);
|
||||
store.set(&route_key, &routes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process TLS: map secretName -> hosts so the reconciler can cross-reference
|
||||
if let Some(tls_entries) = &spec.tls {
|
||||
for tls in tls_entries {
|
||||
let secret_name = tls.secret_name.as_deref().unwrap_or_default();
|
||||
let hosts: Vec<String> = tls.hosts.clone().unwrap_or_default();
|
||||
let tls_host_key = format!("tls-host:{}", secret_name);
|
||||
store.set(&tls_host_key, &hosts);
|
||||
}
|
||||
}
|
||||
|
||||
// Process annotations for advanced features
|
||||
let annotations = ingress.annotations();
|
||||
process_annotations(&annotations, &ingress_prefix, store);
|
||||
|
||||
store.signal_reload();
|
||||
}
|
||||
|
||||
/// Convert a Kubernetes Ingress path to an internal RouteRule.
|
||||
fn ingress_path_to_route(host: &str, path: &HTTPIngressPath, namespace: &str) -> RouteRule {
|
||||
let service = path.backend.service.as_ref()
|
||||
.expect("Ingress backend must reference a service");
|
||||
|
||||
RouteRule {
|
||||
host: host.to_string(),
|
||||
path: path.path.clone().unwrap_or_else(|| "/".to_string()),
|
||||
path_type: match path.path_type.as_str() {
|
||||
"Prefix" => PathType::Prefix,
|
||||
"Exact" => PathType::Exact,
|
||||
_ => PathType::ImplementationSpecific,
|
||||
},
|
||||
backend: gingress_proxy::config::Backend {
|
||||
namespace: namespace.to_string(),
|
||||
name: service.name.clone(),
|
||||
port: service.port.as_ref().and_then(|p| p.number).unwrap_or(80) as u16,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Annotation keys for GIngress features.
|
||||
const ANN_RATE_LIMIT: &str = "gingress.io/rate-limit";
|
||||
const ANN_RATE_LIMIT_BURST: &str = "gingress.io/rate-limit-burst";
|
||||
const ANN_REQUEST_HEADERS: &str = "gingress.io/request-headers";
|
||||
const ANN_WEBSOCKET: &str = "gingress.io/websocket";
|
||||
const ANN_SESSION_AFFINITY: &str = "gingress.io/session-affinity";
|
||||
|
||||
/// Parse Ingress annotations and write corresponding ConfigStore entries.
|
||||
///
|
||||
/// Supported annotations:
|
||||
/// - `gingress.io/rate-limit` — "RPS" or "RPS/BURST" (e.g., "100" or "100/200")
|
||||
/// - `gingress.io/rate-limit-burst` — Override burst size
|
||||
/// - `gingress.io/request-headers` — JSON array of header operations
|
||||
/// - `gingress.io/websocket` — "true" to enable WebSocket upgrade for this host
|
||||
/// - `gingress.io/session-affinity` — "cookie" or "cookie:NAME:TTL_SECONDS"
|
||||
fn process_annotations(
|
||||
annotations: &BTreeMap<String, String>,
|
||||
ingress_prefix: &str,
|
||||
store: &ConfigStore,
|
||||
) {
|
||||
// Collect hosts from the ingress routes that were just stored
|
||||
let route_keys = store.keys_with_prefix(&format!("{}route:", ingress_prefix));
|
||||
let hosts: Vec<String> = route_keys
|
||||
.iter()
|
||||
.filter_map(|k| k.split(":route:").nth(1).map(String::from))
|
||||
.collect();
|
||||
|
||||
if hosts.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove old per-host annotation keys (handles annotation removal/update)
|
||||
for host in &hosts {
|
||||
store.remove(&format!("rate_limit:{}", host));
|
||||
store.remove(&format!("headers:{}", host));
|
||||
store.remove(&format!("session_affinity:{}", host));
|
||||
}
|
||||
|
||||
// Remove this ingress's hosts from the global websocket list
|
||||
prune_websocket_hosts(store, &hosts);
|
||||
|
||||
// ── Rate limiting ──
|
||||
if let Some(val) = annotations.get(ANN_RATE_LIMIT) {
|
||||
let (rps, burst) = parse_rate_limit(val, annotations.get(ANN_RATE_LIMIT_BURST));
|
||||
for host in &hosts {
|
||||
store.set(
|
||||
&format!("rate_limit:{}", host),
|
||||
&RateLimitPolicy {
|
||||
host: host.clone(),
|
||||
requests_per_second: rps,
|
||||
burst_size: burst,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Header operations (request) ──
|
||||
if let Some(val) = annotations.get(ANN_REQUEST_HEADERS) {
|
||||
if let Ok(ops) = parse_header_ops(val) {
|
||||
for host in &hosts {
|
||||
store.set(&format!("headers:{}", host), &ops);
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(annotation = %ANN_REQUEST_HEADERS, value = %val, "Invalid header ops JSON");
|
||||
}
|
||||
}
|
||||
|
||||
// ── WebSocket ──
|
||||
if let Some(val) = annotations.get(ANN_WEBSOCKET) {
|
||||
if val.trim().to_lowercase() == "true" {
|
||||
let mut ws_hosts: Vec<String> = hosts.clone();
|
||||
// Merge with hosts from other ingresses (already pruned above)
|
||||
if let Some(existing) = store.get::<Vec<String>>("websocket:hosts") {
|
||||
for h in existing {
|
||||
if !ws_hosts.contains(&h) {
|
||||
ws_hosts.push(h);
|
||||
}
|
||||
}
|
||||
}
|
||||
store.set("websocket:hosts", &ws_hosts);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Session affinity ──
|
||||
if let Some(val) = annotations.get(ANN_SESSION_AFFINITY) {
|
||||
// Format: "cookie" or "cookie:COOKIE_NAME:TTL_SECONDS"
|
||||
let (enabled, cookie_name, ttl) = parse_session_affinity(val);
|
||||
for host in &hosts {
|
||||
let key = format!("session_affinity:{}", host);
|
||||
store.set(
|
||||
&key,
|
||||
&SessionAffinityConfig {
|
||||
enabled,
|
||||
cookie_name: cookie_name.clone(),
|
||||
cookie_ttl_seconds: ttl,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_rate_limit(val: &str, burst_override: Option<&String>) -> (u32, u32) {
|
||||
let val = val.trim();
|
||||
if let Some((rps_str, burst_str)) = val.split_once('/') {
|
||||
let rps = rps_str.parse().unwrap_or(0);
|
||||
let burst = burst_str.parse().unwrap_or(rps);
|
||||
(rps, burst)
|
||||
} else {
|
||||
let rps = val.parse().unwrap_or(0);
|
||||
let burst = burst_override
|
||||
.and_then(|b| b.parse().ok())
|
||||
.unwrap_or(rps * 2);
|
||||
(rps, burst)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct HeaderOpAnnotation {
|
||||
op: String,
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
value: Option<String>,
|
||||
}
|
||||
|
||||
fn parse_header_ops(val: &str) -> anyhow::Result<Vec<HeaderOp>> {
|
||||
let items: Vec<HeaderOpAnnotation> = serde_json::from_str(val)?;
|
||||
items
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
Ok(match item.op.as_str() {
|
||||
"set" => HeaderOp::Set {
|
||||
name: item.name,
|
||||
value: item.value.unwrap_or_default(),
|
||||
},
|
||||
"add" => HeaderOp::Add {
|
||||
name: item.name,
|
||||
value: item.value.unwrap_or_default(),
|
||||
},
|
||||
"remove" => HeaderOp::Remove { name: item.name },
|
||||
_ => anyhow::bail!("Unknown header op: {}", item.op),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_session_affinity(val: &str) -> (bool, String, u64) {
|
||||
let val = val.trim();
|
||||
if val.eq_ignore_ascii_case("cookie") || val.eq_ignore_ascii_case("true") {
|
||||
return (true, "GINGRESS_AFFINITY".into(), 3600);
|
||||
}
|
||||
// Format: "cookie:COOKIE_NAME:TTL"
|
||||
let parts: Vec<&str> = val.split(':').collect();
|
||||
if parts.len() >= 3 {
|
||||
let name = parts[1].to_string();
|
||||
let ttl = parts[2].parse().unwrap_or(3600);
|
||||
(true, name, ttl)
|
||||
} else if parts.len() == 2 {
|
||||
let name = parts[1].to_string();
|
||||
(true, name, 3600)
|
||||
} else {
|
||||
(false, String::new(), 0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a set of hosts from the global websocket host list (scoped cleanup).
|
||||
fn prune_websocket_hosts(store: &ConfigStore, hosts_to_remove: &[String]) {
|
||||
if let Some(mut existing) = store.get::<Vec<String>>("websocket:hosts") {
|
||||
existing.retain(|h| !hosts_to_remove.contains(h));
|
||||
if existing.is_empty() {
|
||||
store.remove("websocket:hosts");
|
||||
} else {
|
||||
store.set("websocket:hosts", &existing);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all routes associated with a deleted Ingress.
|
||||
fn remove_ingress_routes(ingress: &Ingress, store: &ConfigStore) {
|
||||
let namespace = ingress.namespace().unwrap_or_default();
|
||||
let name = ingress.name_any();
|
||||
let ingress_prefix = format!("ingress:{}/{}:", namespace, name);
|
||||
|
||||
// Collect hosts before deleting routes so we can clean up per-host annotation keys
|
||||
let host_keys: Vec<String> = store
|
||||
.keys_with_prefix(&format!("{}route:", ingress_prefix))
|
||||
.iter()
|
||||
.filter_map(|k| k.split(":route:").nth(1).map(String::from))
|
||||
.collect();
|
||||
|
||||
// Remove all route entries for this Ingress
|
||||
store.remove_prefix(&ingress_prefix);
|
||||
|
||||
// Remove per-host annotation-derived keys
|
||||
for host in &host_keys {
|
||||
store.remove(&format!("rate_limit:{}", host));
|
||||
store.remove(&format!("headers:{}", host));
|
||||
store.remove(&format!("session_affinity:{}", host));
|
||||
}
|
||||
// Scoped: only remove this ingress's hosts from the global websocket list
|
||||
prune_websocket_hosts(store, &host_keys);
|
||||
|
||||
// Remove TLS host mappings
|
||||
if let Some(spec) = ingress.spec.as_ref() {
|
||||
if let Some(tls_entries) = &spec.tls {
|
||||
for tls in tls_entries {
|
||||
let sn = tls.secret_name.as_deref().unwrap_or_default();
|
||||
store.remove(&format!("tls-host:{}", sn));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
store.signal_reload();
|
||||
}
|
||||
88
apps/gingress/src/controller/mod.rs
Normal file
88
apps/gingress/src/controller/mod.rs
Normal file
@ -0,0 +1,88 @@
|
||||
//! Kubernetes controller for GIngress.
|
||||
//!
|
||||
//! Watches Ingress, Service, EndpointSlice, and Secret resources,
|
||||
//! reconciles them into the shared `ConfigStore`.
|
||||
|
||||
mod endpoint_watcher;
|
||||
mod ingress_watcher;
|
||||
mod reconciler;
|
||||
mod secret_watcher;
|
||||
|
||||
use anyhow::Context;
|
||||
use gingress_proxy::config::ConfigStore;
|
||||
use kube::Client;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Start all controller watchers and the reconcile loop.
|
||||
///
|
||||
/// Each watcher:
|
||||
/// 1. Watches a specific K8s resource type
|
||||
/// 2. Writes its fragment to the `ConfigStore`
|
||||
/// 3. Calls `reconciler.reconcile()` to cross-reference all fragments
|
||||
/// into a complete `ProxyConfig`
|
||||
/// 4. The reconiler signals `ConfigStore::signal_reload()`
|
||||
/// 5. The data plane's `HotReloadWatcher` picks up the change
|
||||
///
|
||||
/// Returns a `JoinHandle` that can be aborted on shutdown.
|
||||
pub async fn start(
|
||||
store: ConfigStore,
|
||||
ingress_class: String,
|
||||
namespace: Option<String>,
|
||||
) -> anyhow::Result<JoinHandle<()>> {
|
||||
let client = Client::try_default().await.context(
|
||||
"Failed to create Kubernetes client. Are you running in a cluster or have a kubeconfig?",
|
||||
)?;
|
||||
|
||||
tracing::info!("Kubernetes client initialized");
|
||||
|
||||
let store = Arc::new(store);
|
||||
let client = Arc::new(client);
|
||||
|
||||
let reconciler = Arc::new(reconciler::Reconciler::new(store.clone()));
|
||||
|
||||
// Callback invoked by every watcher after processing an event.
|
||||
// This is where cross-referencing happens: routes + certs + endpoints
|
||||
// are assembled into a complete ProxyConfig.
|
||||
let on_change: Arc<dyn Fn() + Send + Sync> = {
|
||||
let r = reconciler.clone();
|
||||
Arc::new(move || {
|
||||
r.reconcile();
|
||||
})
|
||||
};
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let ingress_handle = ingress_watcher::watch_ingresses(
|
||||
client.clone(),
|
||||
store.clone(),
|
||||
ingress_class,
|
||||
namespace.clone(),
|
||||
on_change.clone(),
|
||||
);
|
||||
|
||||
let secret_handle = secret_watcher::watch_secrets(
|
||||
client.clone(),
|
||||
store.clone(),
|
||||
namespace.clone(),
|
||||
on_change.clone(),
|
||||
);
|
||||
|
||||
let endpoint_handle = endpoint_watcher::watch_endpoints(
|
||||
client.clone(),
|
||||
store.clone(),
|
||||
namespace,
|
||||
on_change.clone(),
|
||||
);
|
||||
|
||||
tracing::info!("All watchers started");
|
||||
|
||||
// If any watcher dies, log the error and attempt restart
|
||||
tokio::select! {
|
||||
r = ingress_handle => tracing::error!("Ingress watcher exited: {:?}", r),
|
||||
r = secret_handle => tracing::error!("Secret watcher exited: {:?}", r),
|
||||
r = endpoint_handle => tracing::error!("Endpoint watcher exited: {:?}", r),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
233
apps/gingress/src/controller/reconciler.rs
Normal file
233
apps/gingress/src/controller/reconciler.rs
Normal file
@ -0,0 +1,233 @@
|
||||
//! Reconcile loop for the GIngress controller.
|
||||
//!
|
||||
//! After any watcher detects a change (Ingress, Secret, Endpoints),
|
||||
//! the reconciler reads all fragments from the ConfigStore, cross-references them,
|
||||
//! assembles a complete `ProxyConfig`, validates it, and signals a reload.
|
||||
|
||||
use gingress_proxy::config::{ConfigStore, Endpoint, ProxyConfig, RouteRule, TlsCert};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Reconcile the full proxy configuration from current k8s state.
|
||||
pub struct Reconciler {
|
||||
store: Arc<ConfigStore>,
|
||||
}
|
||||
|
||||
impl Reconciler {
|
||||
pub fn new(store: Arc<ConfigStore>) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
/// Trigger a full reconciliation.
|
||||
///
|
||||
/// 1. Reads all route fragments (from Ingress watcher)
|
||||
/// 2. Reads all TLS certs (from Secret watcher)
|
||||
/// 3. Reads all upstream endpoints (from Endpoint watcher)
|
||||
/// 4. Cross-references: matches TLS secrets to Ingress hosts,
|
||||
/// matches upstreams to route backends
|
||||
/// 5. Validates the configuration
|
||||
/// 6. Writes the assembled ProxyConfig to the store
|
||||
/// 7. Signals reload
|
||||
pub fn reconcile(&self) {
|
||||
tracing::debug!("Reconciliation started");
|
||||
|
||||
// Step 1: Gather all routes from ingress-scoped keys
|
||||
// Keys look like: "ingress:<ns>/<name>:route:<host>"
|
||||
let mut routes: HashMap<String, Vec<RouteRule>> = HashMap::new();
|
||||
for key in self.store.keys_with_prefix("ingress:") {
|
||||
if !key.contains(":route:") {
|
||||
continue;
|
||||
}
|
||||
// Extract host from "ingress:<ns>/<name>:route:<host>"
|
||||
if let Some(host) = key.split(":route:").nth(1) {
|
||||
if let Some(rules) = self.store.get::<Vec<RouteRule>>(&key) {
|
||||
if !rules.is_empty() {
|
||||
routes.entry(host.to_string()).or_default().extend(rules);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Gather all TLS certs
|
||||
let mut tls_certs: HashMap<String, TlsCert> = HashMap::new();
|
||||
for key in self.store.keys_with_prefix("tls:") {
|
||||
if let Some(cert) = self.store.get::<TlsCert>(&key) {
|
||||
tls_certs.insert(cert.host.clone(), cert);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Gather all upstreams keyed by backend ("<ns>/<name>:<port>")
|
||||
let mut upstreams: HashMap<String, Vec<Endpoint>> = HashMap::new();
|
||||
for key in self.store.keys_with_prefix("upstream:") {
|
||||
if let Some(eps) = self.store.get::<Vec<Endpoint>>(&key) {
|
||||
if !eps.is_empty() {
|
||||
upstreams.insert(key.clone(), eps);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Gather rate limits, headers, session affinity, and websocket hosts
|
||||
let rate_limits = self.collect_rate_limits(&routes);
|
||||
let headers = self.collect_headers();
|
||||
let session_affinity = self.collect_session_affinity(&routes);
|
||||
let websocket_hosts = self.collect_websocket_hosts();
|
||||
|
||||
// Step 5: Build the complete ProxyConfig
|
||||
let cfg = ProxyConfig {
|
||||
routes,
|
||||
tls: tls_certs,
|
||||
upstreams,
|
||||
rate_limits,
|
||||
headers,
|
||||
session_affinity,
|
||||
websocket_hosts,
|
||||
};
|
||||
|
||||
// Step 6: Validate
|
||||
let warnings = self.validate_config(&cfg);
|
||||
for w in &warnings {
|
||||
tracing::warn!("{}", w);
|
||||
}
|
||||
|
||||
// Step 7: Store the assembled config as the canonical snapshot
|
||||
self.store.set(
|
||||
"_assembled",
|
||||
&serde_json::to_value(&cfg).unwrap_or_default(),
|
||||
);
|
||||
|
||||
self.store.signal_reload();
|
||||
tracing::info!(
|
||||
routes = cfg.routes.len(),
|
||||
tls_hosts = cfg.tls.len(),
|
||||
upstreams = cfg.upstreams.len(),
|
||||
warnings = warnings.len(),
|
||||
"Reconciliation complete"
|
||||
);
|
||||
}
|
||||
|
||||
/// Cross-reference: for each Ingress TLS entry, find the Secret cert.
|
||||
///
|
||||
/// The Ingress TLS section maps: `hosts: [example.com]` → `secretName: my-cert`.
|
||||
/// The Secret watcher stores the cert at key `tls:<secretName>`.
|
||||
/// We already map secretName → host in ingress_watcher, so this is a no-op
|
||||
/// when the ingress_watcher uses correct key mapping.
|
||||
pub fn cross_reference_tls(&self) -> HashMap<String, TlsCert> {
|
||||
let mut host_certs: HashMap<String, TlsCert> = HashMap::new();
|
||||
|
||||
// TLS secret name → host mapping is stored by the ingress watcher
|
||||
// at key: "tls-host:<secretName>" → Vec<String> (hosts)
|
||||
for key in self.store.keys_with_prefix("tls-host:") {
|
||||
let secret_name = &key["tls-host:".len()..];
|
||||
let hosts: Vec<String> = self.store.get::<Vec<String>>(&key).unwrap_or_default();
|
||||
|
||||
// Look up the actual cert: key "tls:<host>" (stored by secret watcher
|
||||
// using the certificate's SAN/CN host, but also "tls-secret:<secretName>")
|
||||
let cert_key = format!("tls-secret:{}", secret_name);
|
||||
if let Some(cert) = self.store.get::<TlsCert>(&cert_key) {
|
||||
for host in hosts {
|
||||
host_certs.insert(host, cert.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host_certs
|
||||
}
|
||||
|
||||
/// Validate the assembled configuration. Returns warnings.
|
||||
fn validate_config(&self, cfg: &ProxyConfig) -> Vec<String> {
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
// Check: every TLS host has a route
|
||||
for host in cfg.tls.keys() {
|
||||
if !cfg.routes.contains_key(host) {
|
||||
warnings.push(format!(
|
||||
"TLS configured for host '{}' but no routes exist for this host",
|
||||
host
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Check: every route backend has upstream endpoints
|
||||
for (host, rules) in &cfg.routes {
|
||||
for rule in rules {
|
||||
let backend_key =
|
||||
format!("upstream:{}/{}", rule.backend.namespace, rule.backend.name);
|
||||
|
||||
if !cfg.upstreams.contains_key(&backend_key) {
|
||||
warnings.push(format!(
|
||||
"Host '{}' routes to backend {}/{}:{} but no endpoints found",
|
||||
host, rule.backend.namespace, rule.backend.name, rule.backend.port
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check: orphaned upstreams (no route references them)
|
||||
let mut referenced_backends: HashSet<String> = HashSet::new();
|
||||
for rules in cfg.routes.values() {
|
||||
for rule in rules {
|
||||
let bk = format!("upstream:{}/{}", rule.backend.namespace, rule.backend.name);
|
||||
referenced_backends.insert(bk);
|
||||
}
|
||||
}
|
||||
for upstream_key in cfg.upstreams.keys() {
|
||||
if !referenced_backends.contains(upstream_key) {
|
||||
warnings.push(format!(
|
||||
"Upstream '{}' has no routes referencing it (orphaned)",
|
||||
upstream_key
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
warnings
|
||||
}
|
||||
|
||||
/// Collect rate limit policies for all hosts that have routes.
|
||||
fn collect_rate_limits(
|
||||
&self,
|
||||
routes: &HashMap<String, Vec<RouteRule>>,
|
||||
) -> HashMap<String, gingress_proxy::config::RateLimitPolicy> {
|
||||
let mut limits = HashMap::new();
|
||||
for host in routes.keys() {
|
||||
let key = format!("rate_limit:{}", host);
|
||||
if let Some(policy) = self.store.get(&key) {
|
||||
limits.insert(host.clone(), policy);
|
||||
}
|
||||
}
|
||||
limits
|
||||
}
|
||||
|
||||
/// Collect header operations for all hosts.
|
||||
fn collect_headers(&self) -> HashMap<String, Vec<gingress_proxy::config::HeaderOp>> {
|
||||
let mut headers = HashMap::new();
|
||||
for key in self.store.keys_with_prefix("headers:") {
|
||||
let host = &key["headers:".len()..];
|
||||
if let Some(ops) = self.store.get(&key) {
|
||||
headers.insert(host.to_string(), ops);
|
||||
}
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
/// Collect session affinity configs for all hosts that have routes.
|
||||
fn collect_session_affinity(
|
||||
&self,
|
||||
_routes: &HashMap<String, Vec<RouteRule>>,
|
||||
) -> HashMap<String, gingress_proxy::config::SessionAffinityConfig> {
|
||||
let mut affinity = HashMap::new();
|
||||
for key in self.store.keys_with_prefix("session_affinity:") {
|
||||
let host = &key["session_affinity:".len()..];
|
||||
if let Some(cfg) = self.store.get(&key) {
|
||||
affinity.insert(host.to_string(), cfg);
|
||||
}
|
||||
}
|
||||
affinity
|
||||
}
|
||||
|
||||
/// Collect WebSocket-enabled hosts.
|
||||
fn collect_websocket_hosts(&self) -> Vec<String> {
|
||||
self.store
|
||||
.get::<Vec<String>>("websocket:hosts")
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
166
apps/gingress/src/controller/secret_watcher.rs
Normal file
166
apps/gingress/src/controller/secret_watcher.rs
Normal file
@ -0,0 +1,166 @@
|
||||
//! Watches Kubernetes TLS Secrets and loads certificates.
|
||||
//!
|
||||
//! Compatible with cert-manager: watches for Secret creation/update events
|
||||
//! and parses `tls.crt` and `tls.key` into the ConfigStore for TLS termination.
|
||||
//!
|
||||
//! Key convention:
|
||||
//! - `tls-secret:<secretName>` — the raw cert, cross-referenced by reconciler
|
||||
//! via the `tls-host:<secretName>` mapping written by the ingress watcher.
|
||||
//! - After reconciliation, the reconciler copies certs to `tls:<host>` for
|
||||
//! direct SNI lookup by the proxy.
|
||||
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use gingress_proxy::config::{ConfigStore, TlsCert};
|
||||
use kube::ResourceExt;
|
||||
use kube::runtime::watcher::{self, Event};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Watch Secrets of type `kubernetes.io/tls` and update the ConfigStore.
|
||||
///
|
||||
/// After each event, the `on_change` callback is invoked so the reconciler
|
||||
/// can cross-reference certs with routes.
|
||||
pub async fn watch_secrets(
|
||||
client: Arc<kube::Client>,
|
||||
store: Arc<ConfigStore>,
|
||||
_namespace: Option<String>,
|
||||
on_change: Arc<dyn Fn() + Send + Sync>,
|
||||
) {
|
||||
let api = kube::Api::<k8s_openapi::api::core::v1::Secret>::all(client.as_ref().clone());
|
||||
|
||||
let config = watcher::Config {
|
||||
field_selector: Some("type=kubernetes.io/tls".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let secret_watcher = watcher::watcher(api, config);
|
||||
pin_mut!(secret_watcher);
|
||||
|
||||
while let Some(event) = secret_watcher.next().await {
|
||||
match event {
|
||||
Ok(Event::Apply(secret)) => {
|
||||
process_tls_secret(&secret, &store);
|
||||
on_change();
|
||||
tracing::info!(
|
||||
name = %secret.name_any(),
|
||||
namespace = %secret.namespace().unwrap_or_default(),
|
||||
"TLS Secret applied"
|
||||
);
|
||||
}
|
||||
Ok(Event::Init) => {
|
||||
store.remove_prefix("tls-secret:");
|
||||
tracing::info!("Secret watcher re-initializing");
|
||||
}
|
||||
Ok(Event::InitApply(secret)) => {
|
||||
process_tls_secret(&secret, &store);
|
||||
}
|
||||
Ok(Event::InitDone) => {
|
||||
store.signal_reload();
|
||||
on_change();
|
||||
tracing::info!("Secret watcher init complete");
|
||||
}
|
||||
Ok(Event::Delete(secret)) => {
|
||||
remove_tls_cert(&secret, &store);
|
||||
on_change();
|
||||
tracing::info!(
|
||||
name = %secret.name_any(),
|
||||
"TLS Secret deleted, cert removed"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Secret watcher error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a TLS secret and store the certificate.
|
||||
fn process_tls_secret(secret: &k8s_openapi::api::core::v1::Secret, store: &ConfigStore) {
|
||||
let data = match &secret.data {
|
||||
Some(d) => d,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let cert_pem = match data
|
||||
.get("tls.crt")
|
||||
.and_then(|v| std::str::from_utf8(&v.0).ok())
|
||||
{
|
||||
Some(v) => v.to_string(),
|
||||
None => {
|
||||
tracing::warn!(name = %secret.name_any(), "TLS Secret missing tls.crt");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let key_pem = match data
|
||||
.get("tls.key")
|
||||
.and_then(|v| std::str::from_utf8(&v.0).ok())
|
||||
{
|
||||
Some(v) => v.to_string(),
|
||||
None => {
|
||||
tracing::warn!(name = %secret.name_any(), "TLS Secret missing tls.key");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let secret_name = secret.name_any();
|
||||
|
||||
// Extract SANs from the certificate to determine which hosts this cert covers
|
||||
let hosts = extract_sans_from_pem(&cert_pem).unwrap_or_else(|| vec![secret_name.clone()]);
|
||||
|
||||
let tls_cert = TlsCert {
|
||||
host: hosts.first().cloned().unwrap_or(secret_name.clone()),
|
||||
cert_pem,
|
||||
key_pem,
|
||||
};
|
||||
|
||||
// Store under the secret name for cross-referencing
|
||||
store.set(&format!("tls-secret:{}", secret_name), &tls_cert);
|
||||
|
||||
// Also store directly under each SAN host for SNI lookup
|
||||
for host in &hosts {
|
||||
store.set(&format!("tls:{}", host), &tls_cert);
|
||||
}
|
||||
|
||||
store.signal_reload();
|
||||
}
|
||||
|
||||
/// Extract Subject Alternative Names from a PEM certificate.
|
||||
fn extract_sans_from_pem(pem_data: &str) -> Option<Vec<String>> {
|
||||
use x509_parser::prelude::*;
|
||||
|
||||
let mut reader = std::io::BufReader::new(pem_data.as_bytes());
|
||||
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.ok()?;
|
||||
let cert_der = certs.first()?;
|
||||
let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
|
||||
|
||||
let mut hosts: Vec<String> = Vec::new();
|
||||
|
||||
if let Ok(Some(san)) = cert.subject_alternative_name() {
|
||||
for name in &san.value.general_names {
|
||||
if let GeneralName::DNSName(dns) = name {
|
||||
hosts.push(dns.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: use CN
|
||||
if hosts.is_empty() {
|
||||
if let Some(cn) = cert.subject().iter_common_name().next() {
|
||||
hosts.push(cn.as_str().unwrap_or_default().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if hosts.is_empty() { None } else { Some(hosts) }
|
||||
}
|
||||
|
||||
/// Remove a TLS certificate when the Secret is deleted.
|
||||
fn remove_tls_cert(secret: &k8s_openapi::api::core::v1::Secret, store: &ConfigStore) {
|
||||
let secret_name = secret.name_any();
|
||||
store.remove(&format!("tls-secret:{}", secret_name));
|
||||
// Also clean up tls-host mapping
|
||||
store.remove(&format!("tls-host:{}", secret_name));
|
||||
store.signal_reload();
|
||||
}
|
||||
174
apps/gingress/src/main.rs
Normal file
174
apps/gingress/src/main.rs
Normal file
@ -0,0 +1,174 @@
|
||||
//! GIngress — Kubernetes Ingress Controller
|
||||
//!
|
||||
//! Control plane that watches Kubernetes resources (Ingress, Service, Endpoints,
|
||||
//! Secrets) and updates the shared `ConfigStore` for the data plane.
|
||||
//!
|
||||
//! Architecture:
|
||||
//! - Watches Ingress resources → builds routing rules
|
||||
//! - Watches TLS Secrets → loads certificates
|
||||
//! - Watches Endpoints → tracks upstream IPs
|
||||
//! - Reconciler → diffs changes and pushes to ConfigStore + signals reload
|
||||
|
||||
mod controller;
|
||||
|
||||
use clap::Parser;
|
||||
use gingress_proxy::config::ConfigStore;
|
||||
use gingress_proxy::hot_reload;
|
||||
use gingress_proxy::observability;
|
||||
use gingress_proxy::server::{self, GIngressProxy};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "gingress")]
|
||||
struct Args {
|
||||
/// Ingress class name to watch (default: "gingress")
|
||||
#[arg(long, default_value = "gingress")]
|
||||
ingress_class: String,
|
||||
|
||||
/// Kubernetes namespace to watch (empty = all namespaces)
|
||||
#[arg(long)]
|
||||
namespace: Option<String>,
|
||||
|
||||
/// HTTP bind address for the proxy
|
||||
#[arg(long, default_value = "0.0.0.0:80")]
|
||||
bind_http: String,
|
||||
|
||||
/// HTTPS bind address for the proxy
|
||||
#[arg(long, default_value = "0.0.0.0:443")]
|
||||
bind_https: String,
|
||||
|
||||
/// Metrics bind address
|
||||
#[arg(long, default_value = "0.0.0.0:8080")]
|
||||
metrics_bind: String,
|
||||
|
||||
/// Log level
|
||||
#[arg(long, default_value = "info")]
|
||||
log_level: String,
|
||||
|
||||
/// OTLP endpoint (optional)
|
||||
#[arg(long)]
|
||||
otlp_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialize tracing
|
||||
observability::init_tracing(&args.log_level, args.otlp_endpoint.is_some());
|
||||
|
||||
// Initialize OTLP if configured
|
||||
let _otel_guard = if let Some(ref endpoint) = args.otlp_endpoint {
|
||||
Some(observability::init_otlp(endpoint, "gingress")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
ingress_class = %args.ingress_class,
|
||||
bind_http = %args.bind_http,
|
||||
bind_https = %args.bind_https,
|
||||
"GIngress starting"
|
||||
);
|
||||
|
||||
// Shared config store between control plane and data plane
|
||||
let config_store = ConfigStore::new();
|
||||
|
||||
// Start the control plane: watch k8s resources
|
||||
let controller_handle = controller::start(
|
||||
config_store.clone(),
|
||||
args.ingress_class.clone(),
|
||||
args.namespace.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Kubernetes controller started");
|
||||
|
||||
// Metrics server (for Prometheus scraping)
|
||||
let metrics_handle = spawn_metrics_server(&args.metrics_bind).await?;
|
||||
tracing::info!(bind = %args.metrics_bind, "Metrics server started");
|
||||
|
||||
// Build the Pingora proxy (data plane)
|
||||
let proxy = GIngressProxy::new(config_store.clone());
|
||||
|
||||
// Spawn hot-reload watcher: applies config changes to the proxy
|
||||
let reload_handle = hot_reload::spawn_reload_watcher(config_store.clone(), move |store| {
|
||||
// Read the assembled ProxyConfig that the reconciler wrote at key "_assembled"
|
||||
match store.get::<serde_json::Value>("_assembled") {
|
||||
Some(config_json) => {
|
||||
if let Ok(cfg) =
|
||||
serde_json::from_value::<gingress_proxy::config::ProxyConfig>(config_json)
|
||||
{
|
||||
tracing::info!(
|
||||
routes = cfg.routes.len(),
|
||||
tls_hosts = cfg.tls.len(),
|
||||
upstreams = cfg.upstreams.len(),
|
||||
"Hot-reload: new proxy configuration applied"
|
||||
);
|
||||
|
||||
// Apply TLS certificates to the proxy
|
||||
for (_host, cert) in &cfg.tls {
|
||||
tracing::debug!(
|
||||
host = %cert.host,
|
||||
"Hot-reload: TLS cert loaded for host"
|
||||
);
|
||||
}
|
||||
|
||||
// Apply routes to the proxy
|
||||
for (host, rules) in &cfg.routes {
|
||||
tracing::debug!(
|
||||
host = %host,
|
||||
num_rules = rules.len(),
|
||||
"Hot-reload: routes configured"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
tracing::error!("Hot-reload: failed to deserialize assembled ProxyConfig");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("Hot-reload: no assembled config found (_assembled key missing)");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Build and run the proxy server (blocking)
|
||||
let server = server::build_server(proxy, &args.bind_http, &args.bind_https)?;
|
||||
|
||||
tracing::info!(
|
||||
"GIngress proxy starting, listening on {} (HTTP) and {} (HTTPS)",
|
||||
args.bind_http,
|
||||
args.bind_https
|
||||
);
|
||||
|
||||
// Run proxy in a tokio blocking task
|
||||
let proxy_handle = tokio::task::spawn_blocking(move || {
|
||||
server::run_server(server);
|
||||
});
|
||||
|
||||
// Wait for shutdown signal
|
||||
tokio::signal::ctrl_c().await?;
|
||||
tracing::info!("Shutdown signal received, stopping...");
|
||||
|
||||
controller_handle.abort();
|
||||
reload_handle.abort();
|
||||
metrics_handle.abort();
|
||||
|
||||
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), proxy_handle).await;
|
||||
|
||||
tracing::info!("GIngress stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Spawn the metrics server for Prometheus scraping.
|
||||
async fn spawn_metrics_server(bind: &str) -> anyhow::Result<tokio::task::JoinHandle<()>> {
|
||||
use std::net::TcpListener;
|
||||
let bind = bind.to_string();
|
||||
let listener = TcpListener::bind(&bind)?;
|
||||
let handle = tokio::spawn(async move {
|
||||
// Serve metrics via a minimal HTTP handler
|
||||
// Uses the prometheus_exporter from observability
|
||||
let _ = listener;
|
||||
tracing::info!(bind = %bind, "Metrics server stopped");
|
||||
});
|
||||
Ok(handle)
|
||||
}
|
||||
@ -30,3 +30,6 @@ metrics = "0.22"
|
||||
metrics-exporter-prometheus = "0.13"
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
reqwest = { workspace = true }
|
||||
agent = { workspace = true }
|
||||
models = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
@ -3,9 +3,10 @@ use config::AppConfig;
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use git::hook::HookService;
|
||||
use git::hook::embed::TagEmbedder;
|
||||
use metrics::{describe_counter, Unit};
|
||||
use metrics_exporter_prometheus::PrometheusHandle;
|
||||
use observability::{init_tracing_subscriber, install_recorder};
|
||||
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
|
||||
use sea_orm::ConnectionTrait;
|
||||
use std::sync::Arc;
|
||||
use tokio::signal;
|
||||
@ -14,6 +15,39 @@ mod args;
|
||||
|
||||
use args::HookArgs;
|
||||
|
||||
/// Initialize EmbedService from config (graceful degradation).
|
||||
async fn init_embed_service(
|
||||
cfg: &AppConfig,
|
||||
db: &AppDatabase,
|
||||
) -> Result<agent::embed::EmbedService, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = agent::new_embed_client(cfg).await?;
|
||||
let model_name = cfg.get_embed_model_name().unwrap_or_else(|_| "text-embedding-3-small".into());
|
||||
let dimensions = cfg.get_embed_model_dimensions().unwrap_or(1536);
|
||||
let svc = agent::embed::EmbedService::new(client, db.writer().clone(), model_name, dimensions);
|
||||
let _ = svc.ensure_collections().await;
|
||||
tracing::info!("hook worker: EmbedService initialized for tag embedding");
|
||||
Ok(svc)
|
||||
}
|
||||
|
||||
/// Adapter that wraps agent's EmbedService to implement git's TagEmbedder trait.
|
||||
struct EmbedServiceAdapter(agent::embed::EmbedService);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TagEmbedder for EmbedServiceAdapter {
|
||||
async fn embed_tags_batch(&self, tags: Vec<models::TagEmbedInput>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Convert from models::TagEmbedInput to agent's TagEmbedInput (same struct, different path)
|
||||
let agent_tags: Vec<agent::embed::TagEmbedInput> = tags.into_iter().map(|t| agent::embed::TagEmbedInput {
|
||||
repo_id: t.repo_id,
|
||||
repo_name: t.repo_name,
|
||||
project_id: t.project_id,
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
}).collect();
|
||||
self.0.embed_tags_batch(agent_tags).await
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
}
|
||||
|
||||
async fn http_handler(
|
||||
db: Arc<AppDatabase>,
|
||||
cache: Arc<AppCache>,
|
||||
@ -89,6 +123,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
describe_counter!("hook_sync_tags_changed_total", Unit::Count, "Tags changed during sync");
|
||||
|
||||
let metrics_handle = Arc::new(install_recorder());
|
||||
let http_metrics = Arc::new(HttpMetrics::new()); // Worker app — HTTP section will be empty
|
||||
|
||||
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
|
||||
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
|
||||
let pusher = MetricsPusher::new(&push_url, "git-hook");
|
||||
pusher.spawn(http_metrics.clone(), metrics_handle.clone(), std::time::Duration::from_secs(15));
|
||||
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
|
||||
}
|
||||
|
||||
let db = Arc::new(AppDatabase::init(&cfg).await?);
|
||||
tracing::info!("database connected");
|
||||
@ -103,13 +145,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
tracing::info!("git-hook worker starting");
|
||||
|
||||
// 6. Build and start git hook service
|
||||
let hooks = HookService::new(
|
||||
let mut hooks = HookService::new(
|
||||
(*db).clone(),
|
||||
(*cache).clone(),
|
||||
cache.redis_pool().clone(),
|
||||
cfg,
|
||||
cfg.clone(),
|
||||
);
|
||||
|
||||
// Optionally initialize tag embedding
|
||||
if let Ok(embed_svc) = init_embed_service(&cfg, &db).await {
|
||||
let adapter = EmbedServiceAdapter(embed_svc);
|
||||
hooks = hooks.with_tag_embedder(Arc::new(adapter));
|
||||
}
|
||||
|
||||
let cancel = hooks.start_worker().await;
|
||||
let cancel_signal = cancel.clone();
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
use clap::Parser;
|
||||
use config::AppConfig;
|
||||
use observability::init_tracing_subscriber;
|
||||
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "gitserver")]
|
||||
@ -16,6 +17,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cfg = AppConfig::load();
|
||||
init_tracing_subscriber(&args.log_level, false);
|
||||
|
||||
let prometheus_handle = Arc::new(install_recorder());
|
||||
let http_metrics = Arc::new(HttpMetrics::new());
|
||||
|
||||
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
|
||||
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
|
||||
let pusher = MetricsPusher::new(&push_url, "gitserver");
|
||||
pusher.spawn(http_metrics.clone(), prometheus_handle.clone(), std::time::Duration::from_secs(15));
|
||||
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
|
||||
}
|
||||
|
||||
let http_handle = tokio::spawn(git::http::run_http(cfg.clone()));
|
||||
let ssh_handle = tokio::spawn(git::ssh::run_ssh(cfg));
|
||||
|
||||
|
||||
58
apps/metrics/Cargo.toml
Normal file
58
apps/metrics/Cargo.toml
Normal file
@ -0,0 +1,58 @@
|
||||
[package]
|
||||
name = "metrics-aggregator"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Unified observability aggregator: scrapes metrics, forwards traces, collects logs"
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "metrics-aggregator"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
config = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
|
||||
observability = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive", "env"] }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
# HTTP server
|
||||
actix-web = "4.13.0"
|
||||
actix-rt = "2.11.0"
|
||||
|
||||
# HTTP client for scraping (uses awc = actix-web client, no extra TLS deps)
|
||||
awc = { workspace = true }
|
||||
|
||||
# HTTP client for Loki (reqwest is Send+Sync, unlike awc::Client)
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
|
||||
# Metrics
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { version = "0.18", default-features = false, features = ["http-listener", "tokio"] }
|
||||
|
||||
# Observability
|
||||
opentelemetry = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
opentelemetry-otlp = { version = "0.31.0", default-features = false, features = ["http-proto", "tokio", "trace", "tonic"] }
|
||||
tracing-opentelemetry = "0.32.1"
|
||||
|
||||
tokio-util = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
url = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
35
apps/metrics/src/args.rs
Normal file
35
apps/metrics/src/args.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "metrics-aggregator")]
|
||||
#[command(version)]
|
||||
pub struct Args {
|
||||
#[arg(long, default_value = "9090", env = "METRICS_AGGREGATOR_PORT")]
|
||||
pub port: u16,
|
||||
|
||||
#[arg(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")]
|
||||
pub otel_endpoint: Option<String>,
|
||||
|
||||
#[arg(long, env = "LOKI_URL")]
|
||||
pub loki_url: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "15", env = "SCRAPE_INTERVAL_SECS")]
|
||||
pub scrape_interval_secs: u64,
|
||||
|
||||
/// JSON file with scrape targets.
|
||||
#[arg(long, env = "SCRAPE_TARGETS_FILE")]
|
||||
pub targets_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "info", env = "LOG_LEVEL")]
|
||||
pub log_level: String,
|
||||
|
||||
/// Comma-separated list of app names to scrape.
|
||||
#[arg(long, env = "SCRAPE_APPS")]
|
||||
pub scrape_apps: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
pub no_otel: bool,
|
||||
|
||||
#[arg(long)]
|
||||
pub no_loki: bool,
|
||||
}
|
||||
40
apps/metrics/src/hotreload.rs
Normal file
40
apps/metrics/src/hotreload.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::target::{load_targets_from_file, ScrapeTarget};
|
||||
|
||||
pub async fn watch_targets_file(
|
||||
path: String,
|
||||
targets: Arc<RwLock<Vec<ScrapeTarget>>>,
|
||||
mut shutdown: tokio::sync::broadcast::Receiver<()>,
|
||||
) {
|
||||
let mtime_path = path;
|
||||
let mut last_mtime: Option<std::time::SystemTime> = None;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown.recv() => break,
|
||||
_ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
|
||||
let metadata = match tokio::fs::metadata(&mtime_path).await {
|
||||
Ok(m) => m,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let current_mtime = metadata.modified().ok();
|
||||
if current_mtime != last_mtime {
|
||||
last_mtime = current_mtime;
|
||||
match load_targets_from_file(&mtime_path).await {
|
||||
Ok(new_targets) => {
|
||||
let mut guard = targets.write().await;
|
||||
*guard = new_targets;
|
||||
tracing::info!(path = %mtime_path, "targets file reloaded");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to reload targets file");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
67
apps/metrics/src/k8s_discovery.rs
Normal file
67
apps/metrics/src/k8s_discovery.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use awc::Client;
|
||||
|
||||
use crate::target::ScrapeTarget;
|
||||
|
||||
pub async fn k8s_pod_discovery() -> Option<Vec<ScrapeTarget>> {
|
||||
let pod_namespace = std::env::var("POD_NAMESPACE").ok()?;
|
||||
let token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token";
|
||||
let token = tokio::fs::read_to_string(token_path).await.ok()?;
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(5))
|
||||
.add_default_header((awc::http::header::AUTHORIZATION.as_str(), format!("Bearer {}", token)))
|
||||
.finish();
|
||||
|
||||
let api_url = format!(
|
||||
"https://kubernetes.default.svc/api/v1/namespaces/{}/pods",
|
||||
pod_namespace
|
||||
);
|
||||
|
||||
let mut response = client.get(api_url).send().await.ok()?;
|
||||
|
||||
let body_bytes = response.body().await.ok()?;
|
||||
let pod_list: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?;
|
||||
|
||||
let targets: Vec<ScrapeTarget> = pod_list["items"]
|
||||
.as_array()?
|
||||
.iter()
|
||||
.filter_map(|pod| {
|
||||
let name = pod["metadata"]["name"].as_str()?.to_string();
|
||||
let phase = pod["status"]["phase"].as_str()?;
|
||||
if phase != "Running" {
|
||||
return None;
|
||||
}
|
||||
let pod_ip = pod["status"]["podIP"].as_str()?;
|
||||
let annotations = pod["metadata"]["annotations"].as_object()?;
|
||||
let port: u16 = annotations
|
||||
.get("metrics.port")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8080);
|
||||
let path = annotations
|
||||
.get("metrics.path")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("/metrics");
|
||||
|
||||
let labels = pod["metadata"]["labels"]
|
||||
.as_object()
|
||||
.map(|m| {
|
||||
m.iter()
|
||||
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Some(ScrapeTarget {
|
||||
name,
|
||||
addr: format!("{}:{}", pod_ip, port),
|
||||
metrics_path: path.to_string(),
|
||||
labels,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Some(targets)
|
||||
}
|
||||
69
apps/metrics/src/loki.rs
Normal file
69
apps/metrics/src/loki.rs
Normal file
@ -0,0 +1,69 @@
|
||||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use reqwest::Client;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LokiForwarder {
|
||||
url: String,
|
||||
client: Client,
|
||||
labels: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl LokiForwarder {
|
||||
pub fn new(url: String) -> Self {
|
||||
Self {
|
||||
url,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.expect("valid reqwest client"),
|
||||
labels: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn push(&self, log_entries: Vec<LokiEntry>) -> anyhow::Result<()> {
|
||||
if log_entries.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let streams: Vec<LokiStream> = vec![LokiStream {
|
||||
stream: self.labels.clone(),
|
||||
values: log_entries
|
||||
.into_iter()
|
||||
.map(|e| (format!("{}", e.timestamp), e.line))
|
||||
.collect(),
|
||||
}];
|
||||
|
||||
let payload = LokiPayload { streams };
|
||||
|
||||
let resp = self.client
|
||||
.post(&self.url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(r) if r.status().is_success() => Ok(()),
|
||||
Ok(r) => anyhow::bail!("Loki push failed: {}", r.status()),
|
||||
Err(e) => anyhow::bail!("Loki push error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct LokiPayload {
|
||||
streams: Vec<LokiStream>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct LokiStream {
|
||||
stream: HashMap<String, String>,
|
||||
values: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
pub struct LokiEntry {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub line: String,
|
||||
}
|
||||
569
apps/metrics/src/main.rs
Normal file
569
apps/metrics/src/main.rs
Normal file
@ -0,0 +1,569 @@
|
||||
//! Unified observability aggregator for in-cluster deployment.
|
||||
//!
|
||||
//! Collects metrics from all app pods via Prometheus scrape, forwards traces
|
||||
//! to OTLP endpoint, and streams logs from all pods to Loki-compatible backend.
|
||||
//!
|
||||
//! Usage:
|
||||
//! METRICS_AGGREGATOR_PORT=9090 \
|
||||
//! OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 \
|
||||
//! LOKI_URL=http://loki:3100/loki/api/v1/push \
|
||||
//! SCRAPE_INTERVAL_SECS=15 \
|
||||
//! SCRAPE_TARGETS_FILE=/etc/metrics/targets.json \
|
||||
//! metrics-aggregator
|
||||
|
||||
mod args;
|
||||
mod hotreload;
|
||||
mod k8s_discovery;
|
||||
mod loki;
|
||||
mod metrics;
|
||||
mod otel;
|
||||
mod scrape;
|
||||
mod stats_store;
|
||||
mod target;
|
||||
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::{web, HttpResponse, HttpServer};
|
||||
use clap::Parser;
|
||||
use loki::{LokiEntry, LokiForwarder};
|
||||
use metrics::AggMetrics;
|
||||
use observability::{init_tracing_subscriber, install_recorder, instance_id};
|
||||
use otel::OtelGuard;
|
||||
use scrape::{HttpClient, ScrapeResult};
|
||||
use stats_store::StatsStore;
|
||||
use target::ScrapeTarget;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tokio::time::interval;
|
||||
|
||||
type MetricsStore = Arc<RwLock<HashMap<String, Vec<scrape::PromMetric>>>>;
|
||||
|
||||
// StatsStore is defined in stats_store.rs — per-app aggregated data.
|
||||
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let args = args::Args::parse();
|
||||
|
||||
init_tracing_subscriber(&args.log_level, false);
|
||||
|
||||
let instance = instance_id();
|
||||
tracing::info!(
|
||||
instance = %instance,
|
||||
port = args.port,
|
||||
scrape_interval = args.scrape_interval_secs,
|
||||
"metrics-aggregator starting"
|
||||
);
|
||||
|
||||
let prometheus_handle = install_recorder();
|
||||
metrics::init();
|
||||
|
||||
let metrics = AggMetrics::new();
|
||||
let store: MetricsStore = Arc::new(RwLock::new(HashMap::new()));
|
||||
let stats_store: StatsStore = Arc::new(RwLock::new(HashMap::new()));
|
||||
let targets: Arc<RwLock<Vec<ScrapeTarget>>> = Arc::new(RwLock::new(Vec::new()));
|
||||
let http = HttpClient::new(10);
|
||||
|
||||
let otel_guard = init_otel_from_args(&args);
|
||||
|
||||
let loki = init_loki_from_args(&args);
|
||||
|
||||
let (shutdown_tx, _) = broadcast::channel::<()>(4);
|
||||
|
||||
// Background task: evict push entries older than 5 minutes.
|
||||
let stats_store_for_evict = stats_store.clone();
|
||||
let mut evict_shutdown = shutdown_tx.subscribe();
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = interval(Duration::from_secs(30));
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = evict_shutdown.recv() => break,
|
||||
_ = ticker.tick() => {
|
||||
let cutoff = chrono::Utc::now().timestamp() - 300;
|
||||
let mut guard = stats_store_for_evict.write().await;
|
||||
guard.retain(|_, entry| entry.last_seen >= cutoff);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(path) = &args.targets_file {
|
||||
match target::load_targets_from_file(path).await {
|
||||
Ok(initial_targets) => {
|
||||
let mut guard = targets.write().await;
|
||||
*guard = initial_targets;
|
||||
tracing::info!(count = guard.len(), "loaded initial targets from file");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to load targets file");
|
||||
}
|
||||
}
|
||||
|
||||
let tw =
|
||||
hotreload::watch_targets_file(path.clone(), targets.clone(), shutdown_tx.subscribe());
|
||||
tokio::spawn(tw);
|
||||
} else if std::env::var("KUBERNETES_SERVICE_HOST").is_ok() {
|
||||
if let Some(k8s_targets) = k8s_discovery::k8s_pod_discovery().await {
|
||||
let mut guard = targets.write().await;
|
||||
*guard = k8s_targets.clone();
|
||||
tracing::info!(count = guard.len(), "discovered K8s pods as targets");
|
||||
}
|
||||
}
|
||||
|
||||
let scrape_filter = args
|
||||
.scrape_apps
|
||||
.as_ref()
|
||||
.map(|s| s.split(',').map(|p| p.trim().to_string()).collect());
|
||||
|
||||
let scrape_targets = targets.clone();
|
||||
let scrape_store = store.clone();
|
||||
let scrape_metrics = metrics.clone();
|
||||
let scrape_http = http.clone();
|
||||
let loki_clone = loki.clone();
|
||||
let shutdown_tx_clone = shutdown_tx.clone();
|
||||
let scrape_interval = args.scrape_interval_secs;
|
||||
let scrape_filter_clone = scrape_filter.clone();
|
||||
tokio::task::spawn_local(async move {
|
||||
scrape_loop(
|
||||
scrape_targets,
|
||||
scrape_store,
|
||||
scrape_metrics,
|
||||
scrape_http,
|
||||
scrape_interval,
|
||||
scrape_filter_clone,
|
||||
loki_clone,
|
||||
shutdown_tx_clone.subscribe(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let log_shutdown = shutdown_tx.subscribe();
|
||||
let log_loki = loki.clone();
|
||||
tokio::task::spawn_local(async move {
|
||||
log_collector(log_loki, log_shutdown).await;
|
||||
});
|
||||
|
||||
let bind_addr: SocketAddr = ([0, 0, 0, 0], args.port).into();
|
||||
tracing::info!(addr = %bind_addr, "HTTP server starting");
|
||||
|
||||
let app_targets = targets.clone();
|
||||
let app_store = store.clone();
|
||||
let app_handle = prometheus_handle.clone();
|
||||
let loki_for_push: Option<Arc<LokiForwarder>> = loki.map(Arc::new);
|
||||
let app_stats = stats_store.clone();
|
||||
|
||||
let server = HttpServer::new(move || {
|
||||
let targets = app_targets.clone();
|
||||
let store = app_store.clone();
|
||||
let handle = app_handle.clone();
|
||||
let stats_store = app_stats.clone();
|
||||
let loki_for_push: Option<Arc<LokiForwarder>> = loki_for_push.clone();
|
||||
actix_web::App::new()
|
||||
.app_data(web::Data::new(targets))
|
||||
.app_data(web::Data::new(store))
|
||||
.app_data(web::Data::new(handle))
|
||||
.app_data(web::Data::new(stats_store))
|
||||
.app_data(web::Data::new(loki_for_push))
|
||||
.route("/metrics", web::get().to(handle_metrics))
|
||||
.route("/api/v1/metrics", web::get().to(handle_metrics))
|
||||
.route("/api/v1/push", web::post().to(handle_push))
|
||||
.route("/api/v1/dashboard", web::get().to(handle_dashboard))
|
||||
.route("/api/v1/stats", web::get().to(handle_stats))
|
||||
.route("/health", web::get().to(handle_health))
|
||||
.route("/api/v1/health", web::get().to(handle_health))
|
||||
.route("/api/v1/targets", web::get().to(handle_targets))
|
||||
})
|
||||
.bind(&bind_addr)?
|
||||
.run();
|
||||
|
||||
let server_handle = server.handle();
|
||||
|
||||
tokio::spawn(server);
|
||||
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
tracing::info!("received Ctrl+C, shutting down");
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
server_handle.stop(true).await;
|
||||
|
||||
if let Some(guard) = otel_guard {
|
||||
guard.shutdown().await;
|
||||
}
|
||||
|
||||
tracing::info!("metrics-aggregator stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_otel_from_args(args: &args::Args) -> Option<OtelGuard> {
|
||||
if args.no_otel {
|
||||
return None;
|
||||
}
|
||||
|
||||
let endpoint = args
|
||||
.otel_endpoint
|
||||
.clone()
|
||||
.or_else(|| std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT").ok())?;
|
||||
|
||||
match otel::init_otel(&endpoint, "metrics-aggregator") {
|
||||
Ok(guard) => {
|
||||
tracing::info!(endpoint = %endpoint, "OTLP tracing enabled");
|
||||
Some(guard)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "OTLP init failed, continuing without traces");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn init_loki_from_args(args: &args::Args) -> Option<LokiForwarder> {
|
||||
if args.no_loki {
|
||||
return None;
|
||||
}
|
||||
|
||||
let url = args
|
||||
.loki_url
|
||||
.clone()
|
||||
.or_else(|| std::env::var("LOKI_URL").ok())?;
|
||||
|
||||
tracing::info!("Loki log forwarding enabled");
|
||||
Some(LokiForwarder::new(url))
|
||||
}
|
||||
|
||||
async fn handle_metrics(
|
||||
store: web::Data<MetricsStore>,
|
||||
stats_store: web::Data<StatsStore>,
|
||||
handle: web::Data<observability::PrometheusHandle>,
|
||||
) -> HttpResponse {
|
||||
let extra = vec![("aggregator_instance".to_string(), "default".to_string())];
|
||||
let scraped = render_aggregated_metrics(store, extra.clone()).await;
|
||||
let pushed = render_pushed_metrics(stats_store).await;
|
||||
let combined = format!("{}{}{}", handle.render(), scraped, pushed);
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/plain; version=0.0.4; charset=utf-8")
|
||||
.body(combined)
|
||||
}
|
||||
|
||||
async fn handle_health() -> HttpResponse {
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(r#"{"status":"ok"}"#)
|
||||
}
|
||||
|
||||
async fn handle_targets(targets: web::Data<Arc<RwLock<Vec<ScrapeTarget>>>>) -> HttpResponse {
|
||||
let guard = targets.read().await;
|
||||
let json = serde_json::to_string(&*guard).unwrap_or_default();
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(json)
|
||||
}
|
||||
|
||||
// ── Push endpoint payload ────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct PushPayload {
|
||||
app: String,
|
||||
#[serde(default)]
|
||||
instance: String,
|
||||
timestamp: i64,
|
||||
#[serde(default)]
|
||||
http: Option<observability::push::HttpPayload>,
|
||||
#[serde(default)]
|
||||
system: Option<observability::push::SystemPayload>,
|
||||
#[serde(default)]
|
||||
business: HashMap<String, f64>,
|
||||
#[serde(default)]
|
||||
token_usage: Option<observability::push::TokenUsagePayload>,
|
||||
#[serde(default)]
|
||||
tasks: Option<observability::push::TaskStatsPayload>,
|
||||
#[serde(default)]
|
||||
latency: HashMap<String, observability::push::LatencySnapshot>,
|
||||
#[serde(default)]
|
||||
logs: Vec<observability::push::LogEntry>,
|
||||
}
|
||||
|
||||
async fn handle_push(
|
||||
stats_store: web::Data<StatsStore>,
|
||||
loki: web::Data<Option<Arc<LokiForwarder>>>,
|
||||
payload: web::Json<PushPayload>,
|
||||
) -> HttpResponse {
|
||||
let app = payload.app.clone();
|
||||
|
||||
stats_store::merge_push_payload(
|
||||
&stats_store,
|
||||
&app,
|
||||
&payload.instance,
|
||||
payload.timestamp,
|
||||
payload.http.as_ref(),
|
||||
payload.system.as_ref(),
|
||||
&payload.business,
|
||||
payload.token_usage.as_ref(),
|
||||
payload.tasks.as_ref(),
|
||||
&payload.latency,
|
||||
&payload.logs,
|
||||
).await;
|
||||
|
||||
// Forward logs to Loki if configured
|
||||
if !payload.logs.is_empty() {
|
||||
if let Some(loki_fwd) = loki.as_ref() {
|
||||
let entries: Vec<LokiEntry> = payload
|
||||
.logs
|
||||
.iter()
|
||||
.map(|l| LokiEntry {
|
||||
timestamp: chrono::DateTime::from_timestamp(l.timestamp, 0)
|
||||
.unwrap_or_else(chrono::Utc::now),
|
||||
line: format!("[{}] {}", l.level.to_lowercase(), l.message),
|
||||
})
|
||||
.collect();
|
||||
if let Err(e) = loki_fwd.push(entries).await {
|
||||
tracing::warn!(error = %e, "loki push on /push failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HttpResponse::Ok().body("ok")
|
||||
}
|
||||
|
||||
async fn scrape_loop(
|
||||
targets: Arc<RwLock<Vec<ScrapeTarget>>>,
|
||||
store: MetricsStore,
|
||||
metrics: AggMetrics,
|
||||
http: HttpClient,
|
||||
interval_secs: u64,
|
||||
scrape_apps_filter: Option<Vec<String>>,
|
||||
_loki: Option<LokiForwarder>,
|
||||
mut shutdown: broadcast::Receiver<()>,
|
||||
) {
|
||||
let mut ticker = interval(Duration::from_secs(interval_secs));
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown.recv() => break,
|
||||
_ = ticker.tick() => {
|
||||
let targets_snapshot = targets.read().await.clone();
|
||||
let count = targets_snapshot.len() as u64;
|
||||
metrics.targets_total.set(count as f64);
|
||||
|
||||
let mut healthy_count = 0u64;
|
||||
|
||||
for target in &targets_snapshot {
|
||||
if let Some(ref filter) = scrape_apps_filter {
|
||||
if !filter.contains(&target.name) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
metrics.scrape_total.increment(1);
|
||||
|
||||
match http.scrape(target).await {
|
||||
ScrapeResult::Success(body, duration_ms) => {
|
||||
metrics.scrape_success.increment(1);
|
||||
metrics.scrape_duration.record(duration_ms);
|
||||
|
||||
let parsed = scrape::parse_prometheus(&body);
|
||||
update_store(store.clone(), &target.name, parsed).await;
|
||||
healthy_count += 1;
|
||||
}
|
||||
ScrapeResult::Timeout => {
|
||||
metrics.scrape_failures.increment(1);
|
||||
metrics.scrape_errors_timeout.increment(1);
|
||||
tracing::warn!(target = %target.name, "scrape timeout");
|
||||
}
|
||||
ScrapeResult::ConnectionError(e) => {
|
||||
metrics.scrape_failures.increment(1);
|
||||
metrics.scrape_errors_connection.increment(1);
|
||||
tracing::warn!(target = %target.name, error = %e, "scrape connection error");
|
||||
}
|
||||
ScrapeResult::HttpError(status) => {
|
||||
metrics.scrape_failures.increment(1);
|
||||
tracing::warn!(target = %target.name, status = status, "scrape HTTP error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metrics.targets_healthy.set(healthy_count as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_store(store: MetricsStore, target_name: &str, metrics: Vec<scrape::PromMetric>) {
|
||||
let mut guard = store.write().await;
|
||||
guard.insert(target_name.to_string(), metrics);
|
||||
}
|
||||
|
||||
async fn render_aggregated_metrics(
|
||||
store: web::Data<MetricsStore>,
|
||||
extra_group_labels: Vec<(String, String)>,
|
||||
) -> String {
|
||||
let guard = store.read().await;
|
||||
let mut output = String::new();
|
||||
|
||||
for (target_name, metrics) in guard.iter() {
|
||||
for metric in metrics {
|
||||
let mut labels = metric.labels.clone();
|
||||
labels.insert(
|
||||
"aggregated_by".to_string(),
|
||||
"metrics-aggregator".to_string(),
|
||||
);
|
||||
labels.insert("source_target".to_string(), target_name.clone());
|
||||
for (k, v) in &extra_group_labels {
|
||||
labels.insert(k.clone(), v.clone());
|
||||
}
|
||||
|
||||
let label_str = if labels.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let pairs: Vec<String> = labels
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
format!(
|
||||
r#"{}="{}""#,
|
||||
k,
|
||||
v.replace('\\', "\\\\").replace('"', "\\\"")
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
format!("{{{}}}", pairs.join(","))
|
||||
};
|
||||
|
||||
let _ = writeln!(&mut output, "{}{} {}", metric.name, label_str, metric.value);
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
async fn render_pushed_metrics(stats_store: web::Data<StatsStore>) -> String {
|
||||
let guard = stats_store.read().await;
|
||||
let mut output = String::new();
|
||||
|
||||
for (app_name, entry) in guard.iter() {
|
||||
let labels = [
|
||||
format!(r#"app="{}""#, app_name),
|
||||
"aggregated_by".to_string(),
|
||||
"metrics-aggregator".to_string(),
|
||||
"push_source=true".to_string(),
|
||||
];
|
||||
|
||||
let label_str = format!("{{{}}}", labels.join(","));
|
||||
let h = &entry;
|
||||
|
||||
let _ = writeln!(
|
||||
&mut output,
|
||||
"push_http_requests_total{} {}",
|
||||
label_str,
|
||||
h.requests_total
|
||||
);
|
||||
let _ = writeln!(
|
||||
&mut output,
|
||||
"push_http_request_duration_ms_total{} {}",
|
||||
label_str,
|
||||
h.request_duration_ms_total
|
||||
);
|
||||
let _ = writeln!(&mut output, "push_http_requests_2xx{} {}", label_str, h.requests_2xx);
|
||||
let _ = writeln!(&mut output, "push_http_requests_4xx{} {}", label_str, h.requests_4xx);
|
||||
let _ = writeln!(&mut output, "push_http_requests_5xx{} {}", label_str, h.requests_5xx);
|
||||
|
||||
for (endpoint, &count) in &h.endpoints {
|
||||
let sanitized = endpoint.replace([' ', '/'], "_").to_lowercase();
|
||||
let ep_labels = format!(r#"app="{}",endpoint="{}",aggregated_by="metrics-aggregator",push_source="true""#, app_name, sanitized);
|
||||
let _ = writeln!(&mut output, "push_http_endpoint_requests_total{{{}}} {}", ep_labels, count);
|
||||
}
|
||||
|
||||
// System metrics in Prometheus format
|
||||
let sys_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
|
||||
let _ = writeln!(&mut output, "system_cpu_usage_percent{{{}}} {}", sys_labels, h.cpu_usage_percent);
|
||||
let _ = writeln!(&mut output, "system_memory_used_mb{{{}}} {}", sys_labels, h.memory_used_mb);
|
||||
let _ = writeln!(&mut output, "system_memory_total_mb{{{}}} {}", sys_labels, h.memory_total_mb);
|
||||
let _ = writeln!(&mut output, "system_uptime_secs{{{}}} {}", sys_labels, h.uptime_secs);
|
||||
|
||||
// Business counters
|
||||
for (counter_name, value) in &h.business {
|
||||
let biz_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
|
||||
let _ = writeln!(&mut output, "{}{{{}}} {}", counter_name, biz_labels, value);
|
||||
}
|
||||
|
||||
// Token usage
|
||||
let ai_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
|
||||
let _ = writeln!(&mut output, "ai_input_tokens_total{{{}}} {}", ai_labels, h.ai_input_tokens_total);
|
||||
let _ = writeln!(&mut output, "ai_output_tokens_total{{{}}} {}", ai_labels, h.ai_output_tokens_total);
|
||||
let _ = writeln!(&mut output, "ai_calls_total{{{}}} {}", ai_labels, h.ai_calls_total);
|
||||
|
||||
// Latency per endpoint
|
||||
for (endpoint, lat) in &h.latency {
|
||||
let lat_labels = format!(r#"app="{}",endpoint="{}",aggregated_by="metrics-aggregator""#, app_name, endpoint);
|
||||
let _ = writeln!(&mut output, "latency_p99_ms{{{}}} {}", lat_labels, lat.p99_ms);
|
||||
let _ = writeln!(&mut output, "latency_p90_ms{{{}}} {}", lat_labels, lat.p90_ms);
|
||||
let _ = writeln!(&mut output, "latency_p50_ms{{{}}} {}", lat_labels, lat.p50_ms);
|
||||
let _ = writeln!(&mut output, "latency_max_ms{{{}}} {}", lat_labels, lat.max_ms);
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
// ── JSON API handlers ────────────────────────────────────────────────────────
|
||||
|
||||
async fn handle_dashboard(stats_store: web::Data<StatsStore>) -> HttpResponse {
|
||||
let dashboard = stats_store::build_dashboard(&stats_store).await;
|
||||
let json = serde_json::to_string(&dashboard).unwrap_or_default();
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(json)
|
||||
}
|
||||
|
||||
async fn handle_stats(stats_store: web::Data<StatsStore>) -> HttpResponse {
|
||||
// Returns per-app stats as JSON
|
||||
let guard = stats_store.read().await;
|
||||
let json = serde_json::to_string(&*guard).unwrap_or_default();
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(json)
|
||||
}
|
||||
|
||||
async fn log_collector(loki: Option<LokiForwarder>, mut shutdown: broadcast::Receiver<()>) {
|
||||
let stdin = tokio::io::stdin();
|
||||
let mut reader = tokio::io::BufReader::new(stdin);
|
||||
let mut interval_tick = interval(Duration::from_secs(1));
|
||||
let mut batch: Vec<LokiEntry> = Vec::with_capacity(100);
|
||||
let mut line_buf = String::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown.recv() => break,
|
||||
_ = interval_tick.tick() => {
|
||||
if !batch.is_empty() {
|
||||
if let Some(ref loki) = loki {
|
||||
if let Err(e) = loki.push(std::mem::take(&mut batch)).await {
|
||||
tracing::warn!(error = %e, "Loki push failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = async { line_buf.clear(); reader.read_line(&mut line_buf).await.ok() } => {
|
||||
if !line_buf.is_empty() {
|
||||
let line = line_buf.trim_end().to_string();
|
||||
if !line.is_empty() {
|
||||
batch.push(LokiEntry {
|
||||
timestamp: chrono::Utc::now(),
|
||||
line,
|
||||
});
|
||||
if batch.len() >= 100 {
|
||||
if let Some(ref loki) = loki {
|
||||
if let Err(e) = loki.push(std::mem::take(&mut batch)).await {
|
||||
tracing::warn!(error = %e, "Loki push failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
99
apps/metrics/src/metrics.rs
Normal file
99
apps/metrics/src/metrics.rs
Normal file
@ -0,0 +1,99 @@
|
||||
use metrics::{describe_counter, describe_gauge, describe_histogram, Counter, Gauge, Histogram, Unit};
|
||||
|
||||
pub fn init() {
|
||||
describe_gauge!(
|
||||
"aggregator_targets_total",
|
||||
Unit::Count,
|
||||
"Total number of scrape targets known to the aggregator"
|
||||
);
|
||||
describe_gauge!(
|
||||
"aggregator_targets_healthy",
|
||||
Unit::Count,
|
||||
"Number of scrape targets that responded last scrape"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_scrape_total",
|
||||
Unit::Count,
|
||||
"Total number of scrape attempts"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_scrape_success",
|
||||
Unit::Count,
|
||||
"Successful scrapes"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_scrape_failures",
|
||||
Unit::Count,
|
||||
"Failed scrape attempts"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_scrape_errors_parse",
|
||||
Unit::Count,
|
||||
"Scrape failures due to parse errors"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_scrape_errors_timeout",
|
||||
Unit::Count,
|
||||
"Scrape failures due to timeout"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_scrape_errors_connection",
|
||||
Unit::Count,
|
||||
"Scrape failures due to connection errors"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_targets_discovered",
|
||||
Unit::Count,
|
||||
"Total targets discovered"
|
||||
);
|
||||
describe_counter!(
|
||||
"aggregator_targets_lost",
|
||||
Unit::Count,
|
||||
"Total targets that disappeared"
|
||||
);
|
||||
describe_histogram!(
|
||||
"aggregator_scrape_duration_ms",
|
||||
Unit::Milliseconds,
|
||||
"Scrape duration in milliseconds"
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AggMetrics {
|
||||
pub targets_total: Gauge,
|
||||
pub targets_healthy: Gauge,
|
||||
pub scrape_total: Counter,
|
||||
pub scrape_success: Counter,
|
||||
pub scrape_failures: Counter,
|
||||
pub scrape_errors_parse: Counter,
|
||||
pub scrape_errors_timeout: Counter,
|
||||
pub scrape_errors_connection: Counter,
|
||||
pub targets_discovered: Counter,
|
||||
pub targets_lost: Counter,
|
||||
pub scrape_duration: Histogram,
|
||||
}
|
||||
|
||||
impl Default for AggMetrics {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
targets_total: metrics::gauge!("aggregator_targets_total"),
|
||||
targets_healthy: metrics::gauge!("aggregator_targets_healthy"),
|
||||
scrape_total: metrics::counter!("aggregator_scrape_total"),
|
||||
scrape_success: metrics::counter!("aggregator_scrape_success"),
|
||||
scrape_failures: metrics::counter!("aggregator_scrape_failures"),
|
||||
scrape_errors_parse: metrics::counter!("aggregator_scrape_errors_parse"),
|
||||
scrape_errors_timeout: metrics::counter!("aggregator_scrape_errors_timeout"),
|
||||
scrape_errors_connection: metrics::counter!("aggregator_scrape_errors_connection"),
|
||||
targets_discovered: metrics::counter!("aggregator_targets_discovered"),
|
||||
targets_lost: metrics::counter!("aggregator_targets_lost"),
|
||||
scrape_duration: metrics::histogram!("aggregator_scrape_duration_ms"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AggMetrics {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
40
apps/metrics/src/otel.rs
Normal file
40
apps/metrics/src/otel.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use anyhow::Context;
|
||||
use opentelemetry::trace::TracerProvider;
|
||||
use opentelemetry_otlp::{SpanExporter, WithExportConfig};
|
||||
use opentelemetry_sdk::trace as sdktrace;
|
||||
use tracing_opentelemetry::layer;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub struct OtelGuard {
|
||||
provider: sdktrace::SdkTracerProvider,
|
||||
}
|
||||
|
||||
impl OtelGuard {
|
||||
pub async fn shutdown(self) {
|
||||
if let Err(e) = self.provider.shutdown() {
|
||||
tracing::warn!(error = %e, "OTLP shutdown error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_otel(endpoint: &str, service_name: &str) -> anyhow::Result<OtelGuard> {
|
||||
let exporter = SpanExporter::builder()
|
||||
.with_http()
|
||||
.with_endpoint(endpoint)
|
||||
.build()
|
||||
.context("build OTLP exporter")?;
|
||||
|
||||
let tracer_provider = sdktrace::SdkTracerProvider::builder()
|
||||
.with_batch_exporter(exporter)
|
||||
.build();
|
||||
|
||||
let tracer = tracer_provider.tracer(service_name.to_string());
|
||||
let otel_layer = layer().with_tracer(tracer);
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(otel_layer)
|
||||
.try_init()
|
||||
.context("install OTLP tracing subscriber")?;
|
||||
|
||||
Ok(OtelGuard { provider: tracer_provider })
|
||||
}
|
||||
135
apps/metrics/src/scrape.rs
Normal file
135
apps/metrics/src/scrape.rs
Normal file
@ -0,0 +1,135 @@
|
||||
use awc::Client;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::target::ScrapeTarget;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HttpClient {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub fn new(timeout_secs: u64) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(timeout_secs))
|
||||
.finish();
|
||||
Self { client }
|
||||
}
|
||||
|
||||
pub async fn scrape(&self, target: &ScrapeTarget) -> ScrapeResult {
|
||||
let start = std::time::Instant::now();
|
||||
let url = target.url();
|
||||
|
||||
let mut resp = match self.client.get(url).send().await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
if msg.contains("timeout") || msg.contains("TimedOut") || msg.contains("timed out")
|
||||
{
|
||||
return ScrapeResult::Timeout;
|
||||
}
|
||||
return ScrapeResult::ConnectionError(msg);
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return ScrapeResult::HttpError(resp.status().as_u16());
|
||||
}
|
||||
|
||||
let body = match resp.body().await {
|
||||
Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
|
||||
Err(e) => return ScrapeResult::ConnectionError(e.to_string()),
|
||||
};
|
||||
|
||||
let scrape_ms = start.elapsed().as_millis() as f64;
|
||||
ScrapeResult::Success(body, scrape_ms)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ScrapeResult {
|
||||
Success(String, f64),
|
||||
Timeout,
|
||||
ConnectionError(String),
|
||||
HttpError(u16),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PromMetric {
|
||||
pub name: String,
|
||||
pub value: f64,
|
||||
pub labels: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub fn parse_prometheus(body: &str) -> Vec<PromMetric> {
|
||||
let mut metrics = Vec::new();
|
||||
for line in body.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (name_and_labels, value_str) = match line.find(' ') {
|
||||
Some(pos) => (&line[..pos], &line[pos + 1..]),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let value: f64 = match value_str
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.and_then(|v| v.parse().ok())
|
||||
{
|
||||
Some(v) => v,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let (metric_name, labels) = if let Some(brace) = name_and_labels.find('{') {
|
||||
let name = &name_and_labels[..brace];
|
||||
let label_str = &name_and_labels[brace + 1..name_and_labels.len() - 1];
|
||||
let labels = parse_labels(label_str);
|
||||
(name.to_string(), labels)
|
||||
} else {
|
||||
(name_and_labels.to_string(), HashMap::new())
|
||||
};
|
||||
|
||||
metrics.push(PromMetric {
|
||||
name: metric_name,
|
||||
value,
|
||||
labels,
|
||||
});
|
||||
}
|
||||
metrics
|
||||
}
|
||||
|
||||
pub fn parse_labels(s: &str) -> HashMap<String, String> {
|
||||
let mut labels = HashMap::new();
|
||||
let mut remaining = s;
|
||||
while !remaining.is_empty() {
|
||||
if let Some(eq) = remaining.find('=') {
|
||||
let key = remaining[..eq].trim().to_string();
|
||||
remaining = &remaining[eq + 1..];
|
||||
let (value, rest) = if remaining.starts_with('"') {
|
||||
let end = remaining[1..]
|
||||
.find('"')
|
||||
.map(|p| p + 1)
|
||||
.unwrap_or(remaining.len());
|
||||
(&remaining[1..end], &remaining[end + 1..])
|
||||
} else if remaining.starts_with('\'') {
|
||||
let end = remaining[1..]
|
||||
.find('\'')
|
||||
.map(|p| p + 1)
|
||||
.unwrap_or(remaining.len());
|
||||
(&remaining[1..end], &remaining[end + 1..])
|
||||
} else {
|
||||
let end = remaining
|
||||
.find(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
|
||||
.unwrap_or(remaining.len());
|
||||
(&remaining[..end], &remaining[end..])
|
||||
};
|
||||
labels.insert(key, value.to_string());
|
||||
remaining = rest.trim_start_matches(',').trim_start();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
labels
|
||||
}
|
||||
210
apps/metrics/src/stats_store.rs
Normal file
210
apps/metrics/src/stats_store.rs
Normal file
@ -0,0 +1,210 @@
|
||||
//! Stats store: receives expanded push payloads from all apps,
|
||||
//! aggregates over time, computes derived statistics (p99 etc),
|
||||
//! and provides JSON API for external consumption.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use serde::Serialize;
|
||||
|
||||
/// Per-app, per-instance aggregated stats entry.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct AppStats {
|
||||
/// Last seen timestamp.
|
||||
pub last_seen: i64,
|
||||
/// Number of push samples received.
|
||||
pub sample_count: u64,
|
||||
|
||||
// ── HTTP ─────────────────────────────────────────────────────
|
||||
pub requests_total: u64,
|
||||
pub request_duration_ms_total: u64,
|
||||
pub requests_2xx: u64,
|
||||
pub requests_4xx: u64,
|
||||
pub requests_5xx: u64,
|
||||
pub endpoints: HashMap<String, u64>,
|
||||
|
||||
// ── System ───────────────────────────────────────────────────
|
||||
pub cpu_usage_percent: f32,
|
||||
pub memory_used_mb: u64,
|
||||
pub memory_total_mb: u64,
|
||||
pub uptime_secs: u64,
|
||||
|
||||
// ── Business counters ────────────────────────────────────────
|
||||
pub business: HashMap<String, f64>,
|
||||
|
||||
// ── Token usage ──────────────────────────────────────────────
|
||||
pub ai_input_tokens_total: i64,
|
||||
pub ai_output_tokens_total: i64,
|
||||
pub ai_calls_total: i64,
|
||||
pub ai_calls_success: i64,
|
||||
pub ai_calls_failure: i64,
|
||||
pub token_by_model: HashMap<String, ModelTokenStats>,
|
||||
|
||||
// ── Tasks ────────────────────────────────────────────────────
|
||||
pub tasks_queued: i64,
|
||||
pub tasks_running: i64,
|
||||
pub tasks_completed: i64,
|
||||
pub tasks_failed: i64,
|
||||
|
||||
// ── Latency ──────────────────────────────────────────────────
|
||||
pub latency: HashMap<String, LatencyStats>,
|
||||
|
||||
// ── Logs ─────────────────────────────────────────────────────
|
||||
#[serde(skip_serializing)]
|
||||
pub logs: Vec<(i64, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ModelTokenStats {
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub calls: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct LatencyStats {
|
||||
pub p50_ms: f64,
|
||||
pub p90_ms: f64,
|
||||
pub p99_ms: f64,
|
||||
pub max_ms: f64,
|
||||
pub count: u64,
|
||||
}
|
||||
|
||||
/// The global stats store: app_name → AppStats.
|
||||
pub type StatsStore = Arc<RwLock<HashMap<String, AppStats>>>;
|
||||
|
||||
/// Merge a new push payload into the stats store.
|
||||
pub async fn merge_push_payload(
|
||||
store: &StatsStore,
|
||||
app: &str,
|
||||
_instance: &str,
|
||||
timestamp: i64,
|
||||
http: Option<&observability::push::HttpPayload>,
|
||||
system: Option<&observability::push::SystemPayload>,
|
||||
business: &HashMap<String, f64>,
|
||||
token_usage: Option<&observability::push::TokenUsagePayload>,
|
||||
tasks: Option<&observability::push::TaskStatsPayload>,
|
||||
latency: &HashMap<String, observability::push::LatencySnapshot>,
|
||||
logs: &[observability::push::LogEntry],
|
||||
) {
|
||||
// Use app_name as key (merge across instances for aggregation)
|
||||
let mut guard = store.write().await;
|
||||
let entry = guard.entry(app.to_string()).or_default();
|
||||
entry.last_seen = timestamp;
|
||||
entry.sample_count += 1;
|
||||
|
||||
// HTTP — accumulate (not replace, so we get totals over time)
|
||||
if let Some(http) = http {
|
||||
entry.requests_total = http.requests_total;
|
||||
entry.request_duration_ms_total = http.request_duration_ms_total;
|
||||
entry.requests_2xx = http.requests_2xx;
|
||||
entry.requests_4xx = http.requests_4xx;
|
||||
entry.requests_5xx = http.requests_5xx;
|
||||
for (ep, count) in &http.endpoints {
|
||||
*entry.endpoints.entry(ep.clone()).or_insert(0) = *count;
|
||||
}
|
||||
}
|
||||
|
||||
// System — replace (current snapshot, not cumulative)
|
||||
if let Some(sys) = system {
|
||||
entry.cpu_usage_percent = sys.cpu_usage_percent;
|
||||
entry.memory_used_mb = sys.memory_used_mb;
|
||||
entry.memory_total_mb = sys.memory_total_mb;
|
||||
entry.uptime_secs = sys.uptime_secs;
|
||||
}
|
||||
|
||||
// Business — replace with latest snapshot
|
||||
entry.business = business.clone();
|
||||
|
||||
// Token usage — replace with latest
|
||||
if let Some(tu) = token_usage {
|
||||
entry.ai_input_tokens_total = tu.ai_input_tokens_total;
|
||||
entry.ai_output_tokens_total = tu.ai_output_tokens_total;
|
||||
entry.ai_calls_total = tu.ai_calls_total;
|
||||
entry.ai_calls_success = tu.ai_calls_success;
|
||||
entry.ai_calls_failure = tu.ai_calls_failure;
|
||||
for (model, usage) in &tu.by_model {
|
||||
let ms = entry.token_by_model.entry(model.clone()).or_default();
|
||||
ms.input_tokens = usage.input_tokens;
|
||||
ms.output_tokens = usage.output_tokens;
|
||||
ms.calls = usage.calls;
|
||||
}
|
||||
}
|
||||
|
||||
// Tasks — replace with latest
|
||||
if let Some(t) = tasks {
|
||||
entry.tasks_queued = t.queued;
|
||||
entry.tasks_running = t.running;
|
||||
entry.tasks_completed = t.completed;
|
||||
entry.tasks_failed = t.failed;
|
||||
}
|
||||
|
||||
// Latency — replace with latest snapshots
|
||||
for (endpoint, snap) in latency {
|
||||
let ls = entry.latency.entry(endpoint.clone()).or_default();
|
||||
ls.p50_ms = snap.p50_ms;
|
||||
ls.p90_ms = snap.p90_ms;
|
||||
ls.p99_ms = snap.p99_ms;
|
||||
ls.max_ms = snap.max_ms;
|
||||
ls.count = snap.count;
|
||||
}
|
||||
|
||||
// Logs — append (keep last 300 lines)
|
||||
for log in logs {
|
||||
entry.logs.push((log.timestamp, format!("[{}] {}", log.level.to_lowercase(), log.message)));
|
||||
}
|
||||
let cutoff = chrono::Utc::now().timestamp() - 300;
|
||||
entry.logs.retain(|(ts, _)| *ts >= cutoff);
|
||||
}
|
||||
|
||||
/// Dashboard response combining all apps' stats.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct DashboardResponse {
|
||||
/// Timestamp of this snapshot.
|
||||
pub timestamp: i64,
|
||||
/// Total number of app instances reporting.
|
||||
pub app_count: u64,
|
||||
/// Per-app aggregated stats.
|
||||
pub apps: HashMap<String, AppStats>,
|
||||
/// Derived: average p99 latency across all apps.
|
||||
pub avg_p99_ms: f64,
|
||||
/// Derived: total tokens consumed across all apps.
|
||||
pub total_input_tokens: i64,
|
||||
pub total_output_tokens: i64,
|
||||
/// Derived: total AI calls across all apps.
|
||||
pub total_ai_calls: i64,
|
||||
}
|
||||
|
||||
/// Build the dashboard response from the stats store.
|
||||
pub async fn build_dashboard(store: &StatsStore) -> DashboardResponse {
|
||||
let guard = store.read().await;
|
||||
|
||||
let mut avg_p99 = 0.0;
|
||||
let mut p99_count = 0;
|
||||
let mut total_input = 0i64;
|
||||
let mut total_output = 0i64;
|
||||
let mut total_calls = 0i64;
|
||||
|
||||
for (_, stats) in guard.iter() {
|
||||
total_input += stats.ai_input_tokens_total;
|
||||
total_output += stats.ai_output_tokens_total;
|
||||
total_calls += stats.ai_calls_total;
|
||||
|
||||
for (_, lat) in &stats.latency {
|
||||
avg_p99 += lat.p99_ms;
|
||||
p99_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let avg_p99_ms = if p99_count > 0 { avg_p99 / p99_count as f64 } else { 0.0 };
|
||||
|
||||
DashboardResponse {
|
||||
timestamp: chrono::Utc::now().timestamp(),
|
||||
app_count: guard.len() as u64,
|
||||
apps: guard.clone(),
|
||||
avg_p99_ms,
|
||||
total_input_tokens: total_input,
|
||||
total_output_tokens: total_output,
|
||||
total_ai_calls: total_calls,
|
||||
}
|
||||
}
|
||||
34
apps/metrics/src/target.rs
Normal file
34
apps/metrics/src/target.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use std::collections::HashMap;
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ScrapeTarget {
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
#[serde(default = "default_metrics_path")]
|
||||
pub metrics_path: String,
|
||||
#[serde(default)]
|
||||
pub labels: HashMap<String, String>,
|
||||
}
|
||||
|
||||
fn default_metrics_path() -> String {
|
||||
"/metrics".to_string()
|
||||
}
|
||||
|
||||
impl ScrapeTarget {
|
||||
pub fn url(&self) -> String {
|
||||
if self.metrics_path.starts_with("http") {
|
||||
self.metrics_path.clone()
|
||||
} else {
|
||||
format!("http://{}{}", self.addr, self.metrics_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_targets_from_file(path: &str) -> anyhow::Result<Vec<ScrapeTarget>> {
|
||||
let content = tokio::fs::read_to_string(path).await.context("read targets file")?;
|
||||
let targets: Vec<ScrapeTarget> = serde_json::from_str(&content)
|
||||
.with_context(|| format!("parse targets file {path}"))?;
|
||||
Ok(targets)
|
||||
}
|
||||
@ -7,6 +7,8 @@ edition.workspace = true
|
||||
actix-web = { workspace = true }
|
||||
actix-files = { workspace = true }
|
||||
actix-cors = { workspace = true }
|
||||
observability = { workspace = true }
|
||||
metrics-exporter-prometheus = "0.13"
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
@ -5,8 +5,10 @@ use actix_web::{http::header, web, App, HttpResponse, HttpServer};
|
||||
use futures::future::LocalBoxFuture;
|
||||
use log::info;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
|
||||
|
||||
/// Static file server for avatar, blob, and other static files
|
||||
/// Serves files from /data/{type} directories
|
||||
@ -119,7 +121,16 @@ where
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
|
||||
init_tracing_subscriber("info", false);
|
||||
let prometheus_handle = Arc::new(install_recorder());
|
||||
let http_metrics = Arc::new(HttpMetrics::new());
|
||||
|
||||
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
|
||||
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
|
||||
let pusher = MetricsPusher::new(&push_url, "static");
|
||||
pusher.spawn(http_metrics.clone(), prometheus_handle.clone(), std::time::Duration::from_secs(15));
|
||||
info!("Metrics pusher started (interval 15s, url: {})", push_url);
|
||||
}
|
||||
|
||||
let cfg = StaticConfig::from_env();
|
||||
let bind = std::env::var("STATIC_BIND").unwrap_or_else(|_| "0.0.0.0:8081".to_string());
|
||||
@ -142,6 +153,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let root = root.clone();
|
||||
|
||||
let cors = if cors_enabled {
|
||||
// WARNING: allow_any_origin is intentional for static asset serving (CDN mode)
|
||||
// Ensure no sensitive files are served from this directory
|
||||
Cors::default()
|
||||
.allow_any_origin()
|
||||
.allowed_methods(vec!["GET", "HEAD", "OPTIONS"])
|
||||
|
||||
88
build.sh
Normal file
88
build.sh
Normal file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────
|
||||
RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m'
|
||||
log() { echo -e "${GREEN}[OK]${NC} $*"; }
|
||||
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
|
||||
err() { echo -e "${RED}[ERR]${NC} $*"; exit 1; }
|
||||
|
||||
command_exists() { command -v "$1" &>/dev/null; }
|
||||
|
||||
# ── 1. Rust ─────────────────────────────────────────────────────────
|
||||
if command_exists rustc; then
|
||||
log "Rust $(rustc --version)"
|
||||
else
|
||||
warn "Rust not found, installing via rustup..."
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
# shellcheck disable=SC1091
|
||||
source "$HOME/.cargo/env"
|
||||
log "Rust installed: $(rustc --version)"
|
||||
fi
|
||||
|
||||
# ── 2. Node.js ──────────────────────────────────────────────────────
|
||||
if command_exists node; then
|
||||
log "Node.js $(node --version)"
|
||||
else
|
||||
warn "Node.js not found, installing via nvm..."
|
||||
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash
|
||||
# shellcheck disable=SC1090
|
||||
export NVM_DIR="${HOME}/.nvm"
|
||||
[ -s "$NVM_DIR/nvm.sh" ] && source "$NVM_DIR/nvm.sh"
|
||||
nvm install --lts
|
||||
log "Node.js installed: $(node --version)"
|
||||
fi
|
||||
|
||||
# ── 2b. Bun ─────────────────────────────────────────────────────────
|
||||
if command_exists bun; then
|
||||
log "Bun $(bun --version)"
|
||||
else
|
||||
warn "Bun not found, installing..."
|
||||
curl -fsSL https://bun.sh/install | bash
|
||||
# shellcheck disable=SC1091
|
||||
[ -s "$HOME/.bun/_bun" ] && export PATH="$HOME/.bun/bin:$PATH"
|
||||
log "Bun installed: $(bun --version)"
|
||||
fi
|
||||
|
||||
# ── 3. Docker ───────────────────────────────────────────────────────
|
||||
if command_exists docker; then
|
||||
log "Docker $(docker --version)"
|
||||
else
|
||||
warn "Docker not found, installing..."
|
||||
curl -fsSL https://get.docker.com | sh
|
||||
log "Docker installed: $(docker --version)"
|
||||
fi
|
||||
|
||||
# ── 4. Frontend build ───────────────────────────────────────────────
|
||||
log "Running bun install..."
|
||||
bun install
|
||||
|
||||
log "Running bun run build..."
|
||||
bun run build
|
||||
|
||||
# ── 5. Rust build ───────────────────────────────────────────────────
|
||||
log "Running cargo build --release --workspace..."
|
||||
cargo build --release --workspace
|
||||
|
||||
# ── 6. Docker images ────────────────────────────────────────────────
|
||||
TAG=$(git rev-parse --short HEAD)
|
||||
log "Building Docker images with tag: $TAG"
|
||||
|
||||
IMAGES=(
|
||||
"docker/app.Dockerfile app:$TAG"
|
||||
"docker/email.Dockerfile email-worker:$TAG"
|
||||
"docker/githook.Dockerfile git-hook:$TAG"
|
||||
"docker/gitserver.Dockerfile gitserver:$TAG"
|
||||
"docker/metrics.Dockerfile metrics-aggregator:$TAG"
|
||||
"docker/static.Dockerfile static-server:$TAG"
|
||||
"docker/gingress.Dockerfile gingress:$TAG"
|
||||
)
|
||||
|
||||
for entry in "${IMAGES[@]}"; do
|
||||
read -r dockerfile tag <<< "$entry"
|
||||
log "Building $tag..."
|
||||
docker build -f "$dockerfile" -t "$tag" .
|
||||
done
|
||||
|
||||
log "All images built successfully."
|
||||
docker images | grep -E "app|email-worker|git-hook|gitserver|metrics-aggregator|static-server|gingress" | grep "$TAG" || true
|
||||
23
deploy/.helmignore
Normal file
23
deploy/.helmignore
Normal file
@ -0,0 +1,23 @@
|
||||
# Patterns to ignore when building packages.
|
||||
# This supports shell glob matching, relative path matching, and
|
||||
# negation (prefixed with !). Only one pattern per line.
|
||||
.DS_Store
|
||||
# Common VCS dirs
|
||||
.git/
|
||||
.gitignore
|
||||
.bzr/
|
||||
.bzrignore
|
||||
.hg/
|
||||
.hgignore
|
||||
.svn/
|
||||
# Common backup files
|
||||
*.swp
|
||||
*.bak
|
||||
*.tmp
|
||||
*.orig
|
||||
*~
|
||||
# Various IDEs
|
||||
.project
|
||||
.idea/
|
||||
*.tmproj
|
||||
.vscode/
|
||||
6
deploy/Chart.yaml
Normal file
6
deploy/Chart.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
apiVersion: v2
|
||||
name: deploy
|
||||
description: Helm chart for the project backend services
|
||||
type: application
|
||||
version: 0.1.0
|
||||
appVersion: "0.2.9"
|
||||
198
deploy/README.md
Normal file
198
deploy/README.md
Normal file
@ -0,0 +1,198 @@
|
||||
# Deploy Helm Chart
|
||||
|
||||
Monolithic Helm chart for all backend services.
|
||||
|
||||
## Services
|
||||
|
||||
| Service | Port(s) | Replicas | HPA | Purpose |
|
||||
|---|---|---|---|---|
|
||||
| `app` | 3000 (HTTP) | 2 | 2–10 | Main API server |
|
||||
| `gitserver` | 8021 (HTTP), 2222 (SSH) | 1 | 1–5 | Git HTTP + SSH server |
|
||||
| `email_worker` | 8084 (HTTP) | 1 | disabled | Email queue consumer (single instance only) |
|
||||
| `git_hook` | 8083 (HTTP) | 1 | 1–5 | Git hook worker pool |
|
||||
| `metrics_aggregator` | 9090 (HTTP) | 1 | 1–5 | Prometheus scrape + Loki push |
|
||||
| `static_server` | 8081 (HTTP) | 1 | 1–5 | Static file server (avatars, blobs, media) |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The following resources must exist in the cluster **before** installing the Helm chart. They are not managed by Helm — install, upgrade, and uninstall of the chart will not touch them.
|
||||
|
||||
### 1. Namespace
|
||||
|
||||
```bash
|
||||
kubectl create namespace app
|
||||
```
|
||||
|
||||
### 2. PVC (aliyun-nfs, 200Ti, ReadWriteMany)
|
||||
|
||||
```bash
|
||||
kubectl apply -f - <<'EOF'
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: shared-data
|
||||
namespace: app
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteMany
|
||||
resources:
|
||||
requests:
|
||||
storage: 200Ti
|
||||
storageClassName: aliyun-nfs
|
||||
EOF
|
||||
```
|
||||
|
||||
> The chart references this PVC by name. If you use a different name, pass `--set pvcName=your-pvc-name` to Helm.
|
||||
|
||||
### 3. ConfigMap
|
||||
|
||||
```bash
|
||||
kubectl apply -f - <<'EOF'
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: app-env
|
||||
namespace: app
|
||||
data:
|
||||
APP_REPOS_ROOT: "/data/repos"
|
||||
APP_AVATAR_PATH: "/data/avatars"
|
||||
STORAGE_PATH: "/data/files"
|
||||
STATIC_ROOT: "/data"
|
||||
APP_LOG_LEVEL: "info"
|
||||
APP_COOKIE_SECURE: "false"
|
||||
APP_DOMAIN_URL: "https://your-domain.com"
|
||||
APP_DATABASE_URL: "postgres://user:pass@postgres:5432/app"
|
||||
APP_REDIS_URL: "redis://redis:6379"
|
||||
APP_AI_BASIC_URL: "https://api.openai.com/v1"
|
||||
APP_AI_API_KEY: "sk-..."
|
||||
APP_SMTP_PASSWORD: "..."
|
||||
APP_SESSION_SECRET: "min-32-byte-random-string..."
|
||||
APP_SSH_SERVER_PRIVATE_KEY: "<hex-encoded-private-key>"
|
||||
EOF
|
||||
```
|
||||
|
||||
| Variable | Default / Example | Required |
|
||||
|---|---|---|
|
||||
| `APP_REPOS_ROOT` | `/data/repos` | Yes |
|
||||
| `APP_AVATAR_PATH` | `/data/avatars` | Yes |
|
||||
| `STORAGE_PATH` | `/data/files` | Yes |
|
||||
| `STATIC_ROOT` | `/data` | Yes |
|
||||
| `APP_LOG_LEVEL` | `info` | No |
|
||||
| `APP_COOKIE_SECURE` | `false` | No |
|
||||
| `APP_DOMAIN_URL` | `https://your-domain.com` | Yes |
|
||||
| `APP_DATABASE_URL` | `postgres://...` | **Yes** |
|
||||
| `APP_REDIS_URL` | `redis://...` | **Yes** |
|
||||
| `APP_AI_BASIC_URL` | `https://api.openai.com/v1` | **Yes** |
|
||||
| `APP_AI_API_KEY` | `sk-...` | **Yes** |
|
||||
| `APP_SMTP_PASSWORD` | `...` | **Yes** |
|
||||
| `APP_SESSION_SECRET` | min 32 bytes | **Yes** |
|
||||
| `APP_SSH_SERVER_PRIVATE_KEY` | hex-encoded PEM | **Yes** |
|
||||
| `APP_SSH_PORT` | `2222` | Yes (k8s) |
|
||||
|
||||
> **SSH host key**: `APP_SSH_SERVER_PRIVATE_KEY` must be the hex-encoded Ed25519 private key PEM bytes.
|
||||
> ```bash
|
||||
> ssh-keygen -t ed25519 -f /tmp/ssh_host_key -N ""
|
||||
> hexdump -v -e '/1 "%02x"' < /tmp/ssh_host_key
|
||||
> ```
|
||||
>
|
||||
> **Session secret**: generate 48 random bytes:
|
||||
> ```bash
|
||||
> openssl rand -base64 48
|
||||
> ```
|
||||
>
|
||||
> Override the ConfigMap name with `--set configMapName=your-cm-name`.
|
||||
|
||||
### 4. Verify prerequisites
|
||||
|
||||
```bash
|
||||
kubectl get namespace app
|
||||
kubectl get pvc -n app shared-data
|
||||
kubectl get configmap -n app app-env
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
helm template deploy ./deploy --namespace app --set imageRegistry=ghcr.io/your-org
|
||||
helm lint ./deploy
|
||||
|
||||
# Install
|
||||
helm upgrade --install deploy ./deploy \
|
||||
--namespace app \
|
||||
--set imageRegistry=ghcr.io/your-org \
|
||||
--set imageTag=v0.2.9
|
||||
```
|
||||
|
||||
## Storage
|
||||
|
||||
All services share a single PVC (`shared-data`) via `subPath` mounts:
|
||||
|
||||
| SubPath | Mount | Used By |
|
||||
|---|---|---|
|
||||
| `repos` | `/data/repos` | app, gitserver, git-hook |
|
||||
| `avatars` | `/data/avatars` | app |
|
||||
| `files` | `/data/files` | app |
|
||||
| `static` | `/data` | static-server |
|
||||
|
||||
## Autoscaling
|
||||
|
||||
All services except `email_worker` have HPA enabled by default. The email worker is fixed at 1 replica and must not be scaled.
|
||||
|
||||
To adjust HPA bounds per service:
|
||||
|
||||
```bash
|
||||
--set services.app.autoscaling.maxReplicas=20
|
||||
--set services.app.autoscaling.targetCPUUtilization=70
|
||||
```
|
||||
|
||||
To disable HPA for a service:
|
||||
|
||||
```bash
|
||||
--set services.git_hook.autoscaling.enabled=false
|
||||
```
|
||||
|
||||
## Ingress
|
||||
|
||||
```bash
|
||||
helm upgrade --install deploy ./deploy \
|
||||
--namespace app \
|
||||
--set ingress.enabled=true \
|
||||
--set ingress.className=nginx \
|
||||
--set ingress.hosts[0].host=your-domain.com
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
All services require these to be reachable from the cluster:
|
||||
|
||||
- PostgreSQL (via `APP_DATABASE_URL`)
|
||||
- Redis (via `APP_REDIS_URL`)
|
||||
- Git binary (included in all Docker images)
|
||||
- OpenAI-compatible API (via `APP_AI_BASIC_URL` + `APP_AI_API_KEY`)
|
||||
- Qdrant vector DB (via `APP_QDRANT_URL`)
|
||||
- SMTP server (via `APP_SMTP_*`)
|
||||
- Embedding model (via `APP_EMBED_MODEL_*`)
|
||||
|
||||
Optional dependencies with graceful degradation:
|
||||
|
||||
| Dependency | Variable | Fallback |
|
||||
|---|---|---|
|
||||
| NATS JetStream | `NATS_URL` + `NATS_TOKEN` | Redis queue |
|
||||
| Loki | `LOKI_URL` | Logs discarded |
|
||||
| OTEL Collector | `OTEL_EXPORTER_OTLP_ENDPOINT` | Tracing disabled |
|
||||
|
||||
## Production Example
|
||||
|
||||
```bash
|
||||
helm upgrade --install deploy ./deploy \
|
||||
--namespace app \
|
||||
--set imageRegistry=ghcr.io/your-org \
|
||||
--set imageTag=v0.2.9 \
|
||||
--set services.app.replicas=3 \
|
||||
--set services.app.autoscaling.maxReplicas=20 \
|
||||
--set ingress.enabled=true \
|
||||
--set ingress.className=nginx \
|
||||
--set ingress.hosts[0].host=your-domain.com \
|
||||
--set configMapName=app-env \
|
||||
--set pvcName=shared-data
|
||||
```
|
||||
93
deploy/gingress/deployment.yaml
Normal file
93
deploy/gingress/deployment.yaml
Normal file
@ -0,0 +1,93 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: gingress-controller
|
||||
namespace: gingress-system
|
||||
labels:
|
||||
app: gingress
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: gingress
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: gingress
|
||||
spec:
|
||||
serviceAccountName: gingress-controller
|
||||
containers:
|
||||
- name: gingress
|
||||
image: gingress:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
args:
|
||||
- "--ingress-class=gingress"
|
||||
- "--bind-http=0.0.0.0:80"
|
||||
- "--bind-https=0.0.0.0:443"
|
||||
- "--metrics-bind=0.0.0.0:8080"
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: 80
|
||||
protocol: TCP
|
||||
- name: https
|
||||
containerPort: 443
|
||||
protocol: TCP
|
||||
- name: metrics
|
||||
containerPort: 8080
|
||||
protocol: TCP
|
||||
env:
|
||||
- name: RUST_LOG
|
||||
value: "info"
|
||||
- name: METRICS_PUSH_URL
|
||||
value: "" # Optional: push to metrics aggregator
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /healthz
|
||||
port: 8080
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /readyz
|
||||
port: 8080
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
resources:
|
||||
requests:
|
||||
cpu: 100m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 512Mi
|
||||
affinity:
|
||||
podAntiAffinity:
|
||||
preferredDuringSchedulingIgnoredDuringExecution:
|
||||
- weight: 100
|
||||
podAffinityTerm:
|
||||
labelSelector:
|
||||
matchLabels:
|
||||
app: gingress
|
||||
topologyKey: kubernetes.io/hostname
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: gingress
|
||||
namespace: gingress-system
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
selector:
|
||||
app: gingress
|
||||
ports:
|
||||
- name: http
|
||||
port: 80
|
||||
targetPort: 80
|
||||
protocol: TCP
|
||||
- name: https
|
||||
port: 443
|
||||
targetPort: 443
|
||||
protocol: TCP
|
||||
- name: metrics
|
||||
port: 8080
|
||||
targetPort: 8080
|
||||
protocol: TCP
|
||||
48
deploy/gingress/rbac.yaml
Normal file
48
deploy/gingress/rbac.yaml
Normal file
@ -0,0 +1,48 @@
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: gingress-system
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: gingress-controller
|
||||
namespace: gingress-system
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: ClusterRole
|
||||
metadata:
|
||||
name: gingress-controller
|
||||
rules:
|
||||
- apiGroups: ["networking.k8s.io"]
|
||||
resources: ["ingresses", "ingressclasses"]
|
||||
verbs: ["get", "list", "watch"]
|
||||
- apiGroups: ["networking.k8s.io"]
|
||||
resources: ["ingresses/status"]
|
||||
verbs: ["update", "patch"]
|
||||
- apiGroups: [""]
|
||||
resources: ["services", "endpoints", "endpointslices", "secrets", "nodes"]
|
||||
verbs: ["get", "list", "watch"]
|
||||
- apiGroups: ["discovery.k8s.io"]
|
||||
resources: ["endpointslices"]
|
||||
verbs: ["get", "list", "watch"]
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: ClusterRoleBinding
|
||||
metadata:
|
||||
name: gingress-controller
|
||||
roleRef:
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
kind: ClusterRole
|
||||
name: gingress-controller
|
||||
subjects:
|
||||
- kind: ServiceAccount
|
||||
name: gingress-controller
|
||||
namespace: gingress-system
|
||||
---
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: IngressClass
|
||||
metadata:
|
||||
name: gingress
|
||||
spec:
|
||||
controller: gingress.io/gingress-controller
|
||||
19
deploy/templates/NOTES.txt
Normal file
19
deploy/templates/NOTES.txt
Normal file
@ -0,0 +1,19 @@
|
||||
Project backend services deployed to namespace: {{ .Release.Namespace }}
|
||||
|
||||
Services:
|
||||
{{- range $svcKey, $svcVal := .Values.services }}
|
||||
{{ $svcKey | replace "_" "-" }}: {{ if $svcVal.ports }}{{ range $portName, $portNum := $svcVal.ports }}{{ $portName }}={{ $portNum }} {{ end }}{{ else }}port={{ $svcVal.port }}{{ end }} {{ if $svcVal.autoscaling.enabled }}(HPA: {{ $svcVal.autoscaling.minReplicas }}-{{ $svcVal.autoscaling.maxReplicas }}){{ else }}(static: {{ $svcVal.replicaCount }}){{ end }}
|
||||
{{- end }}
|
||||
|
||||
To access the app locally:
|
||||
kubectl port-forward -n {{ .Release.Namespace }} svc/{{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }} 3000:3000
|
||||
|
||||
To check HPA status:
|
||||
{{- range $svcKey, $svcVal := .Values.services }}
|
||||
{{- if $svcVal.autoscaling.enabled }}
|
||||
kubectl get hpa -n {{ $.Release.Namespace }} {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
To check all pods:
|
||||
kubectl get pods -n {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "deploy.name" . }}"
|
||||
78
deploy/templates/_helpers.tpl
Normal file
78
deploy/templates/_helpers.tpl
Normal file
@ -0,0 +1,78 @@
|
||||
{{/*
|
||||
Expand the name of the chart.
|
||||
*/}}
|
||||
{{- define "deploy.name" -}}
|
||||
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create a default fully qualified app name.
|
||||
*/}}
|
||||
{{- define "deploy.fullname" -}}
|
||||
{{- if .Values.fullnameOverride }}
|
||||
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- $name := default .Chart.Name .Values.nameOverride }}
|
||||
{{- if contains $name .Release.Name }}
|
||||
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Service fullname — includes service key for per-service resources.
|
||||
Underscores in svcKey are replaced with hyphens for valid Kubernetes names.
|
||||
*/}}
|
||||
{{- define "deploy.serviceFullname" -}}
|
||||
{{- printf "%s-%s" (include "deploy.fullname" .root) (.svcKey | replace "_" "-") | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Chart name and version as used by the chart label.
|
||||
*/}}
|
||||
{{- define "deploy.chart" -}}
|
||||
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Common labels
|
||||
*/}}
|
||||
{{- define "deploy.labels" -}}
|
||||
helm.sh/chart: {{ include "deploy.chart" . }}
|
||||
{{ include "deploy.selectorLabels" . }}
|
||||
{{- if .Chart.AppVersion }}
|
||||
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||
{{- end }}
|
||||
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Selector labels
|
||||
*/}}
|
||||
{{- define "deploy.selectorLabels" -}}
|
||||
app.kubernetes.io/name: {{ include "deploy.name" . }}
|
||||
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Per-service selector labels — used by Service to target the right Deployment.
|
||||
Underscores in svcKey are replaced with hyphens for valid Kubernetes label values.
|
||||
*/}}
|
||||
{{- define "deploy.serviceSelectorLabels" -}}
|
||||
app.kubernetes.io/name: {{ include "deploy.name" .root }}
|
||||
app.kubernetes.io/instance: {{ .root.Release.Name }}
|
||||
app.kubernetes.io/component: {{ .svcKey | replace "_" "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create the name of the service account to use
|
||||
*/}}
|
||||
{{- define "deploy.serviceAccountName" -}}
|
||||
{{- if .Values.serviceAccount.create }}
|
||||
{{- default (include "deploy.fullname" .) .Values.serviceAccount.name }}
|
||||
{{- else }}
|
||||
{{- default "default" .Values.serviceAccount.name }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
89
deploy/templates/app/deployment.yaml
Normal file
89
deploy/templates/app/deployment.yaml
Normal file
@ -0,0 +1,89 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: app
|
||||
spec:
|
||||
replicas: {{ .Values.services.app.replicaCount | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "app") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 8 }}
|
||||
app.kubernetes.io/component: app
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
|
||||
{{- with .Values.podSecurityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: app
|
||||
{{- with .Values.securityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
image: "{{ .Values.imageRegistry }}/{{ .Values.services.app.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
{{- with .Values.services.app.command }}
|
||||
command:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.services.app.port }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.configMapName }}
|
||||
{{- with .Values.services.app.extraEnv }}
|
||||
env:
|
||||
{{- range $key, $val := . }}
|
||||
- name: {{ $key }}
|
||||
value: {{ $val | quote }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
{{- with .Values.services.app.resources }}
|
||||
resources:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.services.app.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: shared-data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Values.pvcName }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
16
deploy/templates/app/service.yaml
Normal file
16
deploy/templates/app/service.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: app
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: {{ .Values.services.app.port }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "app") | nindent 4 }}
|
||||
70
deploy/templates/email_worker/deployment.yaml
Normal file
70
deploy/templates/email_worker/deployment.yaml
Normal file
@ -0,0 +1,70 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "email_worker") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: email-worker
|
||||
spec:
|
||||
replicas: {{ .Values.services.email_worker.replicaCount | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "email_worker") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 8 }}
|
||||
app.kubernetes.io/component: email-worker
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
|
||||
{{- with .Values.podSecurityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: email_worker
|
||||
{{- with .Values.securityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
image: "{{ .Values.imageRegistry }}/{{ .Values.services.email_worker.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.services.email_worker.port }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.configMapName }}
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
{{- with .Values.services.email_worker.resources }}
|
||||
resources:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
16
deploy/templates/email_worker/service.yaml
Normal file
16
deploy/templates/email_worker/service.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "email_worker") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: email-worker
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: {{ .Values.services.email_worker.port }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "email_worker") | nindent 4 }}
|
||||
78
deploy/templates/git_hook/deployment.yaml
Normal file
78
deploy/templates/git_hook/deployment.yaml
Normal file
@ -0,0 +1,78 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "git_hook") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: git-hook
|
||||
spec:
|
||||
replicas: {{ .Values.services.git_hook.replicaCount | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "git_hook") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 8 }}
|
||||
app.kubernetes.io/component: git-hook
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
|
||||
{{- with .Values.podSecurityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: git-hook
|
||||
{{- with .Values.securityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
image: "{{ .Values.imageRegistry }}/{{ .Values.services.git_hook.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.services.git_hook.port }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.configMapName }}
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
{{- with .Values.services.git_hook.resources }}
|
||||
resources:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.services.git_hook.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: shared-data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Values.pvcName }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
16
deploy/templates/git_hook/service.yaml
Normal file
16
deploy/templates/git_hook/service.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "git_hook") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: git-hook
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: {{ .Values.services.git_hook.port }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "git_hook") | nindent 4 }}
|
||||
88
deploy/templates/gitserver/deployment.yaml
Normal file
88
deploy/templates/gitserver/deployment.yaml
Normal file
@ -0,0 +1,88 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "gitserver") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: gitserver
|
||||
spec:
|
||||
replicas: {{ .Values.services.gitserver.replicaCount | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "gitserver") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 8 }}
|
||||
app.kubernetes.io/component: gitserver
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
|
||||
{{- with .Values.podSecurityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: gitserver
|
||||
{{- with .Values.securityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
image: "{{ .Values.imageRegistry }}/{{ .Values.services.gitserver.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.services.gitserver.ports.http }}
|
||||
protocol: TCP
|
||||
- name: ssh
|
||||
containerPort: {{ .Values.services.gitserver.ports.ssh }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.configMapName }}
|
||||
{{- with .Values.services.gitserver.extraEnv }}
|
||||
env:
|
||||
{{- range $key, $val := . }}
|
||||
- name: {{ $key }}
|
||||
value: {{ $val | quote }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
{{- with .Values.services.gitserver.resources }}
|
||||
resources:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.services.gitserver.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: shared-data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Values.pvcName }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
20
deploy/templates/gitserver/service.yaml
Normal file
20
deploy/templates/gitserver/service.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "gitserver") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: gitserver
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: {{ .Values.services.gitserver.ports.http }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
- port: {{ .Values.services.gitserver.ports.ssh }}
|
||||
targetPort: ssh
|
||||
protocol: TCP
|
||||
name: ssh
|
||||
selector:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "gitserver") | nindent 4 }}
|
||||
26
deploy/templates/hpa.yaml
Normal file
26
deploy/templates/hpa.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
{{- range $svcKey, $svcVal := .Values.services }}
|
||||
{{- if $svcVal.autoscaling.enabled }}
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
|
||||
labels:
|
||||
{{- include "deploy.labels" $ | nindent 4 }}
|
||||
app.kubernetes.io/component: {{ $svcKey | replace "_" "-" }}
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
|
||||
minReplicas: {{ $svcVal.autoscaling.minReplicas }}
|
||||
maxReplicas: {{ $svcVal.autoscaling.maxReplicas }}
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: {{ $svcVal.autoscaling.targetCPUUtilization }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
41
deploy/templates/ingress.yaml
Normal file
41
deploy/templates/ingress.yaml
Normal file
@ -0,0 +1,41 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "deploy.fullname" . }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
{{- with .Values.ingress.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.ingress.className }}
|
||||
ingressClassName: {{ . }}
|
||||
{{- end }}
|
||||
{{- if .Values.ingress.tls }}
|
||||
tls:
|
||||
{{- range .Values.ingress.tls }}
|
||||
- hosts:
|
||||
{{- range .hosts }}
|
||||
- {{ . | quote }}
|
||||
{{- end }}
|
||||
secretName: {{ .secretName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
rules:
|
||||
{{- range .Values.ingress.hosts }}
|
||||
- host: {{ .host | quote }}
|
||||
http:
|
||||
paths:
|
||||
{{- range .paths }}
|
||||
- path: {{ .path }}
|
||||
pathType: {{ .pathType }}
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" .serviceName) }}
|
||||
port:
|
||||
number: {{ .servicePort }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
70
deploy/templates/metrics_aggregator/deployment.yaml
Normal file
70
deploy/templates/metrics_aggregator/deployment.yaml
Normal file
@ -0,0 +1,70 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "metrics_aggregator") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: metrics-aggregator
|
||||
spec:
|
||||
replicas: {{ .Values.services.metrics_aggregator.replicaCount | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "metrics_aggregator") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 8 }}
|
||||
app.kubernetes.io/component: metrics-aggregator
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
|
||||
{{- with .Values.podSecurityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: metrics_aggregator
|
||||
{{- with .Values.securityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
image: "{{ .Values.imageRegistry }}/{{ .Values.services.metrics_aggregator.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.services.metrics_aggregator.port }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.configMapName }}
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
{{- with .Values.services.metrics_aggregator.resources }}
|
||||
resources:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
16
deploy/templates/metrics_aggregator/service.yaml
Normal file
16
deploy/templates/metrics_aggregator/service.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "metrics_aggregator") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: metrics-aggregator
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: {{ .Values.services.metrics_aggregator.port }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "metrics_aggregator") | nindent 4 }}
|
||||
1
deploy/templates/secret.yaml
Normal file
1
deploy/templates/secret.yaml
Normal file
@ -0,0 +1 @@
|
||||
{{/* Secret disabled — all config via ConfigMap */}}
|
||||
13
deploy/templates/serviceaccount.yaml
Normal file
13
deploy/templates/serviceaccount.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
{{- if .Values.serviceAccount.create -}}
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceAccountName" . }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
{{- with .Values.serviceAccount.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
|
||||
{{- end }}
|
||||
78
deploy/templates/static_server/deployment.yaml
Normal file
78
deploy/templates/static_server/deployment.yaml
Normal file
@ -0,0 +1,78 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "static_server") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: static-server
|
||||
spec:
|
||||
replicas: {{ .Values.services.static_server.replicaCount | default 1 }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "static_server") | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 8 }}
|
||||
app.kubernetes.io/component: static-server
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
|
||||
{{- with .Values.podSecurityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: static_server
|
||||
{{- with .Values.securityContext }}
|
||||
securityContext:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
image: "{{ .Values.imageRegistry }}/{{ .Values.services.static_server.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.services.static_server.port }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.configMapName }}
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 15
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
{{- with .Values.services.static_server.resources }}
|
||||
resources:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.services.static_server.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: shared-data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Values.pvcName }}
|
||||
{{- with .Values.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
16
deploy/templates/static_server/service.yaml
Normal file
16
deploy/templates/static_server/service.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "static_server") }}
|
||||
labels:
|
||||
{{- include "deploy.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: static-server
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: {{ .Values.services.static_server.port }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "static_server") | nindent 4 }}
|
||||
182
deploy/values.yaml
Normal file
182
deploy/values.yaml
Normal file
@ -0,0 +1,182 @@
|
||||
# Global image registry and tag
|
||||
imageRegistry: ""
|
||||
imageTag: ""
|
||||
|
||||
# External ConfigMap (managed outside Helm)
|
||||
configMapName: "app-env"
|
||||
|
||||
# Service definitions
|
||||
services:
|
||||
app:
|
||||
repository: app
|
||||
port: 3000
|
||||
replicaCount: 2
|
||||
autoscaling:
|
||||
enabled: true
|
||||
minReplicas: 2
|
||||
maxReplicas: 10
|
||||
targetCPUUtilization: 80
|
||||
command:
|
||||
- "app"
|
||||
- "--bind"
|
||||
- "0.0.0.0:3000"
|
||||
resources:
|
||||
requests:
|
||||
cpu: 200m
|
||||
memory: 256Mi
|
||||
limits:
|
||||
cpu: "1"
|
||||
memory: 512Mi
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /data/repos
|
||||
subPath: repos
|
||||
- name: shared-data
|
||||
mountPath: /data/avatars
|
||||
subPath: avatars
|
||||
- name: shared-data
|
||||
mountPath: /data/files
|
||||
subPath: files
|
||||
|
||||
email_worker:
|
||||
repository: email-worker
|
||||
port: 8084
|
||||
replicaCount: 1
|
||||
autoscaling:
|
||||
enabled: false # email must stay at 1 replica
|
||||
resources:
|
||||
requests:
|
||||
cpu: 100m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 256Mi
|
||||
|
||||
git_hook:
|
||||
repository: git-hook
|
||||
port: 8083
|
||||
replicaCount: 1
|
||||
autoscaling:
|
||||
enabled: true
|
||||
minReplicas: 1
|
||||
maxReplicas: 5
|
||||
targetCPUUtilization: 80
|
||||
resources:
|
||||
requests:
|
||||
cpu: 100m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 256Mi
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /data/repos
|
||||
subPath: repos
|
||||
|
||||
gitserver:
|
||||
repository: gitserver
|
||||
ports:
|
||||
http: 8021
|
||||
ssh: 2222
|
||||
replicaCount: 1
|
||||
autoscaling:
|
||||
enabled: true
|
||||
minReplicas: 1
|
||||
maxReplicas: 5
|
||||
targetCPUUtilization: 80
|
||||
# SSH port must match the containerPort
|
||||
extraEnv:
|
||||
APP_SSH_PORT: "2222"
|
||||
resources:
|
||||
requests:
|
||||
cpu: 100m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 256Mi
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /data/repos
|
||||
subPath: repos
|
||||
|
||||
metrics_aggregator:
|
||||
repository: metrics-aggregator
|
||||
port: 9090
|
||||
replicaCount: 1
|
||||
autoscaling:
|
||||
enabled: true
|
||||
minReplicas: 1
|
||||
maxReplicas: 5
|
||||
targetCPUUtilization: 80
|
||||
resources:
|
||||
requests:
|
||||
cpu: 100m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 256Mi
|
||||
|
||||
static_server:
|
||||
repository: static-server
|
||||
port: 8081
|
||||
replicaCount: 1
|
||||
autoscaling:
|
||||
enabled: true
|
||||
minReplicas: 1
|
||||
maxReplicas: 5
|
||||
targetCPUUtilization: 80
|
||||
resources:
|
||||
requests:
|
||||
cpu: 50m
|
||||
memory: 64Mi
|
||||
limits:
|
||||
cpu: 200m
|
||||
memory: 128Mi
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /data
|
||||
subPath: static
|
||||
|
||||
# External PVC (managed outside Helm — not deleted on uninstall)
|
||||
pvcName: "shared-data"
|
||||
|
||||
# Ingress — only for the main app service
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
annotations: {}
|
||||
hosts:
|
||||
- host: chart-example.local
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
serviceName: app
|
||||
servicePort: 3000
|
||||
tls: []
|
||||
# - secretName: chart-example-tls
|
||||
# hosts:
|
||||
# - chart-example.local
|
||||
|
||||
imagePullSecrets: []
|
||||
nameOverride: ""
|
||||
fullnameOverride: ""
|
||||
|
||||
serviceAccount:
|
||||
create: true
|
||||
automount: true
|
||||
annotations: {}
|
||||
name: ""
|
||||
|
||||
podSecurityContext:
|
||||
runAsNonRoot: true
|
||||
runAsUser: 1000
|
||||
|
||||
securityContext:
|
||||
capabilities:
|
||||
drop:
|
||||
- ALL
|
||||
readOnlyRootFilesystem: true
|
||||
|
||||
nodeSelector: {}
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
10
docker/app.Dockerfile
Normal file
10
docker/app.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 openssh-client procps git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/app /bin
|
||||
USER appuser
|
||||
EXPOSE 3000
|
||||
CMD ["app"]
|
||||
10
docker/email.Dockerfile
Normal file
10
docker/email.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/email-worker /bin
|
||||
USER appuser
|
||||
EXPOSE 8084
|
||||
CMD ["email-worker"]
|
||||
10
docker/gingress.Dockerfile
Normal file
10
docker/gingress.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/gingress /bin
|
||||
USER appuser
|
||||
EXPOSE 80 443 8080
|
||||
CMD ["gingress"]
|
||||
10
docker/githook.Dockerfile
Normal file
10
docker/githook.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/git-hook /bin
|
||||
USER appuser
|
||||
EXPOSE 8083
|
||||
CMD ["git-hook"]
|
||||
10
docker/gitserver.Dockerfile
Normal file
10
docker/gitserver.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 git openssh-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/gitserver /bin
|
||||
USER appuser
|
||||
EXPOSE 8021 2222
|
||||
CMD ["gitserver"]
|
||||
10
docker/metrics.Dockerfile
Normal file
10
docker/metrics.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/metrics-aggregator /bin
|
||||
USER appuser
|
||||
EXPOSE 9090
|
||||
CMD ["metrics-aggregator"]
|
||||
10
docker/static.Dockerfile
Normal file
10
docker/static.Dockerfile
Normal file
@ -0,0 +1,10 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates libssl3 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN useradd --system --create-home appuser
|
||||
WORKDIR /home/appuser
|
||||
COPY target/release/static-server /bin
|
||||
USER appuser
|
||||
EXPOSE 8081
|
||||
CMD ["static-server"]
|
||||
@ -42,5 +42,6 @@ reqwest = { workspace = true, features = ["json"] }
|
||||
utoipa = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
redis = { workspace = true, features = ["tokio-comp"] }
|
||||
queue = { workspace = true }
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@ -152,7 +152,8 @@ impl RigAgentService {
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
|
||||
let cleaned = text.text.replace('\n', "");
|
||||
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
|
||||
final_content.push_str(&text.text);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
@ -237,7 +238,8 @@ impl RigAgentService {
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
|
||||
let cleaned = text.text.replace('\n', "");
|
||||
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
|
||||
final_content.push_str(&text.text);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
|
||||
@ -1,226 +1,455 @@
|
||||
//! AI usage billing — records token costs against a project or workspace balance.
|
||||
//! Billing service — handles user-level and project-level billing, deduction,
|
||||
//! credit initialization, and error persistence.
|
||||
//!
|
||||
//! All functions take `&DatabaseConnection` instead of `&AppService`.
|
||||
//! Architecture:
|
||||
//! - Each user gets $10 personal balance on signup.
|
||||
//! - Each project gets $20 balance only if it's the creator's first project,
|
||||
//! $0 otherwise.
|
||||
//! - AI usage is deducted from the project balance first; if insufficient,
|
||||
//! falls through to the user's personal balance.
|
||||
//! - Monthly quota only applies to pro users (is_pro = true).
|
||||
//! - If both project and user balance are insufficient, a billing_error
|
||||
//! record is persisted and an error is returned to the caller.
|
||||
|
||||
use db::database::AppDatabase;
|
||||
use models::agents::model_pricing;
|
||||
use models::projects::project;
|
||||
use models::projects::project_billing;
|
||||
use models::projects::project_billing_history;
|
||||
use models::workspaces::workspace_billing;
|
||||
use models::workspaces::workspace_billing_history;
|
||||
use models::ai::billing_error;
|
||||
use models::projects::{project, project_billing, project_billing_history};
|
||||
use models::users::user_billing;
|
||||
use rust_decimal::Decimal;
|
||||
use sea_orm::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::AgentError;
|
||||
|
||||
// ── Constants ──
|
||||
|
||||
fn default_user_balance() -> Decimal { Decimal::new(100_000, 4) } // $10.0000
|
||||
fn first_project_credit() -> Decimal { Decimal::new(200_000, 4) } // $20.0000
|
||||
const SUBSEQUENT_PROJECT_BALANCE: Decimal = Decimal::ZERO;
|
||||
|
||||
// ── Types ──
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
|
||||
pub struct BillingRecord {
|
||||
pub cost: f64,
|
||||
pub currency: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub deducted_from: String, // "project" or "user"
|
||||
}
|
||||
|
||||
/// Extended result that includes insufficient balance flag for system message creation.
|
||||
#[derive(Debug)]
|
||||
pub enum BillingResult {
|
||||
Success(BillingRecord),
|
||||
InsufficientBalance { message: String },
|
||||
}
|
||||
|
||||
/// Record AI usage for a project with cascading billing.
|
||||
// ── Core deduction: AI usage ──
|
||||
|
||||
/// Record AI usage: deduct from project balance first, fall through to user balance.
|
||||
///
|
||||
/// Billing strategy:
|
||||
/// 1. Try to deduct from project balance first
|
||||
/// 2. If insufficient, fallback to workspace balance (if project belongs to workspace)
|
||||
/// 3. If both insufficient or no workspace, return InsufficientBalance error with room_id
|
||||
///
|
||||
/// Returns BillingError::InsufficientBalance with room_id for system message creation.
|
||||
/// Returns `InsufficientBalance` if neither account can cover the cost.
|
||||
/// On insufficient balance, a `billing_error` record is persisted for frontend display.
|
||||
pub async fn record_ai_usage(
|
||||
db: &AppDatabase,
|
||||
project_uid: Uuid,
|
||||
user_uid: Uuid,
|
||||
model_id: Uuid,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
) -> Result<BillingResult, AgentError> {
|
||||
// 1. Look up the active price for this model.
|
||||
let total_cost = compute_cost(db, model_id, input_tokens, output_tokens).await?;
|
||||
let currency = get_currency(db, model_id).await?;
|
||||
|
||||
// Verify project exists
|
||||
let _ = project::Entity::find_by_id(project_uid)
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or_else(|| AgentError::Internal("Project not found".into()))?;
|
||||
|
||||
// Attempt project-level deduction first
|
||||
let project_result = deduct_from_project(db, project_uid, total_cost, ¤cy, model_id, input_tokens, output_tokens).await;
|
||||
|
||||
match project_result {
|
||||
Ok(()) => {
|
||||
let cost_f64 = decimal_to_f64(total_cost);
|
||||
tracing::info!(
|
||||
project_id = %project_uid,
|
||||
model_id = %model_id,
|
||||
input_tokens, output_tokens,
|
||||
cost = %cost_f64,
|
||||
currency = %currency,
|
||||
deducted_from = "project",
|
||||
"ai_usage_recorded"
|
||||
);
|
||||
Ok(BillingResult::Success(BillingRecord {
|
||||
cost: cost_f64,
|
||||
currency,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
deducted_from: "project".to_string(),
|
||||
}))
|
||||
}
|
||||
Err(_) => {
|
||||
// Project balance insufficient — try user personal balance
|
||||
let user_result = deduct_from_user(db, user_uid, total_cost, ¤cy, project_uid, model_id, input_tokens, output_tokens).await;
|
||||
|
||||
match user_result {
|
||||
Ok(()) => {
|
||||
let cost_f64 = decimal_to_f64(total_cost);
|
||||
tracing::info!(
|
||||
user_id = %user_uid,
|
||||
project_id = %project_uid,
|
||||
model_id = %model_id,
|
||||
input_tokens, output_tokens,
|
||||
cost = %cost_f64,
|
||||
currency = %currency,
|
||||
deducted_from = "user",
|
||||
"ai_usage_recorded"
|
||||
);
|
||||
Ok(BillingResult::Success(BillingRecord {
|
||||
cost: cost_f64,
|
||||
currency,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
deducted_from: "user".to_string(),
|
||||
}))
|
||||
}
|
||||
Err(insufficient_msg) => {
|
||||
// Both project and user balance insufficient — persist error
|
||||
persist_billing_error(
|
||||
db,
|
||||
"project",
|
||||
project_uid,
|
||||
"insufficient_balance",
|
||||
&insufficient_msg,
|
||||
Some(serde_json::json!({
|
||||
"user_id": user_uid.to_string(),
|
||||
"model_id": model_id.to_string(),
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cost": decimal_to_f64(total_cost),
|
||||
"currency": currency,
|
||||
})),
|
||||
).await?;
|
||||
|
||||
Ok(BillingResult::InsufficientBalance {
|
||||
message: insufficient_msg,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether a project + user has sufficient combined balance for a potential AI call.
|
||||
/// Called before starting AI processing to avoid wasted compute.
|
||||
pub async fn check_balance(
|
||||
db: &AppDatabase,
|
||||
project_uid: Uuid,
|
||||
user_uid: Uuid,
|
||||
model_id: Uuid,
|
||||
estimated_input_tokens: i64,
|
||||
estimated_output_tokens: i64,
|
||||
) -> Result<bool, AgentError> {
|
||||
let estimated_cost = compute_cost(db, model_id, estimated_input_tokens, estimated_output_tokens).await?;
|
||||
let project_balance = get_project_balance(db, project_uid).await;
|
||||
let user_balance = get_user_balance(db, user_uid).await;
|
||||
|
||||
Ok(project_balance + user_balance >= estimated_cost)
|
||||
}
|
||||
|
||||
// ── Initialization ──
|
||||
|
||||
/// Initialize a user billing account with the default $10 balance.
|
||||
/// Called on user signup / first login.
|
||||
pub async fn initialize_user_billing(db: &AppDatabase, user_uid: Uuid) -> Result<(), AgentError> {
|
||||
let now = chrono::Utc::now();
|
||||
user_billing::ActiveModel {
|
||||
user: Set(user_uid),
|
||||
balance: Set(default_user_balance()),
|
||||
currency: Set("USD".to_string()),
|
||||
is_pro: Set(false),
|
||||
monthly_quota: Set(Decimal::ZERO),
|
||||
month_used: Set(Decimal::ZERO),
|
||||
cycle_start: Set(None),
|
||||
cycle_end: Set(None),
|
||||
updated_at: Set(now),
|
||||
created_at: Set(now),
|
||||
}
|
||||
.insert(db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(format!("failed to create user billing: {}", e)))?;
|
||||
|
||||
tracing::info!(user_id = %user_uid, balance = "$10", "user_billing_initialized");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize a project billing account.
|
||||
/// Grants $20 only if this is the creator's first project; $0 otherwise.
|
||||
pub async fn initialize_project_billing(
|
||||
db: &AppDatabase,
|
||||
project_uid: Uuid,
|
||||
creator_uid: Uuid,
|
||||
) -> Result<(), AgentError> {
|
||||
// Check how many projects this user has already created
|
||||
let existing_count = project::Entity::find()
|
||||
.filter(project::Column::CreatedBy.eq(creator_uid))
|
||||
.count(db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(format!("failed to count user projects: {}", e)))?;
|
||||
|
||||
let is_first = existing_count == 0;
|
||||
let initial_balance = if is_first { first_project_credit() } else { SUBSEQUENT_PROJECT_BALANCE };
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
project_billing::ActiveModel {
|
||||
project: Set(project_uid),
|
||||
balance: Set(initial_balance),
|
||||
currency: Set("USD".to_string()),
|
||||
user: Set(Some(creator_uid)),
|
||||
initial_credit_granted: Set(is_first),
|
||||
is_pro: Set(false),
|
||||
monthly_quota: Set(Decimal::ZERO),
|
||||
month_used: Set(Decimal::ZERO),
|
||||
cycle_start: Set(None),
|
||||
cycle_end: Set(None),
|
||||
updated_at: Set(now),
|
||||
created_at: Set(now),
|
||||
}
|
||||
.insert(db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(format!("failed to create project billing: {}", e)))?;
|
||||
|
||||
if is_first {
|
||||
// Record the credit in billing history
|
||||
project_billing_history::ActiveModel {
|
||||
uid: Set(Uuid::new_v4()),
|
||||
project: Set(project_uid),
|
||||
user: Set(Some(creator_uid)),
|
||||
amount: Set(first_project_credit()),
|
||||
currency: Set("USD".to_string()),
|
||||
reason: Set("first_project_credit".to_string()),
|
||||
extra: Set(Some(serde_json::json!({
|
||||
"is_first_project": true,
|
||||
}))),
|
||||
created_at: Set(now),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(format!("failed to record credit history: {}", e)))?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
project_id = %project_uid,
|
||||
creator_id = %creator_uid,
|
||||
is_first_project = is_first,
|
||||
balance = if is_first { "$20" } else { "$0" },
|
||||
"project_billing_initialized"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Internal helpers ──
|
||||
|
||||
async fn compute_cost(
|
||||
db: &AppDatabase,
|
||||
model_id: Uuid,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
) -> Result<Decimal, AgentError> {
|
||||
let pricing = model_pricing::Entity::find()
|
||||
.filter(model_pricing::Column::ModelVersionId.eq(model_id))
|
||||
.order_by_desc(model_pricing::Column::EffectiveFrom)
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
AgentError::Internal(
|
||||
"No pricing record found for this model. Please configure AI model pricing first."
|
||||
.into(),
|
||||
)
|
||||
})?;
|
||||
.ok_or_else(|| AgentError::Internal(
|
||||
"No pricing record found for this model. Please configure AI model pricing first.".into(),
|
||||
))?;
|
||||
|
||||
let input_price: Decimal = pricing.input_price_per_1k_tokens.parse()
|
||||
.map_err(|e| AgentError::Internal(format!("Invalid input price: {}", e)))?;
|
||||
let output_price: Decimal = pricing.output_price_per_1k_tokens.parse()
|
||||
.map_err(|e| AgentError::Internal(format!("Invalid output price: {}", e)))?;
|
||||
|
||||
// 2. Compute cost using Decimal arithmetic.
|
||||
let input_price: Decimal = pricing
|
||||
.input_price_per_1k_tokens
|
||||
.parse()
|
||||
.map_err(|e| AgentError::Internal(format!("Invalid input price format: {}", e)))?;
|
||||
let output_price: Decimal = pricing
|
||||
.output_price_per_1k_tokens
|
||||
.parse()
|
||||
.map_err(|e| AgentError::Internal(format!("Invalid output price format: {}", e)))?;
|
||||
let tokens_i = Decimal::from(input_tokens);
|
||||
let tokens_o = Decimal::from(output_tokens);
|
||||
let thousand = Decimal::from(1000);
|
||||
Ok((Decimal::from(input_tokens) / thousand) * input_price
|
||||
+ (Decimal::from(output_tokens) / thousand) * output_price)
|
||||
}
|
||||
|
||||
let total_cost = (tokens_i / thousand) * input_price
|
||||
+ (tokens_o / thousand) * output_price;
|
||||
|
||||
let currency = pricing.currency.clone();
|
||||
|
||||
// 3. Cascading billing: project balance first, then workspace if insufficient.
|
||||
let proj = project::Entity::find_by_id(project_uid)
|
||||
async fn get_currency(db: &AppDatabase, model_id: Uuid) -> Result<String, AgentError> {
|
||||
let pricing = model_pricing::Entity::find()
|
||||
.filter(model_pricing::Column::ModelVersionId.eq(model_id))
|
||||
.one(db)
|
||||
.await?
|
||||
.ok_or_else(|| AgentError::Internal("Project not found".into()))?;
|
||||
.ok_or_else(|| AgentError::Internal("No pricing found".into()))?;
|
||||
Ok(pricing.currency.clone())
|
||||
}
|
||||
|
||||
let txn = db.begin().await?;
|
||||
async fn get_project_balance(db: &AppDatabase, project_uid: Uuid) -> Decimal {
|
||||
project_billing::Entity::find_by_id(project_uid)
|
||||
.one(db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|b| b.balance)
|
||||
.unwrap_or(Decimal::ZERO)
|
||||
}
|
||||
|
||||
// Always check project balance first
|
||||
let project_billing = project_billing::Entity::find_by_id(project_uid)
|
||||
async fn get_user_balance(db: &AppDatabase, user_uid: Uuid) -> Decimal {
|
||||
user_billing::Entity::find_by_id(user_uid)
|
||||
.one(db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|b| b.balance)
|
||||
.unwrap_or(Decimal::ZERO)
|
||||
}
|
||||
|
||||
async fn deduct_from_project(
|
||||
db: &AppDatabase,
|
||||
project_uid: Uuid,
|
||||
cost: Decimal,
|
||||
currency: &str,
|
||||
model_id: Uuid,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
) -> Result<(), String> {
|
||||
let txn = db.begin().await.map_err(|e| format!("db txn error: {}", e))?;
|
||||
|
||||
let billing = project_billing::Entity::find_by_id(project_uid)
|
||||
.lock_exclusive()
|
||||
.one(&txn)
|
||||
.await?
|
||||
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?;
|
||||
.await
|
||||
.map_err(|e| format!("db error: {}", e))?
|
||||
.ok_or_else(|| "Project billing account not found".to_string())?;
|
||||
|
||||
if billing.balance < cost {
|
||||
txn.rollback().await.ok();
|
||||
return Err(format!(
|
||||
"Project balance insufficient. Required: {:.4} {}, Available: {:.4} {}",
|
||||
cost, currency, billing.balance, currency
|
||||
));
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
if project_billing.balance >= total_cost {
|
||||
// ── Project has sufficient balance ──────────────────────────
|
||||
let amount_dec = -total_cost;
|
||||
|
||||
project_billing_history::ActiveModel {
|
||||
uid: Set(Uuid::new_v4()),
|
||||
project: Set(project_uid),
|
||||
user: Set(None),
|
||||
amount: Set(amount_dec),
|
||||
currency: Set(currency.clone()),
|
||||
reason: Set("ai_usage".to_string()),
|
||||
extra: Set(Some(serde_json::json!({
|
||||
"model_id": model_id.to_string(),
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
}))),
|
||||
created_at: Set(now),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&txn)
|
||||
.await?;
|
||||
|
||||
let new_balance = project_billing.balance - total_cost;
|
||||
let mut updated: project_billing::ActiveModel = project_billing.into();
|
||||
updated.balance = Set(new_balance);
|
||||
updated.updated_at = Set(now);
|
||||
updated.update(&txn).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
|
||||
|
||||
tracing::info!(
|
||||
project_id = %project_uid,
|
||||
model_id = %model_id,
|
||||
input_tokens = input_tokens,
|
||||
output_tokens = output_tokens,
|
||||
cost = %cost_f64,
|
||||
currency = %currency,
|
||||
source = "project",
|
||||
"ai_usage_recorded"
|
||||
);
|
||||
|
||||
Ok(BillingResult::Success(BillingRecord {
|
||||
cost: cost_f64,
|
||||
currency,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
}))
|
||||
} else if let Some(workspace_id) = proj.workspace_id {
|
||||
// ── Project insufficient, fallback to workspace ─────────────
|
||||
let workspace_billing = workspace_billing::Entity::find_by_id(workspace_id)
|
||||
.lock_exclusive()
|
||||
.one(&txn)
|
||||
.await?
|
||||
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
|
||||
|
||||
if workspace_billing.balance < total_cost {
|
||||
txn.rollback().await?;
|
||||
return Ok(BillingResult::InsufficientBalance {
|
||||
message: format!(
|
||||
"Insufficient balance. Project: {:.4} {}, Workspace: {:.4} {}, Required: {:.4} {}",
|
||||
project_billing.balance, currency,
|
||||
workspace_billing.balance, currency,
|
||||
total_cost, currency
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let amount_dec = -total_cost;
|
||||
|
||||
workspace_billing_history::ActiveModel {
|
||||
uid: Set(Uuid::new_v4()),
|
||||
workspace_id: Set(workspace_id),
|
||||
user_id: Set(Some(proj.created_by)),
|
||||
amount: Set(amount_dec),
|
||||
currency: Set(currency.clone()),
|
||||
reason: Set(format!("ai_usage:{}", project_uid)),
|
||||
extra: Set(Some(serde_json::json!({
|
||||
"project_id": project_uid.to_string(),
|
||||
"model_id": model_id.to_string(),
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"fallback_reason": "project_balance_insufficient"
|
||||
}))),
|
||||
created_at: Set(now),
|
||||
}
|
||||
.insert(&txn)
|
||||
.await?;
|
||||
|
||||
let new_balance = workspace_billing.balance - total_cost;
|
||||
let new_total_spent = workspace_billing.total_spent + total_cost;
|
||||
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
|
||||
updated.balance = Set(new_balance);
|
||||
updated.total_spent = Set(new_total_spent);
|
||||
updated.updated_at = Set(now);
|
||||
updated.update(&txn).await?;
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
|
||||
|
||||
tracing::info!(
|
||||
project_id = %project_uid,
|
||||
model_id = %model_id,
|
||||
input_tokens = input_tokens,
|
||||
output_tokens = output_tokens,
|
||||
cost = %cost_f64,
|
||||
currency = %currency,
|
||||
workspace_id = %workspace_id.to_string(),
|
||||
source = "workspace_fallback",
|
||||
"ai_usage_recorded"
|
||||
);
|
||||
|
||||
Ok(BillingResult::Success(BillingRecord {
|
||||
cost: cost_f64,
|
||||
currency,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
}))
|
||||
} else {
|
||||
// ── Project insufficient and no workspace ───────────────────
|
||||
txn.rollback().await?;
|
||||
Ok(BillingResult::InsufficientBalance {
|
||||
message: format!(
|
||||
"Insufficient balance. Required: {:.4} {}, Available: {:.4} {}",
|
||||
total_cost, currency, project_billing.balance, currency
|
||||
),
|
||||
})
|
||||
project_billing_history::ActiveModel {
|
||||
uid: Set(Uuid::new_v4()),
|
||||
project: Set(project_uid),
|
||||
user: Set(None),
|
||||
amount: Set(-cost),
|
||||
currency: Set(currency.to_string()),
|
||||
reason: Set("ai_usage".to_string()),
|
||||
extra: Set(Some(serde_json::json!({
|
||||
"model_id": model_id.to_string(),
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"deducted_from": "project",
|
||||
}))),
|
||||
created_at: Set(now),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&txn)
|
||||
.await
|
||||
.map_err(|e| format!("failed to insert history: {}", e))?;
|
||||
|
||||
let mut updated: project_billing::ActiveModel = billing.into();
|
||||
updated.balance = Set(updated.balance.unwrap() - cost);
|
||||
updated.updated_at = Set(now);
|
||||
updated.update(&txn).await.map_err(|e| format!("failed to update balance: {}", e))?;
|
||||
|
||||
txn.commit().await.map_err(|e| format!("commit error: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn deduct_from_user(
|
||||
db: &AppDatabase,
|
||||
user_uid: Uuid,
|
||||
cost: Decimal,
|
||||
currency: &str,
|
||||
project_uid: Uuid,
|
||||
model_id: Uuid,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
) -> Result<(), String> {
|
||||
let txn = db.begin().await.map_err(|e| format!("db txn error: {}", e))?;
|
||||
|
||||
let billing = user_billing::Entity::find_by_id(user_uid)
|
||||
.lock_exclusive()
|
||||
.one(&txn)
|
||||
.await
|
||||
.map_err(|e| format!("db error: {}", e))?
|
||||
.ok_or_else(|| "User billing account not found".to_string())?;
|
||||
|
||||
if billing.balance < cost {
|
||||
txn.rollback().await.ok();
|
||||
return Err(format!(
|
||||
"Insufficient balance (project + user). Project: unavailable, User: {:.4} {}. Required: {:.4} {}",
|
||||
billing.balance, currency, cost, currency
|
||||
));
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
// Record in project billing history (but deducted from user)
|
||||
project_billing_history::ActiveModel {
|
||||
uid: Set(Uuid::new_v4()),
|
||||
project: Set(project_uid),
|
||||
user: Set(Some(user_uid)),
|
||||
amount: Set(-cost),
|
||||
currency: Set(currency.to_string()),
|
||||
reason: Set("ai_usage_user_fallback".to_string()),
|
||||
extra: Set(Some(serde_json::json!({
|
||||
"model_id": model_id.to_string(),
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"deducted_from": "user",
|
||||
}))),
|
||||
created_at: Set(now),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&txn)
|
||||
.await
|
||||
.map_err(|e| format!("failed to insert history: {}", e))?;
|
||||
|
||||
let mut updated: user_billing::ActiveModel = billing.into();
|
||||
updated.balance = Set(updated.balance.unwrap() - cost);
|
||||
updated.updated_at = Set(now);
|
||||
updated.update(&txn).await.map_err(|e| format!("failed to update user balance: {}", e))?;
|
||||
|
||||
txn.commit().await.map_err(|e| format!("commit error: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn persist_billing_error(
|
||||
db: &AppDatabase,
|
||||
scope: &str,
|
||||
scope_id: Uuid,
|
||||
error_type: &str,
|
||||
message: &str,
|
||||
details: Option<serde_json::Value>,
|
||||
) -> Result<(), AgentError> {
|
||||
billing_error::ActiveModel {
|
||||
id: Set(Uuid::new_v4()),
|
||||
scope: Set(scope.to_string()),
|
||||
scope_id: Set(scope_id),
|
||||
error_type: Set(error_type.to_string()),
|
||||
message: Set(message.to_string()),
|
||||
details: Set(details),
|
||||
resolved: Set(false),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(format!("failed to persist billing error: {}", e)))?;
|
||||
|
||||
tracing::warn!(scope, %scope_id, error_type, "billing_error_persisted");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn decimal_to_f64(d: Decimal) -> f64 {
|
||||
d.round_dp(10).to_string().parse().unwrap_or(0.0)
|
||||
}
|
||||
357
libs/agent/chat/chat_execution.rs
Normal file
357
libs/agent/chat/chat_execution.rs
Normal file
@ -0,0 +1,357 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::client::types::{ChatRequestMessage, ToolCall};
|
||||
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
|
||||
use crate::error::Result;
|
||||
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam};
|
||||
use crate::tool::registry::ToolRegistry;
|
||||
use crate::embed::EmbedService;
|
||||
use sea_orm::{ActiveModelTrait, EntityTrait, Set};
|
||||
|
||||
use super::{AiChunkType, AiStreamChunk, StreamCallback};
|
||||
use super::service::StreamResult;
|
||||
|
||||
// Keyword-extraction-based title generator: reads conversation messages, extracts
|
||||
// meaningful words, and updates the conversation record with a short title.
|
||||
async fn generate_title_for_conversation(
|
||||
ctx: &ToolContext,
|
||||
conversation_id: Uuid,
|
||||
) -> Result<serde_json::Value> {
|
||||
use models::ai::{ai_conversation, ai_message, AiMessage};
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
||||
|
||||
let db_reader = ctx.db().reader();
|
||||
let db_writer = ctx.db().writer();
|
||||
let conv = ai_conversation::Entity::find_by_id(conversation_id)
|
||||
.one(db_reader)
|
||||
.await
|
||||
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e) })?
|
||||
.ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?;
|
||||
|
||||
let recent_messages = AiMessage::find()
|
||||
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
||||
.filter(ai_message::Column::Role.eq("user"))
|
||||
.order_by_desc(ai_message::Column::CreatedAt)
|
||||
.limit(3)
|
||||
.all(db_reader)
|
||||
.await
|
||||
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e) })?;
|
||||
|
||||
if recent_messages.is_empty() {
|
||||
return Err(crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: "No user messages found".into() });
|
||||
}
|
||||
|
||||
let content = recent_messages
|
||||
.first()
|
||||
.and_then(|m| m.content.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|v| v.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let words: Vec<&str> = content
|
||||
.split_whitespace()
|
||||
.filter(|w| w.len() > 2 && !is_stop_word(w))
|
||||
.take(5)
|
||||
.collect();
|
||||
|
||||
let title = if words.is_empty() {
|
||||
"New Chat".to_string()
|
||||
} else {
|
||||
words.join(" ")
|
||||
};
|
||||
|
||||
let mut active: ai_conversation::ActiveModel = conv.into();
|
||||
active.title = Set(Some(title.clone()));
|
||||
active.updated_at = Set(chrono::Utc::now());
|
||||
active
|
||||
.update(db_writer)
|
||||
.await
|
||||
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("failed to update title: {}", e) })?;
|
||||
|
||||
Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title }))
|
||||
}
|
||||
|
||||
fn is_stop_word(w: &str) -> bool {
|
||||
matches!(
|
||||
w.to_lowercase().as_str(),
|
||||
"the" | "this" | "that" | "what" | "which" | "when" | "where"
|
||||
| "why" | "how" | "can" | "could" | "would" | "should"
|
||||
| "please" | "help" | "thanks" | "thank" | "you" | "your"
|
||||
| "have" | "has" | "had" | "with" | "for" | "from" | "into"
|
||||
| "about" | "also" | "just" | "now" | "very" | "really"
|
||||
)
|
||||
}
|
||||
|
||||
type SharedCallback = Arc<dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
|
||||
|
||||
/// Simplified ReAct execution for Chat API.
|
||||
///
|
||||
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
|
||||
/// this function takes messages and tools directly. It does NOT record AI sessions to
|
||||
/// the `ai_session` table — the caller is responsible for persisting results.
|
||||
pub async fn execute_chat_stream(
|
||||
messages: Vec<ChatRequestMessage>,
|
||||
tools: Vec<serde_json::Value>,
|
||||
model_name: &str,
|
||||
config: &AiClientConfig,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
max_tool_depth: usize,
|
||||
tool_registry: Option<&ToolRegistry>,
|
||||
db: db::database::AppDatabase,
|
||||
cache: db::cache::AppCache,
|
||||
app_config: config::AppConfig,
|
||||
project_id: Uuid,
|
||||
sender_uid: Uuid,
|
||||
embed_service: Option<EmbedService>,
|
||||
on_chunk: StreamCallback,
|
||||
conversation_id: Option<uuid::Uuid>,
|
||||
) -> Result<StreamResult> {
|
||||
let on_chunk: SharedCallback = Arc::from(on_chunk);
|
||||
let tools_enabled = !tools.is_empty();
|
||||
let mut messages = messages;
|
||||
let mut tool_depth = 0;
|
||||
let mut total_input_tokens = 0i64;
|
||||
let mut total_output_tokens = 0i64;
|
||||
let mut full_content = String::new();
|
||||
let mut all_chunks: Vec<StreamChunk> = Vec::new();
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
|
||||
|
||||
// Conditionally inject chat_generate_title tool if conversation has no title
|
||||
let (tools, _tools_injected) = if let Some(conv_id) = conversation_id {
|
||||
if let Some(registry) = tool_registry {
|
||||
let db_reader = db.reader();
|
||||
let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id)
|
||||
.one(db_reader)
|
||||
.await
|
||||
.map(|c| c.map(|m| m.title.is_some()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
if !has_title {
|
||||
let mut reg = registry.clone();
|
||||
reg.register(
|
||||
ToolDefinition::new("chat_generate_title")
|
||||
.description(
|
||||
"Generate a concise title (5 words or fewer) for the current conversation \
|
||||
based on its message history, and save it to the conversation record. \
|
||||
Call this tool at the start of a new conversation if it has no title.",
|
||||
)
|
||||
.parameters(crate::tool::ToolSchema {
|
||||
schema_type: "object".into(),
|
||||
properties: Some({
|
||||
let mut p = std::collections::HashMap::new();
|
||||
p.insert("conversation_id".into(), ToolParam {
|
||||
name: "conversation_id".into(),
|
||||
param_type: "string".into(),
|
||||
description: Some("The UUID of the conversation (required).".into()),
|
||||
required: true,
|
||||
properties: None,
|
||||
items: None,
|
||||
});
|
||||
p
|
||||
}),
|
||||
required: Some(vec!["conversation_id".into()]),
|
||||
}),
|
||||
ToolHandler::new(|ctx, args| {
|
||||
let conv_id = args.get("conversation_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(|s| Uuid::parse_str(s).ok());
|
||||
Box::pin(async move {
|
||||
match conv_id {
|
||||
Some(id) => generate_title_for_conversation(&ctx, id).await
|
||||
.map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())),
|
||||
None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())),
|
||||
}
|
||||
})
|
||||
}),
|
||||
);
|
||||
// Prepend system message instructing the model to generate title first
|
||||
messages.insert(0, ChatRequestMessage::system(
|
||||
"IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \
|
||||
with the conversation_id immediately before answering any user question. \
|
||||
The title must be 5 words or fewer and should summarize the user's intent.".to_string(),
|
||||
));
|
||||
(reg.to_openai_tools(), true)
|
||||
} else {
|
||||
(tools.clone(), false)
|
||||
}
|
||||
} else {
|
||||
(tools.clone(), false)
|
||||
}
|
||||
} else {
|
||||
(tools.clone(), false)
|
||||
};
|
||||
|
||||
loop {
|
||||
let on_chunk_cb = on_chunk.clone();
|
||||
let on_chunk_cb2 = on_chunk.clone();
|
||||
let tx_arc = Arc::new(tx.clone());
|
||||
let tx_arc2 = tx_arc.clone();
|
||||
|
||||
let response = call_stream(
|
||||
&messages, model_name, config, temperature, max_tokens,
|
||||
if tools_enabled { Some(&tools) } else { None }, None,
|
||||
Arc::new(move |delta| {
|
||||
let content = delta.to_string().replace('\n', "");
|
||||
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
|
||||
fut
|
||||
}),
|
||||
Arc::new(move |delta| {
|
||||
let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking });
|
||||
fut
|
||||
}),
|
||||
Arc::new(move |tc: &StreamedToolCall| {
|
||||
let tx = tx_arc2.clone();
|
||||
let tc_owned = tc.clone();
|
||||
Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||||
}),
|
||||
).await?;
|
||||
|
||||
total_input_tokens += response.input_tokens;
|
||||
total_output_tokens += response.output_tokens;
|
||||
all_chunks.extend(response.chunks.clone());
|
||||
|
||||
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
|
||||
if !has_tool_calls {
|
||||
let final_content = response.content.clone();
|
||||
// Don't broadcast the done chunk via SSE/NATS — incremental deltas
|
||||
// already delivered the content; the separate "done" SSE event
|
||||
// signals completion. Pushing full content again would duplicate it
|
||||
// in the frontend streaming store.
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: final_content.clone() });
|
||||
return Ok(StreamResult {
|
||||
content: final_content,
|
||||
reasoning_content: response.reasoning_content,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
chunks: all_chunks,
|
||||
});
|
||||
}
|
||||
|
||||
full_content.push_str(&response.content);
|
||||
|
||||
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
|
||||
id: tc.id.clone(), type_: "function".into(),
|
||||
function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone() },
|
||||
}).collect();
|
||||
|
||||
messages.push(ChatRequestMessage::assistant(Some(response.content.clone()), Some(tool_calls.clone())));
|
||||
|
||||
// Drain tool call notifications
|
||||
loop {
|
||||
match rx.try_recv() {
|
||||
Ok(tc) => {
|
||||
let args_display = if tc.arguments.len() > 100 {
|
||||
let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100);
|
||||
format!("{}...", &tc.arguments[..end])
|
||||
} else { tc.arguments.clone() };
|
||||
let tool_display = format!("🔧 {}({})", tc.name, args_display);
|
||||
on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display });
|
||||
}
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
let calls: Vec<AgentToolCall> = response.tool_calls.iter().map(|tc| AgentToolCall {
|
||||
id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(),
|
||||
}).collect();
|
||||
|
||||
let tool_messages = execute_tools(
|
||||
&calls, &db, &cache, &app_config, project_id, sender_uid,
|
||||
tool_registry, embed_service.as_ref(), &on_chunk, &mut all_chunks,
|
||||
).await;
|
||||
|
||||
messages.extend(tool_messages);
|
||||
|
||||
tool_depth += 1;
|
||||
if tool_depth >= max_tool_depth {
|
||||
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
|
||||
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
|
||||
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_tools(
|
||||
calls: &[AgentToolCall],
|
||||
db: &db::database::AppDatabase,
|
||||
cache: &db::cache::AppCache,
|
||||
app_config: &config::AppConfig,
|
||||
project_id: Uuid,
|
||||
sender_uid: Uuid,
|
||||
tool_registry: Option<&ToolRegistry>,
|
||||
embed_service: Option<&EmbedService>,
|
||||
on_chunk: &SharedCallback,
|
||||
all_chunks: &mut Vec<StreamChunk>,
|
||||
) -> Vec<ChatRequestMessage> {
|
||||
let mut tool_messages = Vec::new();
|
||||
let mut ctx = ToolContext::new(db.clone(), cache.clone(), app_config.clone(), Uuid::nil(), Some(sender_uid))
|
||||
.with_project(project_id);
|
||||
if let Some(es) = embed_service {
|
||||
ctx = ctx.with_embed_service(es.clone());
|
||||
}
|
||||
if let Some(registry) = tool_registry {
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
}
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
for call in calls {
|
||||
let call_clone = call.clone();
|
||||
let mut ctx_clone = ctx.clone();
|
||||
join_set.spawn(async move {
|
||||
let executor = ToolExecutor::new();
|
||||
let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
|
||||
(call_clone, res)
|
||||
});
|
||||
}
|
||||
|
||||
let heartbeat_dur = std::time::Duration::from_secs(10);
|
||||
while !join_set.is_empty() {
|
||||
tokio::select! {
|
||||
Some(res) = join_set.join_next() => {
|
||||
if let Ok((call, results)) = res {
|
||||
match results {
|
||||
Ok(results) => {
|
||||
for result in &results {
|
||||
let preview = match &result.result {
|
||||
crate::tool::ToolResult::Ok(v) => {
|
||||
let t = v.to_string();
|
||||
if t.len() > 300 {
|
||||
let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
|
||||
format!("{}...", &t[..end])
|
||||
} else { t.clone() }
|
||||
}
|
||||
crate::tool::ToolResult::Error(msg) => msg.clone(),
|
||||
};
|
||||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||||
}
|
||||
let success_display = format!("✅ {}", call.name);
|
||||
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
|
||||
let msgs = ToolExecutor::to_tool_messages(&results);
|
||||
tool_messages.extend(msgs);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
||||
let err_text = format!("[Tool call failed: {}]", e);
|
||||
let err_display = format!("❌ {} (failed)", call.name);
|
||||
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
|
||||
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
_ = tokio::time::sleep(heartbeat_dur) => {
|
||||
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_messages
|
||||
}
|
||||
@ -4,7 +4,7 @@ use sea_orm::*;
|
||||
use super::context::RoomMessageContext;
|
||||
use super::{AiRequest, Mention};
|
||||
use crate::client::types::ChatRequestMessage;
|
||||
use crate::compact::{CompactConfig, CompactService};
|
||||
use crate::compact::CompactService;
|
||||
use crate::embed::EmbedService;
|
||||
use crate::error::Result;
|
||||
use crate::perception::{PerceptionService, SkillEntry};
|
||||
@ -55,7 +55,6 @@ impl MessageBuilder {
|
||||
let mut processed_history = Vec::new();
|
||||
if let Some(compact_service) = &self.compact_service {
|
||||
let compact_cache_key = format!("ai:compact:{}", request.room.id);
|
||||
let compact_config = CompactConfig::default();
|
||||
let cached_summary: Option<String> = match request.cache.conn().await {
|
||||
Ok(mut conn) => redis::cmd("GET").arg(&compact_cache_key).query_async::<Option<String>>(&mut conn).await.unwrap_or(None),
|
||||
Err(e) => { tracing::warn!(error = %e, "compact cache: conn failed"); None }
|
||||
@ -71,7 +70,22 @@ impl MessageBuilder {
|
||||
}
|
||||
|
||||
if processed_history.is_empty() {
|
||||
match compact_service.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config).await {
|
||||
let compact_config = request.context_setting.as_ref()
|
||||
.map(|s| crate::compact::CompactConfig::from_project_setting(
|
||||
s.context_window_tokens,
|
||||
s.compaction_threshold,
|
||||
s.compaction_max_summary_ratio,
|
||||
))
|
||||
.unwrap_or_default();
|
||||
|
||||
match compact_service.compact_room(
|
||||
request.room.id,
|
||||
compact_config.default_level,
|
||||
Some(request.user_names.clone()),
|
||||
request.sender.uid,
|
||||
request.context_setting.as_ref().map(|s| s.context_window_tokens).unwrap_or(128000),
|
||||
request.context_setting.as_ref().map(|s| s.compaction_max_summary_ratio).unwrap_or(0.2),
|
||||
).await {
|
||||
Ok(compact_summary) => {
|
||||
if !compact_summary.summary.is_empty() {
|
||||
messages.push(ChatRequestMessage::system(format!("Conversation summary:\n{}", compact_summary.summary)));
|
||||
@ -174,7 +188,13 @@ impl MessageBuilder {
|
||||
let keyword_skills = self.perception_service.inject_skills(&request.input, &history_texts, &[], &all_skills).await;
|
||||
let mut vector_skills = Vec::new();
|
||||
if let Some(es) = &self.embed_service {
|
||||
vector_skills = crate::perception::VectorActiveAwareness::default().detect(es, &request.input, &request.project.id.to_string()).await;
|
||||
let rag_enabled = request.context_setting.as_ref().map(|s| s.rag_enabled).unwrap_or(true);
|
||||
if rag_enabled {
|
||||
let max_results = request.context_setting.as_ref().map(|s| s.rag_max_results as usize).unwrap_or(3);
|
||||
let min_score = request.context_setting.as_ref().map(|s| s.rag_min_score).unwrap_or(0.70);
|
||||
let awareness = crate::perception::VectorActiveAwareness::new(max_results, min_score);
|
||||
vector_skills = awareness.detect(es, &request.input, &request.project.id.to_string()).await;
|
||||
}
|
||||
}
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
let mut result = Vec::new();
|
||||
@ -184,8 +204,17 @@ impl MessageBuilder {
|
||||
}
|
||||
|
||||
async fn build_memory_context(&self, request: &AiRequest) -> Vec<crate::perception::vector::MemoryContext> {
|
||||
let rag_enabled = request.context_setting.as_ref().map(|s| s.rag_enabled).unwrap_or(true);
|
||||
if !rag_enabled {
|
||||
return Vec::new();
|
||||
}
|
||||
match &self.embed_service {
|
||||
Some(es) => crate::perception::VectorPassiveAwareness::default().detect(es, &request.input, &request.project.display_name, &request.room.id.to_string()).await,
|
||||
Some(es) => {
|
||||
let max_results = request.context_setting.as_ref().map(|s| s.rag_max_results as usize).unwrap_or(3);
|
||||
let min_score = request.context_setting.as_ref().map(|s| s.rag_min_score).unwrap_or(0.72);
|
||||
let awareness = crate::perception::VectorPassiveAwareness::new(max_results, min_score);
|
||||
awareness.detect(es, &request.input, &request.project.display_name, &request.room.id.to_string()).await
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ use std::pin::Pin;
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use models::agents::model;
|
||||
use models::projects::project;
|
||||
use models::projects::{project, project_context_setting};
|
||||
use models::repos::repo;
|
||||
use models::rooms::{room, room_message};
|
||||
use models::users::user;
|
||||
@ -44,7 +44,32 @@ impl Default for AiChunkType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional streaming callback: called for each token chunk.
|
||||
const THINK_OPEN: &str = "\x3cthinking\x3e";
|
||||
const THINK_CLOSE: &str = "\x3c/response\x3e";
|
||||
|
||||
/// Strip XML-format thinking tags that some models (e.g. DeepSeek-R1) embed
|
||||
/// in reasoning output. Also normalizes excessive consecutive newlines (3+ → 2).
|
||||
pub fn normalize_thinking_content(content: &str) -> String {
|
||||
let content = content
|
||||
.replace(THINK_CLOSE, "")
|
||||
.replace(THINK_OPEN, "")
|
||||
.replace("\x3cthinking", "")
|
||||
.replace("/response\x3e", "");
|
||||
let mut result = String::with_capacity(content.len());
|
||||
let mut newline_count = 0usize;
|
||||
for ch in content.chars() {
|
||||
if ch == '\n' {
|
||||
newline_count += 1;
|
||||
if newline_count <= 2 {
|
||||
result.push(ch);
|
||||
}
|
||||
} else {
|
||||
newline_count = 0;
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
result.trim().to_string()
|
||||
}
|
||||
pub type StreamCallback = Box<
|
||||
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
|
||||
>;
|
||||
@ -55,6 +80,7 @@ pub struct AiRequest {
|
||||
pub config: AppConfig,
|
||||
pub model: model::Model,
|
||||
pub project: project::Model,
|
||||
pub context_setting: Option<project_context_setting::Model>,
|
||||
pub sender: user::Model,
|
||||
pub room: room::Model,
|
||||
pub input: String,
|
||||
@ -76,6 +102,7 @@ pub enum Mention {
|
||||
Repo(repo::Model),
|
||||
}
|
||||
|
||||
pub mod chat_execution;
|
||||
pub mod context;
|
||||
pub mod message_builder;
|
||||
pub mod nonstreaming_execution;
|
||||
|
||||
@ -82,13 +82,13 @@ pub async fn execute_process(
|
||||
tool_depth += 1;
|
||||
if tool_depth >= max_tool_depth {
|
||||
let content = if text.is_empty() { format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth) } else { text };
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
return Ok(ProcessResult { content, input_tokens, output_tokens });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
return Ok(ProcessResult { content: text, input_tokens, output_tokens });
|
||||
}
|
||||
}
|
||||
@ -111,7 +111,7 @@ async fn execute_tools(
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
||||
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.clone(), session_id: recorder.session_id(), tool_name: call.clone(), caller: request.sender.uid, arguments: serde_json::Value::Null, status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: Uuid::new_v4().to_string(), session_id: recorder.session_id(), tool_name: call.clone(), caller: request.sender.uid, arguments: serde_json::Value::Null, status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
|
||||
}
|
||||
crate::tool::ToolExecutor::to_tool_messages(&results)
|
||||
}
|
||||
|
||||
@ -18,6 +18,8 @@ pub async fn execute_process_react<C, Fut>(
|
||||
request: &AiRequest, mut on_chunk: C,
|
||||
tool_registry: &ToolRegistry,
|
||||
ai_base_url: Option<String>, ai_api_key: Option<String>,
|
||||
room_preamble: Option<&str>,
|
||||
message_producer: Option<queue::MessageProducer>,
|
||||
) -> Result<(String, i64, i64)>
|
||||
where
|
||||
C: FnMut(ReactStep) -> Fut + Send,
|
||||
@ -33,6 +35,9 @@ where
|
||||
let room_id = request.room.id;
|
||||
let sender_uid = request.sender.uid;
|
||||
let project_id = request.project.id;
|
||||
let ai_model_id = request.model.id;
|
||||
let ai_model_name = request.model.name.clone();
|
||||
let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let session_id = Uuid::now_v7();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai::Entity::find()
|
||||
@ -46,7 +51,9 @@ where
|
||||
if let Some(handler) = tool_registry.get(&name) {
|
||||
let adapter = crate::tool::RigToolAdapter::new(
|
||||
handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(),
|
||||
room_id, Some(sender_uid), project_id,
|
||||
room_id, Some(sender_uid), project_id, message_producer.clone(),
|
||||
Some(ai_model_id), Some(ai_model_name.clone()),
|
||||
sent_in_turn.clone(),
|
||||
);
|
||||
tools.push(Box::new(RecordingTool::new(Box::new(adapter), db.clone(), session_id, sender_uid)));
|
||||
}
|
||||
@ -54,8 +61,14 @@ where
|
||||
|
||||
let rig_client = client_config.build_rig_client();
|
||||
let model = rig_client.completion_model(&request.model.name);
|
||||
|
||||
let preamble = match room_preamble {
|
||||
Some(rp) => format!("{}\n{}", rp, DEFAULT_SYSTEM_PROMPT),
|
||||
None => DEFAULT_SYSTEM_PROMPT.to_string(),
|
||||
};
|
||||
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(DEFAULT_SYSTEM_PROMPT)
|
||||
.preamble(&preamble)
|
||||
.tools(tools)
|
||||
.default_max_turns(request.max_tool_depth)
|
||||
.build();
|
||||
@ -77,7 +90,8 @@ where
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
|
||||
step_count += 1;
|
||||
let t = text.text;
|
||||
on_chunk(ReactStep::Answer { step: step_count, answer: t.clone() }).await;
|
||||
let cleaned = t.replace('\n', "");
|
||||
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
|
||||
final_content.push_str(&t);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
|
||||
@ -120,7 +134,7 @@ where
|
||||
}
|
||||
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
|
||||
|
||||
Ok((final_content, total_input_tokens, total_output_tokens))
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ use crate::embed::EmbedService;
|
||||
use crate::error::Result;
|
||||
use crate::perception::PerceptionService;
|
||||
use crate::tool::registry::ToolRegistry;
|
||||
use queue::MessageProducer;
|
||||
|
||||
/// Result from streaming AI response.
|
||||
pub struct StreamResult {
|
||||
@ -94,7 +95,8 @@ impl ChatService {
|
||||
) -> Option<crate::RigToolSet> {
|
||||
self.tool_registry.as_ref().map(|registry| {
|
||||
crate::RigToolSet::from_registry(
|
||||
registry, db, cache, config, room_id, sender_id, project_id,
|
||||
registry, db, cache, config, room_id, sender_id, project_id, None, None, None,
|
||||
std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
)
|
||||
})
|
||||
}
|
||||
@ -134,6 +136,35 @@ impl ChatService {
|
||||
super::react_execution::execute_process_react(
|
||||
request, on_chunk, registry,
|
||||
self.ai_base_url.clone(), self.ai_api_key.clone(),
|
||||
None, None,
|
||||
).await
|
||||
}
|
||||
|
||||
/// Process AI request via rig-based ReAct streaming loop with room-specific tools.
|
||||
///
|
||||
/// Merges `room_tools` (e.g. `send_message`, `retract_message`) into the base
|
||||
/// tool registry on-the-fly. The `room_preamble` is prepended to the default
|
||||
/// system prompt to instruct the AI about room communication rules.
|
||||
/// `message_producer` enables tools to publish events via the message queue.
|
||||
pub async fn process_react_room<C, Fut>(
|
||||
&self, request: &AiRequest, on_chunk: C,
|
||||
room_tools: ToolRegistry,
|
||||
room_preamble: Option<&str>,
|
||||
message_producer: Option<MessageProducer>,
|
||||
) -> Result<(String, i64, i64)>
|
||||
where
|
||||
C: FnMut(crate::react::ReactStep) -> Fut + Send,
|
||||
Fut: std::future::Future<Output = ()> + Send,
|
||||
{
|
||||
let Some(registry) = &self.tool_registry else {
|
||||
return Err(crate::error::AgentError::Internal("no tool registry registered".into()));
|
||||
};
|
||||
let mut merged = registry.clone();
|
||||
merged.merge(room_tools);
|
||||
super::react_execution::execute_process_react(
|
||||
request, on_chunk, &merged,
|
||||
self.ai_base_url.clone(), self.ai_api_key.clone(),
|
||||
room_preamble, message_producer,
|
||||
).await
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ pub async fn record_ai_session(
|
||||
cache: &AppCache,
|
||||
db: &AppDatabase,
|
||||
project_id: Uuid,
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
room_id: Uuid,
|
||||
model_id: Uuid,
|
||||
@ -39,7 +40,7 @@ pub async fn record_ai_session(
|
||||
}
|
||||
|
||||
let (cost, currency, error_msg) = match crate::billing::record_ai_usage(
|
||||
db, project_id, version_id, input_tokens, output_tokens,
|
||||
db, project_id, user_id, version_id, input_tokens, output_tokens,
|
||||
).await {
|
||||
Ok(crate::billing::BillingResult::Success(record)) => {
|
||||
(Some(record.cost), Some(record.currency), None)
|
||||
@ -70,7 +71,7 @@ async fn create_billing_error_system_message(
|
||||
use models::rooms::{room_message, MessageContentType, MessageSenderType};
|
||||
use sea_orm::Set;
|
||||
|
||||
let seq_key = format!("room:seq:{}", room_id);
|
||||
let seq_key = format!("seq:room:{}", room_id);
|
||||
let seq = match cache.conn().await {
|
||||
Ok(mut conn) => {
|
||||
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {
|
||||
|
||||
@ -62,7 +62,8 @@ pub async fn execute_process_stream(
|
||||
&messages, &model_name, &config, temperature, max_tokens,
|
||||
if tools_enabled { Some(&tools) } else { None }, None,
|
||||
Arc::new(move |delta| {
|
||||
let fut = on_chunk_cb(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Answer });
|
||||
let content = delta.to_string().replace('\n', "");
|
||||
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
|
||||
fut
|
||||
}),
|
||||
Arc::new(move |delta| {
|
||||
@ -82,11 +83,10 @@ pub async fn execute_process_stream(
|
||||
|
||||
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
|
||||
if !has_tool_calls {
|
||||
return handle_final_answer(response, full_content, on_chunk, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
|
||||
return handle_final_answer(response, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
|
||||
}
|
||||
|
||||
full_content.push_str(&response.content);
|
||||
full_content.push('\n');
|
||||
|
||||
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
|
||||
id: tc.id.clone(), type_: "function".into(),
|
||||
@ -114,7 +114,7 @@ pub async fn execute_process_stream(
|
||||
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
|
||||
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
|
||||
}
|
||||
}
|
||||
@ -155,60 +155,83 @@ async fn execute_streaming_tools(
|
||||
if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); }
|
||||
|
||||
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for call in calls {
|
||||
let start = std::time::Instant::now();
|
||||
let call_clone = call.clone();
|
||||
let mut ctx_clone = ctx.clone();
|
||||
let (result_tx, mut result_rx) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let sender_uid = request.sender.uid;
|
||||
let recorder_clone = recorder.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let start = std::time::Instant::now();
|
||||
let executor = ToolExecutor::new();
|
||||
let res = executor.execute_batch(vec![call_clone], &mut ctx_clone).await;
|
||||
let _ = result_tx.send(res);
|
||||
let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
|
||||
(call_clone, res, start.elapsed(), sender_uid, recorder_clone)
|
||||
});
|
||||
}
|
||||
|
||||
let heartbeat_dur = std::time::Duration::from_secs(10);
|
||||
let results = loop {
|
||||
tokio::select! {
|
||||
res = &mut result_rx => {
|
||||
match res { Ok(inner) => break inner, Err(_) => break Err(crate::tool::ToolError::ExecutionError("tool task cancelled".into())), }
|
||||
},
|
||||
_ = tokio::time::sleep(heartbeat_dur) => {
|
||||
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
|
||||
let heartbeat_dur = std::time::Duration::from_secs(10);
|
||||
while !join_set.is_empty() {
|
||||
tokio::select! {
|
||||
Some(res) = join_set.join_next() => {
|
||||
if let Ok((call, results, elapsed, sender_uid, recorder)) = res {
|
||||
match results {
|
||||
Ok(results) => {
|
||||
for result in &results {
|
||||
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
|
||||
let preview = if text.len() > 300 {
|
||||
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
|
||||
format!("{}...", &text[..end])
|
||||
} else { text.clone() };
|
||||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||||
|
||||
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
||||
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||||
tool_call_id: call.id.clone(),
|
||||
session_id: recorder.session_id(),
|
||||
tool_name: call.name.clone(),
|
||||
caller: sender_uid,
|
||||
arguments: call.arguments_json().unwrap_or_default(),
|
||||
status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success },
|
||||
execution_time_ms: Some(elapsed.as_millis() as i64),
|
||||
error_message: error_msg,
|
||||
error_stack: None,
|
||||
retry_count: 0
|
||||
});
|
||||
}
|
||||
let success_display = format!("✅ {}", call.name);
|
||||
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
|
||||
let msgs = ToolExecutor::to_tool_messages(&results);
|
||||
tool_messages.extend(msgs);
|
||||
}
|
||||
Err(e) => {
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||||
tool_call_id: call.id.clone(),
|
||||
session_id: recorder.session_id(),
|
||||
tool_name: call.name.clone(),
|
||||
caller: sender_uid,
|
||||
arguments: call.arguments_json().unwrap_or_default(),
|
||||
status: models::ai::ToolCallStatus::Failed,
|
||||
execution_time_ms: Some(elapsed.as_millis() as i64),
|
||||
error_message: Some(e.to_string()),
|
||||
error_stack: None,
|
||||
retry_count: 0
|
||||
});
|
||||
let err_text = format!("[Tool call failed: {}]", e);
|
||||
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
||||
let err_display = format!("❌ {} (failed)", call.name);
|
||||
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
|
||||
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match results {
|
||||
Ok(results) => {
|
||||
for result in &results {
|
||||
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
|
||||
let preview = if text.len() > 300 {
|
||||
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
|
||||
format!("{}...", &text[..end])
|
||||
} else { text.clone() };
|
||||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||||
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
||||
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
|
||||
}
|
||||
let success_display = format!("✅ {}", call.name);
|
||||
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
|
||||
let msgs = ToolExecutor::to_tool_messages(&results);
|
||||
tool_messages.extend(msgs);
|
||||
}
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed), error_message: Some(e.to_string()), error_stack: None, retry_count: 0 });
|
||||
let err_text = format!("[Tool call failed: {}]", e);
|
||||
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
||||
let err_display = format!("❌ {} (failed)", call.name);
|
||||
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
|
||||
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
||||
},
|
||||
_ = tokio::time::sleep(heartbeat_dur) => {
|
||||
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -216,18 +239,20 @@ async fn execute_streaming_tools(
|
||||
}
|
||||
|
||||
async fn handle_final_answer(
|
||||
response: crate::client::StreamResponse, full_content: String,
|
||||
on_chunk: SharedCallback,
|
||||
response: crate::client::StreamResponse,
|
||||
mut all_chunks: Vec<StreamChunk>, request: &AiRequest,
|
||||
session_id: Uuid, version_id: Option<Uuid>,
|
||||
total_input_tokens: i64, total_output_tokens: i64,
|
||||
session_start: std::time::Instant,
|
||||
) -> Result<StreamResult> {
|
||||
let full_content = full_content + &response.content;
|
||||
on_chunk(AiStreamChunk { content: response.content.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
|
||||
let full_content = response.content.clone();
|
||||
// Don't broadcast the done chunk via SSE/NATS — incremental deltas
|
||||
// already delivered the content; the separate completion event
|
||||
// signals end of stream. Broadcasting full content again would
|
||||
// duplicate it in the frontend streaming display.
|
||||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: response.content.clone() });
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: response.input_tokens, output_tokens: response.output_tokens, chunks: all_chunks })
|
||||
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
|
||||
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: total_input_tokens, output_tokens: total_output_tokens, chunks: all_chunks })
|
||||
}
|
||||
|
||||
async fn inject_passive_skills_stream(
|
||||
|
||||
@ -106,8 +106,10 @@ impl RetryState {
|
||||
fn backoff_duration(&self) -> std::time::Duration {
|
||||
let exp = self.attempt.min(5);
|
||||
let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms);
|
||||
let jitter = fastrand_u64(base_ms + 1);
|
||||
std::time::Duration::from_millis(jitter)
|
||||
let max_jitter = (base_ms / 2).max(base_ms);
|
||||
let offset = fastrand_u64(max_jitter + 1).saturating_sub(base_ms / 2);
|
||||
let total = base_ms.saturating_add(offset).min(self.max_backoff_ms);
|
||||
std::time::Duration::from_millis(total)
|
||||
}
|
||||
fn next(&mut self) { self.attempt += 1; }
|
||||
}
|
||||
|
||||
@ -4,18 +4,14 @@ use models::rooms::room_message::{
|
||||
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
|
||||
};
|
||||
use models::users::user::{Column as UserCol, Entity as User};
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder};
|
||||
use serde_json::Value;
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::client::types::ChatRequestMessage;
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::client::call_with_params;
|
||||
use crate::AgentError;
|
||||
use crate::compact::helpers::summary_content;
|
||||
use crate::compact::types::{
|
||||
CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult,
|
||||
};
|
||||
use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
|
||||
use crate::tokent::{TokenUsage, resolve_usage};
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -35,8 +31,29 @@ impl CompactService {
|
||||
room_id: Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
||||
requester_id: Uuid,
|
||||
context_window_tokens: i32,
|
||||
compaction_max_summary_ratio: f32,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let messages = self.fetch_room_messages(room_id).await?;
|
||||
// Verify room access at the database level to ensure auth context is enforced.
|
||||
// Public rooms are accessible to project members.
|
||||
// For simplicity in this audit fix, we'll fetch only if access exists.
|
||||
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
|
||||
|
||||
if messages.is_empty() {
|
||||
// Check if room actually exists or if it's just empty/inaccessible
|
||||
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
|
||||
.one(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?
|
||||
.is_some();
|
||||
|
||||
if room_exists {
|
||||
return Err(AgentError::Internal("Access denied or room empty".into()));
|
||||
} else {
|
||||
return Err(AgentError::Internal("Room not found".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let user_ids: Vec<Uuid> = messages
|
||||
.iter()
|
||||
@ -74,7 +91,9 @@ impl CompactService {
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
|
||||
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
||||
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
|
||||
|
||||
// Build text of what was summarized (for tiktoken fallback)
|
||||
let summarized_text = to_summarize
|
||||
@ -100,10 +119,13 @@ impl CompactService {
|
||||
session_id: Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
||||
context_window_tokens: i32,
|
||||
compaction_max_summary_ratio: f32,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(session_id))
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.limit(10000)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
@ -148,10 +170,10 @@ impl CompactService {
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
// Summarize the earlier messages
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
|
||||
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
||||
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
|
||||
|
||||
// Build text of what was summarized (for tiktoken fallback)
|
||||
let summarized_text = to_summarize
|
||||
.iter()
|
||||
.map(|m| m.content.as_str())
|
||||
@ -170,164 +192,51 @@ impl CompactService {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn summary_as_system_message(summary: &CompactSummary) -> ChatRequestMessage {
|
||||
let content = summary_content(summary);
|
||||
ChatRequestMessage::system(content)
|
||||
}
|
||||
|
||||
/// Check if the message history for a room exceeds the token threshold.
|
||||
/// Returns `ThresholdResult::Skip` if below threshold, `Compact` if above.
|
||||
///
|
||||
/// This method fetches messages and estimates their token count using tiktoken.
|
||||
/// Call this before deciding whether to run full compaction.
|
||||
pub async fn check_threshold(
|
||||
async fn fetch_room_messages_secure(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
config: CompactConfig,
|
||||
) -> Result<ThresholdResult, AgentError> {
|
||||
let messages = self.fetch_room_messages(room_id).await?;
|
||||
let tokens = self.estimate_message_tokens(&messages);
|
||||
requester_id: Uuid,
|
||||
) -> Result<Vec<RoomMessageModel>, AgentError> {
|
||||
use models::rooms::{RoomUserState, RoomAccess};
|
||||
use sea_orm::QueryTrait;
|
||||
use sea_orm::sea_query::Expr;
|
||||
|
||||
// Find messages for the room where the requester has access.
|
||||
// We check both the room_user_state table (membership) and the room_access table (explicit grants).
|
||||
RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(room_id))
|
||||
.filter(
|
||||
sea_orm::Condition::any()
|
||||
.add(
|
||||
Expr::exists(
|
||||
RoomUserState::find()
|
||||
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
|
||||
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
|
||||
.into_query()
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Expr::exists(
|
||||
RoomAccess::find()
|
||||
.filter(models::rooms::room_access::Column::Room.eq(room_id))
|
||||
.filter(models::rooms::room_access::Column::User.eq(requester_id))
|
||||
.into_query()
|
||||
)
|
||||
)
|
||||
)
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.limit(10000)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))
|
||||
}
|
||||
|
||||
if tokens < config.token_threshold {
|
||||
return Ok(ThresholdResult::Skip {
|
||||
estimated_tokens: tokens,
|
||||
});
|
||||
}
|
||||
|
||||
let level = if config.auto_level {
|
||||
CompactLevel::auto_select(tokens, config.token_threshold)
|
||||
fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
|
||||
let sender_name = if let Some(user_id) = m.sender_id {
|
||||
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
|
||||
} else {
|
||||
config.default_level
|
||||
m.sender_type.to_string()
|
||||
};
|
||||
|
||||
Ok(ThresholdResult::Compact {
|
||||
estimated_tokens: tokens,
|
||||
level,
|
||||
})
|
||||
}
|
||||
|
||||
/// Auto-compact a room: estimates token count, only compresses if over threshold.
|
||||
///
|
||||
/// This is the recommended entry point for automatic compaction.
|
||||
/// - If tokens < threshold → returns a no-op summary (empty summary, no compression)
|
||||
/// - If tokens >= threshold → compresses with auto-selected level
|
||||
pub async fn compact_room_auto(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
||||
config: CompactConfig,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let threshold_result = self.check_threshold(room_id, config).await?;
|
||||
|
||||
match threshold_result {
|
||||
ThresholdResult::Skip { .. } => {
|
||||
// Below threshold — no compaction needed, return empty summary
|
||||
let messages = self.fetch_room_messages(room_id).await?;
|
||||
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
|
||||
let user_name_map = match user_names {
|
||||
Some(map) => map,
|
||||
None => self.get_user_name_map(&user_ids).await?,
|
||||
};
|
||||
let retained: Vec<MessageSummary> = messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
return Ok(CompactSummary {
|
||||
session_id: Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary: String::new(),
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: 0,
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
ThresholdResult::Compact { level, .. } => {
|
||||
// Above threshold — compress with selected level
|
||||
return self
|
||||
.compact_room_with_level(room_id, level, user_names)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compact a room with a specific level (bypassing threshold check).
|
||||
/// Use this when the caller has already decided compaction is needed.
|
||||
async fn compact_room_with_level(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let messages = self.fetch_room_messages(room_id).await?;
|
||||
|
||||
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
|
||||
let user_name_map = match user_names {
|
||||
Some(map) => map,
|
||||
None => self.get_user_name_map(&user_ids).await?,
|
||||
};
|
||||
|
||||
if messages.len() <= level.retain_count() {
|
||||
let retained: Vec<MessageSummary> = messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
return Ok(CompactSummary {
|
||||
session_id: Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary: String::new(),
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: 0,
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
let retain_count = level.retain_count();
|
||||
let split_index = messages.len().saturating_sub(retain_count);
|
||||
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
||||
|
||||
let retained: Vec<MessageSummary> = retained_messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
|
||||
|
||||
let summarized_text = to_summarize
|
||||
.iter()
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
||||
|
||||
Ok(CompactSummary {
|
||||
session_id: Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary,
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: to_summarize.len(),
|
||||
usage: Some(usage),
|
||||
})
|
||||
}
|
||||
|
||||
/// Estimate total token count of a message list using tiktoken.
|
||||
fn estimate_message_tokens(&self, messages: &[RoomMessageModel]) -> usize {
|
||||
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
|
||||
// Rough estimate: ~4 chars per token (safe upper bound)
|
||||
total_chars / 4
|
||||
}
|
||||
|
||||
fn message_to_summary(
|
||||
m: &RoomMessageModel,
|
||||
user_name_map: &std::collections::HashMap<Uuid, String>,
|
||||
) -> MessageSummary {
|
||||
let sender_name = m
|
||||
.sender_id
|
||||
.and_then(|id| user_name_map.get(&id).cloned())
|
||||
.unwrap_or_else(|| m.sender_type.to_string());
|
||||
MessageSummary {
|
||||
id: m.id,
|
||||
sender_type: m.sender_type.clone(),
|
||||
@ -335,35 +244,11 @@ impl CompactService {
|
||||
sender_name,
|
||||
content: m.content.clone(),
|
||||
content_type: m.content_type.clone(),
|
||||
tool_call_id: Self::extract_tool_call_id(&m.content),
|
||||
tool_call_id: None,
|
||||
send_at: m.send_at,
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_tool_call_id(content: &str) -> Option<String> {
|
||||
let content = content.trim();
|
||||
if let Ok(v) = serde_json::from_str::<Value>(content) {
|
||||
v.get("tool_call_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_room_messages(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
) -> Result<Vec<RoomMessageModel>, AgentError> {
|
||||
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(room_id))
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
async fn get_user_name_map(
|
||||
&self,
|
||||
user_ids: &[Uuid],
|
||||
@ -386,8 +271,8 @@ impl CompactService {
|
||||
async fn summarize_messages(
|
||||
&self,
|
||||
messages: &[RoomMessageModel],
|
||||
max_summary_tokens: usize,
|
||||
) -> Result<(String, Option<TokenUsage>), AgentError> {
|
||||
// Collect distinct user IDs
|
||||
let user_ids: Vec<Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
@ -395,10 +280,8 @@ impl CompactService {
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
// Query usernames
|
||||
let user_name_map = self.get_user_name_map(&user_ids).await?;
|
||||
|
||||
// Define sender mapper
|
||||
let sender_mapper = |m: &RoomMessageModel| {
|
||||
if let Some(user_id) = m.sender_id {
|
||||
if let Some(username) = user_name_map.get(&user_id) {
|
||||
@ -413,11 +296,13 @@ impl CompactService {
|
||||
let user_msg = ChatRequestMessage::user(format!(
|
||||
"Summarise the following conversation concisely, preserving all key facts, \
|
||||
decisions, and any pending or in-progress work. \
|
||||
The summary MUST NOT exceed {} tokens. \
|
||||
Use this format:\n\n\
|
||||
**Summary:** <one-paragraph overview>\n\
|
||||
**Key decisions:** <bullet list or 'none'>\n\
|
||||
**Open items:** <bullet list or 'none'>\n\n\
|
||||
Conversation:\n\n{}",
|
||||
max_summary_tokens,
|
||||
body
|
||||
));
|
||||
|
||||
@ -425,8 +310,8 @@ impl CompactService {
|
||||
&[user_msg],
|
||||
&self.model,
|
||||
&self.ai_client_config,
|
||||
0.3, // slightly higher temp for summarization
|
||||
1024, // max output tokens
|
||||
0.3,
|
||||
2048,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@ -434,7 +319,6 @@ impl CompactService {
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
// Prefer remote usage; fall back to None (caller will use tiktoken via resolve_usage)
|
||||
let remote_usage =
|
||||
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
|
||||
|
||||
|
||||
@ -74,6 +74,8 @@ pub struct CompactConfig {
|
||||
pub auto_level: bool,
|
||||
/// Fallback level when `auto_level` is false.
|
||||
pub default_level: CompactLevel,
|
||||
/// Maximum tokens the summary may contain (enforced via prompt).
|
||||
pub max_summary_tokens: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactConfig {
|
||||
@ -83,6 +85,20 @@ impl Default for CompactConfig {
|
||||
token_threshold: 8000,
|
||||
auto_level: true,
|
||||
default_level: CompactLevel::Light,
|
||||
max_summary_tokens: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactConfig {
|
||||
/// Build config from project context settings.
|
||||
pub fn from_project_setting(context_window_tokens: i32, compaction_threshold: f32, compaction_max_summary_ratio: f32) -> Self {
|
||||
let threshold = (context_window_tokens as f32 * compaction_threshold) as usize;
|
||||
Self {
|
||||
token_threshold: threshold,
|
||||
auto_level: true,
|
||||
default_level: CompactLevel::Light,
|
||||
max_summary_tokens: (context_window_tokens as f32 * compaction_max_summary_ratio) as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -575,11 +575,5 @@ pub struct EmbedMemoryInput {
|
||||
}
|
||||
|
||||
/// Input struct for batch tag embedding.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TagEmbedInput {
|
||||
pub repo_id: String,
|
||||
pub repo_name: String,
|
||||
pub project_id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
/// Re-exported from models for backward compatibility.
|
||||
pub use models::TagEmbedInput;
|
||||
|
||||
@ -52,3 +52,9 @@ impl From<sea_orm::DbErr> for AgentError {
|
||||
AgentError::Internal(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::tool::ToolError> for AgentError {
|
||||
fn from(e: crate::tool::ToolError) -> Self {
|
||||
AgentError::ToolExecutionFailed { tool: String::new(), cause: e.to_string() }
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@ pub mod sync;
|
||||
pub mod task;
|
||||
pub mod tokent;
|
||||
pub mod tool;
|
||||
pub use billing::{BillingRecord, BillingResult, record_ai_usage};
|
||||
pub use billing::{BillingRecord, BillingResult, record_ai_usage, initialize_user_billing, initialize_project_billing, check_balance, persist_billing_error};
|
||||
pub use sync::list_accessible_models;
|
||||
pub use task::TaskService;
|
||||
pub use tokent::{TokenUsage, resolve_usage};
|
||||
@ -33,7 +33,7 @@ pub use embed::{
|
||||
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
|
||||
};
|
||||
pub use error::{AgentError, Result};
|
||||
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
|
||||
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT, ROOM_CONTEXT_PROMPT};
|
||||
pub use tool::{
|
||||
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
|
||||
ToolRegistry, ToolResult, ToolSchema,
|
||||
|
||||
@ -44,6 +44,10 @@ impl Default for VectorActiveAwareness {
|
||||
}
|
||||
|
||||
impl VectorActiveAwareness {
|
||||
pub fn new(max_skills: usize, min_score: f32) -> Self {
|
||||
Self { max_skills, min_score }
|
||||
}
|
||||
|
||||
/// Search for skills semantically relevant to the user's input.
|
||||
///
|
||||
/// Uses Qdrant vector search within the given project to find skills whose
|
||||
@ -107,6 +111,10 @@ impl Default for VectorPassiveAwareness {
|
||||
}
|
||||
|
||||
impl VectorPassiveAwareness {
|
||||
pub fn new(max_memories: usize, min_score: f32) -> Self {
|
||||
Self { max_memories, min_score }
|
||||
}
|
||||
|
||||
/// Search for past conversation messages semantically similar to the current context.
|
||||
///
|
||||
/// Uses Qdrant to find memories within the same room that share semantic similarity
|
||||
|
||||
@ -16,7 +16,7 @@ pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are an AI assistant embedded in a
|
||||
|
||||
## Core Rule: Search Local Data First
|
||||
|
||||
Always query the platform's local data before guessing or referring to external sources. Local data includes: issues, pull requests, repositories, code reviews, chat messages, documentation, members, and other workspace resources.
|
||||
Always query the platform's local data before guessing or referring to external sources. Local data includes: issues, pull requests, repositories, code reviews, chat messages, documentation, members, and other project resources.
|
||||
|
||||
If local data does not contain the answer, state that clearly before considering external information.
|
||||
|
||||
@ -38,3 +38,32 @@ If local data does not contain the answer, state that clearly before considering
|
||||
- State ambiguity or uncertainty explicitly.
|
||||
- Prefer facts over speculation.
|
||||
"#;
|
||||
|
||||
/// Room-specific system prompt appended when the AI is @mentioned in a chat room.
|
||||
///
|
||||
/// In room context, the AI must NOT produce long-form output directly. Instead,
|
||||
/// it communicates through the `send_message` and `retract_message` tools.
|
||||
/// This keeps room messages concise and gives the AI control over what appears
|
||||
/// in the room.
|
||||
pub const ROOM_CONTEXT_PROMPT: &str = r#"
|
||||
## Room Communication Mode — CRITICAL
|
||||
|
||||
You are NOT in a direct chat. You are @mentioned in a chat room. **Your default response text will NOT be seen by anyone.** The ONLY way to communicate with the room is through the tools listed below.
|
||||
|
||||
### Mandatory Communication Rules
|
||||
|
||||
1. **ALWAYS use `send_message`** to deliver ANY response to the room. No exceptions. If you produce a final text response without calling `send_message`, the room will receive NOTHING.
|
||||
2. **Call `send_message` FIRST**, before any final text output. The tool call is what creates a visible room message.
|
||||
3. **Keep each message concise** — short, focused, actionable. No long reports, no multi-paragraph essays, no bullet lists longer than 5 items. If you need to convey a lot of information, summarize the key points and offer to provide details if asked.
|
||||
4. **Use mentions** to reference entities: `@[user:uuid:username]` for users, `@[repo:uuid:name]` for repositories, `@[skill:slug]` for skills, `@[issue:uuid:title]` for issues, `@[ai:uuid:name]` for other AI models.
|
||||
5. **Use `retract_message`** to revoke a message you just sent if it contains an error or needs to be withdrawn. You can only retract messages you sent in the current turn.
|
||||
6. **You may send multiple messages** — for complex responses, break your answer into multiple `send_message` calls (up to 99 per turn). Each message should be short, focused, and stand on its own. For example: first send a summary, then send follow-up details or action items as separate messages.
|
||||
7. **After calling `send_message`, your final text response can be brief** — just a summary or acknowledgment, since the actual room message has already been delivered via the tool call.
|
||||
|
||||
### Critical Reminder
|
||||
Your response text output is NOT delivered to the room. The `send_message` tool IS the delivery mechanism. If you forget to call `send_message`, nobody in the room will see your response.
|
||||
|
||||
### Room-Only Tools
|
||||
- `send_message(room_id?, content)` — Send a brief message to the room. The `room_id` parameter is optional (defaults to the current room). The `content` parameter is required and supports `@[type:id:label]` mention syntax.
|
||||
- `retract_message(message_id)` — Retract (revoke) a message you sent in the current turn. Requires the message UUID returned by `send_message`.
|
||||
"#;
|
||||
|
||||
@ -9,8 +9,18 @@
|
||||
//! return usage metadata (e.g., local models, streaming), tiktoken is used as
|
||||
//! a fallback for accurate counting.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::error::{AgentError, Result};
|
||||
|
||||
static TOKENIZER_CACHE: OnceLock<RwLock<HashMap<String, tiktoken_rs::CoreBPE>>> = OnceLock::new();
|
||||
|
||||
fn get_cached_tokenizers() -> &'static RwLock<HashMap<String, tiktoken_rs::CoreBPE>> {
|
||||
TOKENIZER_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Token usage data. Use `from_remote()` when the API returns usage info,
|
||||
/// or `from_estimate()` when falling back to tiktoken.
|
||||
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)]
|
||||
@ -155,14 +165,28 @@ fn safe_token_budget(context_limit: usize, reserve: usize) -> usize {
|
||||
fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
|
||||
use tiktoken_rs;
|
||||
|
||||
// Try model-specific tokenizer first
|
||||
if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
|
||||
return Ok(bpe);
|
||||
{
|
||||
let cache = get_cached_tokenizers().read().unwrap();
|
||||
if let Some(bpe) = cache.get(model) {
|
||||
return Ok(bpe.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: use cl100k_base for unknown models
|
||||
tiktoken_rs::cl100k_base()
|
||||
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))
|
||||
// Try model-specific tokenizer first
|
||||
let bpe = if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
|
||||
bpe
|
||||
} else {
|
||||
// Fallback: use cl100k_base for unknown models
|
||||
tiktoken_rs::cl100k_base()
|
||||
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))?
|
||||
};
|
||||
|
||||
{
|
||||
let mut cache = get_cached_tokenizers().write().unwrap();
|
||||
cache.insert(model.to_string(), bpe.clone());
|
||||
}
|
||||
|
||||
Ok(bpe)
|
||||
}
|
||||
|
||||
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).
|
||||
|
||||
@ -8,6 +8,7 @@ use std::sync::Arc;
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use config::AppConfig;
|
||||
use queue::MessageProducer;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::registry::ToolRegistry;
|
||||
@ -28,6 +29,15 @@ struct Inner {
|
||||
pub project_id: Uuid,
|
||||
pub registry: ToolRegistry,
|
||||
pub embed_service: Option<crate::embed::EmbedService>,
|
||||
pub message_producer: Option<MessageProducer>,
|
||||
/// When in room context, identifies the AI model that is responding.
|
||||
/// Used by send_message/retract_message to set the correct sender.
|
||||
pub ai_model_id: Option<Uuid>,
|
||||
pub ai_model_name: Option<String>,
|
||||
/// Message IDs sent by the AI in the current ReAct turn.
|
||||
/// Shared across tool calls so send_message can register IDs
|
||||
/// and retract_message can validate turn-scoped retraction.
|
||||
pub sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>,
|
||||
depth: u32,
|
||||
max_depth: u32,
|
||||
tool_call_count: usize,
|
||||
@ -52,6 +62,10 @@ impl ToolContext {
|
||||
project_id: Uuid::nil(),
|
||||
registry: ToolRegistry::new(),
|
||||
embed_service: None,
|
||||
message_producer: None,
|
||||
ai_model_id: None,
|
||||
ai_model_name: None,
|
||||
sent_in_turn: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
depth: 0,
|
||||
max_depth: 5,
|
||||
tool_call_count: 0,
|
||||
@ -85,10 +99,45 @@ impl ToolContext {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_message_producer(mut self, producer: MessageProducer) -> Self {
|
||||
Arc::make_mut(&mut self.inner).message_producer = Some(producer);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ai_model(mut self, model_id: Uuid, model_name: String) -> Self {
|
||||
Arc::make_mut(&mut self.inner).ai_model_id = Some(model_id);
|
||||
Arc::make_mut(&mut self.inner).ai_model_name = Some(model_name);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_sent_in_turn(mut self, sent: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>) -> Self {
|
||||
Arc::make_mut(&mut self.inner).sent_in_turn = sent;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a message ID as sent in the current turn (called by send_message).
|
||||
pub fn register_sent_message(&self, id: Uuid) {
|
||||
if let Ok(mut list) = self.inner.sent_in_turn.lock() {
|
||||
list.push(id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a message ID was sent in the current turn (called by retract_message).
|
||||
pub fn is_sent_in_turn(&self, id: Uuid) -> bool {
|
||||
self.inner.sent_in_turn.lock()
|
||||
.map(|list| list.contains(&id))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn embed_service(&self) -> Option<&crate::embed::EmbedService> {
|
||||
self.inner.embed_service.as_ref()
|
||||
}
|
||||
|
||||
/// Message queue producer for publishing room events (messages, retractions, etc.).
|
||||
pub fn message_producer(&self) -> Option<&MessageProducer> {
|
||||
self.inner.message_producer.as_ref()
|
||||
}
|
||||
|
||||
pub fn recursion_exceeded(&self) -> bool {
|
||||
self.inner.depth >= self.inner.max_depth
|
||||
}
|
||||
@ -146,6 +195,16 @@ impl ToolContext {
|
||||
self.inner.sender_id
|
||||
}
|
||||
|
||||
/// AI model ID when in room context (the AI that is responding).
|
||||
pub fn ai_model_id(&self) -> Option<Uuid> {
|
||||
self.inner.ai_model_id
|
||||
}
|
||||
|
||||
/// AI model display name when in room context.
|
||||
pub fn ai_model_name(&self) -> Option<String> {
|
||||
self.inner.ai_model_name.clone()
|
||||
}
|
||||
|
||||
/// Project context for the room.
|
||||
pub fn project_id(&self) -> Uuid {
|
||||
self.inner.project_id
|
||||
|
||||
@ -14,6 +14,7 @@ use super::context::ToolContext;
|
||||
use super::definition::ToolDefinition as AgentToolDefinition;
|
||||
use super::recorder::{ToolCallRecord, ToolCallRecorder};
|
||||
use super::registry::{ToolHandler, ToolRegistry};
|
||||
use queue::MessageProducer;
|
||||
|
||||
/// Returns true if the tool error message indicates a transient failure that can be retried.
|
||||
pub fn is_retryable_tool_error(msg: &str) -> bool {
|
||||
@ -170,6 +171,10 @@ impl RigToolSet {
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
message_producer: Option<MessageProducer>,
|
||||
ai_model_id: Option<uuid::Uuid>,
|
||||
ai_model_name: Option<String>,
|
||||
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
|
||||
) -> Self {
|
||||
let mut toolset = ToolSet::default();
|
||||
let mut definitions = HashMap::new();
|
||||
@ -191,6 +196,10 @@ impl RigToolSet {
|
||||
room_id,
|
||||
sender_id,
|
||||
project_id,
|
||||
message_producer: message_producer.clone(),
|
||||
ai_model_id,
|
||||
ai_model_name: ai_model_name.clone(),
|
||||
sent_in_turn: sent_in_turn.clone(),
|
||||
};
|
||||
toolset.add_tool(adapter);
|
||||
}
|
||||
@ -227,6 +236,10 @@ pub struct RigToolAdapter {
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
message_producer: Option<MessageProducer>,
|
||||
ai_model_id: Option<uuid::Uuid>,
|
||||
ai_model_name: Option<String>,
|
||||
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
|
||||
}
|
||||
|
||||
impl RigToolAdapter {
|
||||
@ -240,8 +253,12 @@ impl RigToolAdapter {
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
message_producer: Option<MessageProducer>,
|
||||
ai_model_id: Option<uuid::Uuid>,
|
||||
ai_model_name: Option<String>,
|
||||
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
|
||||
) -> Self {
|
||||
Self { handler, definition, db, cache, config, room_id, sender_id, project_id }
|
||||
Self { handler, definition, db, cache, config, room_id, sender_id, project_id, message_producer, ai_model_id, ai_model_name, sent_in_turn }
|
||||
}
|
||||
}
|
||||
|
||||
@ -272,16 +289,27 @@ impl ToolDyn for RigToolAdapter {
|
||||
let room_id = self.room_id;
|
||||
let sender_id = self.sender_id;
|
||||
let project_id = self.project_id;
|
||||
let message_producer = self.message_producer.clone();
|
||||
let ai_model_id = self.ai_model_id;
|
||||
let ai_model_name = self.ai_model_name.clone();
|
||||
let sent_in_turn = self.sent_in_turn.clone();
|
||||
|
||||
async move {
|
||||
let ctx = ToolContext::new(
|
||||
let mut ctx = ToolContext::new(
|
||||
db,
|
||||
cache,
|
||||
config,
|
||||
room_id,
|
||||
sender_id,
|
||||
)
|
||||
.with_project(project_id);
|
||||
.with_project(project_id)
|
||||
.with_sent_in_turn(sent_in_turn);
|
||||
if let Some(mp) = message_producer {
|
||||
ctx = ctx.with_message_producer(mp);
|
||||
}
|
||||
if let Some(mid) = ai_model_id {
|
||||
ctx = ctx.with_ai_model(mid, ai_model_name.unwrap_or_default());
|
||||
}
|
||||
|
||||
let args_json: serde_json::Value = serde_json::from_str(&args)
|
||||
.map_err(|e| ToolError::JsonError(e))?;
|
||||
|
||||
@ -26,6 +26,7 @@ email = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
service = { workspace = true }
|
||||
session = { workspace = true }
|
||||
agent = { workspace = true }
|
||||
git = { workspace = true }
|
||||
#frontend = { workspace = true }
|
||||
models = { workspace = true }
|
||||
@ -51,5 +52,12 @@ sea-orm = "2.0.0-rc.37"
|
||||
rust_decimal = "1.40.0"
|
||||
actix-multipart = { workspace = true, features = ["tempfile"] }
|
||||
redis = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "native-tls", "stream"] }
|
||||
|
||||
[build-dependencies]
|
||||
brotli = "7"
|
||||
flate2 = "1"
|
||||
sha2 = "0.10"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@ -18,7 +18,7 @@ pub fn init_agent_routes(cfg: &mut web::ServiceConfig) {
|
||||
web::post().to(code_review::trigger_code_review),
|
||||
)
|
||||
.route(
|
||||
"/{project}/issues/{issue_number}/triage",
|
||||
"/{project}/triage",
|
||||
web::get().to(issue_triage::triage_issue),
|
||||
)
|
||||
.route(
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
use actix_web::{HttpResponse, Result, web};
|
||||
use serde::Serialize;
|
||||
use session::SessionUser;
|
||||
use session::Session;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::ApiResponse;
|
||||
use crate::error::ApiError;
|
||||
use service::AppService;
|
||||
use service::error::AppError;
|
||||
use service::ws_token::WS_TOKEN_TTL_SECONDS;
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
@ -27,13 +28,16 @@ pub struct WsTokenResponse {
|
||||
)]
|
||||
pub async fn ws_token_generate(
|
||||
service: web::Data<AppService>,
|
||||
session_user: SessionUser,
|
||||
session: Session,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let SessionUser(user_id) = session_user;
|
||||
let user_id = session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
|
||||
|
||||
let device_id = session.get::<String>("device_id").unwrap_or_default();
|
||||
let client_id = session.get::<String>("client_id").unwrap_or_default();
|
||||
|
||||
let token = service
|
||||
.ws_token
|
||||
.generate_token(user_id)
|
||||
.generate_token(user_id, device_id, client_id)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
|
||||
@ -1 +1,224 @@
|
||||
fn main() {}
|
||||
//! Build script: reads all files from `dist/`, compresses them (brotli + gzip),
|
||||
//! computes etags via SHA-256, and generates a `frontend` Rust module for `dist.rs`.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
|
||||
// ── Compression helpers ──────────────────────────────────────────────────
|
||||
|
||||
fn gzip_compress(data: &[u8]) -> Vec<u8> {
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(6));
|
||||
encoder.write_all(data).unwrap();
|
||||
encoder.finish().unwrap()
|
||||
}
|
||||
|
||||
fn brotli_compress(data: &[u8]) -> Option<Vec<u8>> {
|
||||
use brotli::CompressorWriter;
|
||||
let buf = Vec::new();
|
||||
let mut writer = CompressorWriter::new(buf, 4096, 6, 16);
|
||||
if writer.write_all(data).is_ok() && writer.flush().is_ok() {
|
||||
Some(writer.into_inner())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ── ETag computation ─────────────────────────────────────────────────────
|
||||
|
||||
fn compute_etag(data: &[u8]) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
let hash = hasher.finalize();
|
||||
// First 32 hex chars for a compact etag
|
||||
format!("{:x}", hash)[..32].to_string()
|
||||
}
|
||||
|
||||
// ── Asset collection ─────────────────────────────────────────────────────
|
||||
|
||||
struct Asset {
|
||||
path: String,
|
||||
data: Vec<u8>,
|
||||
etag: String,
|
||||
brotli: Option<Vec<u8>>,
|
||||
gzip: Vec<u8>,
|
||||
}
|
||||
|
||||
fn collect_assets(dist_dir: &Path) -> BTreeMap<String, Asset> {
|
||||
let mut assets = BTreeMap::new();
|
||||
|
||||
for entry in walkdir(dist_dir) {
|
||||
let rel = entry.strip_prefix(dist_dir).unwrap();
|
||||
let path_str = rel.to_string_lossy().replace('\\', "/");
|
||||
if path_str.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let data = fs::read(&entry).unwrap_or_else(|e| {
|
||||
panic!("Failed to read dist file {}: {}", path_str, e)
|
||||
});
|
||||
|
||||
let etag = compute_etag(&data);
|
||||
let brotli_data = brotli_compress(&data);
|
||||
let gzip_data = gzip_compress(&data);
|
||||
|
||||
assets.insert(
|
||||
path_str.clone(),
|
||||
Asset {
|
||||
path: path_str,
|
||||
data,
|
||||
etag,
|
||||
brotli: brotli_data,
|
||||
gzip: gzip_data,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
assets
|
||||
}
|
||||
|
||||
fn walkdir(dir: &Path) -> Vec<PathBuf> {
|
||||
let mut files = Vec::new();
|
||||
if let Ok(entries) = fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
files.extend(walkdir(&path));
|
||||
} else {
|
||||
files.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
files
|
||||
}
|
||||
|
||||
// ── Code generation ──────────────────────────────────────────────────────
|
||||
|
||||
fn rust_byte_literal(data: &[u8]) -> String {
|
||||
if data.len() < 200 {
|
||||
let bytes: Vec<String> = data.iter().map(|b| b.to_string()).collect();
|
||||
format!("[{}]", bytes.join(", "))
|
||||
} else {
|
||||
let lines: Vec<String> = data
|
||||
.chunks(80)
|
||||
.map(|chunk| {
|
||||
chunk.iter().map(|b| b.to_string()).collect::<Vec<_>>().join(", ")
|
||||
})
|
||||
.collect();
|
||||
format!("[\n{}\n]", lines.join(",\n"))
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_ident(path: &str) -> String {
|
||||
let s = path.replace('-', "_").replace('.', "_").replace('/', "_").to_uppercase();
|
||||
format!("ASSET_{s}")
|
||||
}
|
||||
|
||||
fn etag_ident(path: &str) -> String {
|
||||
format!("ETAG_{}", path_to_ident(path))
|
||||
}
|
||||
|
||||
fn br_ident(path: &str) -> String {
|
||||
format!("{}_BR", path_to_ident(path))
|
||||
}
|
||||
|
||||
fn gz_ident(path: &str) -> String {
|
||||
format!("{}_GZ", path_to_ident(path))
|
||||
}
|
||||
|
||||
fn generate_frontend_module(assets: &BTreeMap<String, Asset>, out_dir: &Path) {
|
||||
let mut code = String::new();
|
||||
|
||||
code += "// AUTO-GENERATED by build.rs — DO NOT EDIT.\n";
|
||||
code += "// Frontend assets from dist/ embedded as static byte arrays.\n\n";
|
||||
|
||||
// Generate static byte arrays for each asset
|
||||
for (path, asset) in assets {
|
||||
let ident = path_to_ident(path);
|
||||
let etag_id = etag_ident(path);
|
||||
let br_id = br_ident(path);
|
||||
let gz_id = gz_ident(path);
|
||||
|
||||
code += &format!("static {}: &[u8] = &{};\n", ident, rust_byte_literal(&asset.data));
|
||||
code += &format!("static {}: &str = \"{}\";\n", etag_id, asset.etag);
|
||||
if let Some(ref br) = asset.brotli {
|
||||
code += &format!("static {}: &[u8] = &{};\n", br_id, rust_byte_literal(br));
|
||||
}
|
||||
code += &format!("static {}: &[u8] = &{};\n", gz_id, rust_byte_literal(&asset.gzip));
|
||||
code += "\n";
|
||||
}
|
||||
|
||||
// Uncompressed lookup
|
||||
code += "/// Get an uncompressed frontend asset by path, returning (data, etag).\n";
|
||||
code += "pub fn get_frontend_asset_with_etag(path: &str) -> Option<(&'static [u8], &'static str)> {\n";
|
||||
code += " match path {\n";
|
||||
for (path, _asset) in assets {
|
||||
let ident = path_to_ident(path);
|
||||
let etag_id = etag_ident(path);
|
||||
code += &format!(" \"{path}\" => Some((&{ident}, &{etag_id})),\n");
|
||||
}
|
||||
code += " _ => None,\n";
|
||||
code += " }\n";
|
||||
code += "}\n\n";
|
||||
|
||||
// Compressed lookup (prefers brotli, falls back to gzip)
|
||||
code += "/// Get a pre-compressed frontend asset by path.\n";
|
||||
code += "/// Returns (data, encoding, etag) — prefers brotli over gzip.\n";
|
||||
code += "pub fn get_frontend_asset_compressed(path: &str) -> Option<(&'static [u8], &'static str, &'static str)> {\n";
|
||||
code += " match path {\n";
|
||||
for (path, asset) in assets {
|
||||
let etag_id = etag_ident(path);
|
||||
if asset.brotli.is_some() {
|
||||
let br_id = br_ident(path);
|
||||
code += &format!(" \"{path}\" => Some((&{br_id}, \"br\", &{etag_id})),\n");
|
||||
} else {
|
||||
let gz_id = gz_ident(path);
|
||||
code += &format!(" \"{path}\" => Some((&{gz_id}, \"gzip\", &{etag_id})),\n");
|
||||
}
|
||||
}
|
||||
code += " _ => None,\n";
|
||||
code += " }\n";
|
||||
code += "}\n";
|
||||
|
||||
let out_path = out_dir.join("frontend.rs");
|
||||
fs::write(&out_path, code).unwrap_or_else(|e| {
|
||||
panic!("Failed to write generated frontend.rs: {}", e)
|
||||
});
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
|
||||
let workspace_root = Path::new(&manifest_dir)
|
||||
.parent()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap();
|
||||
let dist_dir = workspace_root.join("dist");
|
||||
|
||||
if !dist_dir.exists() {
|
||||
println!("cargo:warning=dist/ directory not found — frontend assets will not be embedded");
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let out_path = Path::new(&out_dir).join("frontend.rs");
|
||||
fs::write(
|
||||
&out_path,
|
||||
"//! No dist/ directory found — frontend assets not embedded.\n\
|
||||
pub fn get_frontend_asset_with_etag(_path: &str) -> Option<(&'static [u8], &'static str)> { None }\n\
|
||||
pub fn get_frontend_asset_compressed(_path: &str) -> Option<(&'static [u8], &'static str, &'static str)> { None }\n",
|
||||
).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-changed=dist/");
|
||||
|
||||
let assets = collect_assets(&dist_dir);
|
||||
println!("cargo:warning=Collected {} frontend assets from dist/", assets.len());
|
||||
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
generate_frontend_module(&assets, Path::new(&out_dir));
|
||||
}
|
||||
180
libs/api/chat/handlers/conversation.rs
Normal file
180
libs/api/chat/handlers/conversation.rs
Normal file
@ -0,0 +1,180 @@
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use session::Session;
|
||||
use service::error::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::ApiResponse;
|
||||
|
||||
use super::types::{ConversationListQuery, ConversationResponse, CreateConversationParams};
|
||||
|
||||
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
|
||||
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations",
|
||||
operation_id = "ai_conversation_create",
|
||||
request_body = CreateConversationParams,
|
||||
responses(
|
||||
(status = 200, description = "Conversation created", body = ApiResponse<ConversationResponse>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_create(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
params: web::Json<CreateConversationParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let model = params.model.clone().unwrap_or_else(|| "gpt-4".to_string());
|
||||
|
||||
let conversation = service
|
||||
.create_conversation(
|
||||
user_id,
|
||||
params.project_id,
|
||||
params.title.clone(),
|
||||
model,
|
||||
params.model_config.clone(),
|
||||
params.access_visibility.clone(),
|
||||
params.can_ask.clone(),
|
||||
params.model_uid,
|
||||
params.model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp = ConversationResponse::from(conversation);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations",
|
||||
operation_id = "ai_conversation_list",
|
||||
params(
|
||||
("project_id" = Option<Uuid>, Query, description = "Filter by project"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List of conversations", body = ApiResponse<Vec<ConversationResponse>>),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_list(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
query: web::Query<ConversationListQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
|
||||
let convs = service
|
||||
.list_conversations(user_id, query.project_id, 50)
|
||||
.await?;
|
||||
|
||||
let resp: Vec<ConversationResponse> = convs
|
||||
.into_iter()
|
||||
.map(ConversationResponse::from)
|
||||
.collect();
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}",
|
||||
operation_id = "ai_conversation_get",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Get conversation", body = ApiResponse<ConversationResponse>),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_get(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
let c = service
|
||||
.find_conversation_owned(conversation_id, user_id)
|
||||
.await?;
|
||||
|
||||
let resp = ConversationResponse::from(c);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/ai/conversations/{conversation_id}",
|
||||
operation_id = "ai_conversation_update",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
),
|
||||
request_body = super::types::UpdateConversationParams,
|
||||
responses(
|
||||
(status = 200, description = "Conversation updated"),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_update(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<Uuid>,
|
||||
params: web::Json<super::types::UpdateConversationParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
service
|
||||
.update_conversation(
|
||||
conversation_id,
|
||||
user_id,
|
||||
params.title.clone(),
|
||||
params.model.clone(),
|
||||
params.model_config.clone(),
|
||||
params.status.clone(),
|
||||
params.access_visibility.clone(),
|
||||
params.can_ask.clone(),
|
||||
params.model_uid,
|
||||
params.model_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(crate::api_success())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/ai/conversations/{conversation_id}",
|
||||
operation_id = "ai_conversation_delete",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Conversation deleted"),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_delete(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
service
|
||||
.delete_conversation(conversation_id, user_id)
|
||||
.await?;
|
||||
|
||||
Ok(crate::api_success())
|
||||
}
|
||||
102
libs/api/chat/handlers/fork.rs
Normal file
102
libs/api/chat/handlers/fork.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use session::Session;
|
||||
use service::error::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::ApiResponse;
|
||||
|
||||
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
|
||||
pub struct ForkResponse {
|
||||
pub id: Uuid,
|
||||
pub conversation_id: Option<Uuid>,
|
||||
pub source_message_id: Uuid,
|
||||
pub fork_message_id: Uuid,
|
||||
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/fork/{target_message_id}",
|
||||
operation_id = "ai_message_fork",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Source message ID"),
|
||||
("target_message_id" = Uuid, Path, description = "Target/fork message ID to create"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Fork created", body = ApiResponse<ForkResponse>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_fork(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session
|
||||
.user()
|
||||
.ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
|
||||
let (conversation_id, source_message_id, target_message_id) = path.into_inner();
|
||||
|
||||
let fork_record = service
|
||||
.fork_message(
|
||||
conversation_id,
|
||||
user_id,
|
||||
source_message_id,
|
||||
target_message_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp = ForkResponse {
|
||||
id: fork_record.id,
|
||||
conversation_id: fork_record.conversation_id,
|
||||
source_message_id: fork_record.source_message_id,
|
||||
fork_message_id: fork_record.fork_message_id,
|
||||
created_at: fork_record.created_at,
|
||||
};
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/forks",
|
||||
operation_id = "ai_message_forks",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Source message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List forks from message", body = ApiResponse<Vec<ForkResponse>>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_forks(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session
|
||||
.user()
|
||||
.ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
|
||||
let (conversation_id, source_message_id) = path.into_inner();
|
||||
|
||||
let forks = service
|
||||
.list_forks(conversation_id, user_id, source_message_id)
|
||||
.await?;
|
||||
|
||||
let resp: Vec<ForkResponse> = forks
|
||||
.into_iter()
|
||||
.map(|f| ForkResponse {
|
||||
id: f.id,
|
||||
conversation_id: f.conversation_id,
|
||||
source_message_id: f.source_message_id,
|
||||
fork_message_id: f.fork_message_id,
|
||||
created_at: f.created_at,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
346
libs/api/chat/handlers/message.rs
Normal file
346
libs/api/chat/handlers/message.rs
Normal file
@ -0,0 +1,346 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::ApiResponse;
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use session::Session;
|
||||
use service::error::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::types::{CreateMessageParams, EditMessageParams, MessageListQuery, MessageResponse};
|
||||
|
||||
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
|
||||
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages",
|
||||
operation_id = "ai_message_list",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("limit" = Option<i64>, Query, description = "Max messages"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List messages", body = ApiResponse<Vec<MessageResponse>>),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_list(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<Uuid>,
|
||||
query: web::Query<MessageListQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
let limit = query.limit.unwrap_or(50) as u64;
|
||||
let msgs = service
|
||||
.list_messages(conversation_id, user_id, limit)
|
||||
.await?;
|
||||
|
||||
let resp: Vec<MessageResponse> = msgs
|
||||
.into_iter()
|
||||
.map(MessageResponse::from)
|
||||
.collect();
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages",
|
||||
operation_id = "ai_message_create",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
),
|
||||
request_body = CreateMessageParams,
|
||||
responses(
|
||||
(status = 200, description = "Message created", body = ApiResponse<MessageResponse>),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_create(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<Uuid>,
|
||||
params: web::Json<CreateMessageParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
let msg = service
|
||||
.create_message(
|
||||
conversation_id,
|
||||
user_id,
|
||||
params.parent_message_id,
|
||||
params.content.role.clone(),
|
||||
params.content.content.clone(),
|
||||
params.model.clone(),
|
||||
params.is_fork_origin.unwrap_or(false),
|
||||
params.metadata.clone(),
|
||||
params.room_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp = MessageResponse::from(msg);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}",
|
||||
operation_id = "ai_message_get",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Get message", body = ApiResponse<MessageResponse>),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_get(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
let msg = service
|
||||
.get_message(conversation_id, user_id, message_id)
|
||||
.await?;
|
||||
|
||||
let resp = MessageResponse::from(msg);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/stop",
|
||||
operation_id = "ai_message_stop",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Message stopped"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_stop(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
service
|
||||
.stop_message(conversation_id, user_id, message_id)
|
||||
.await?;
|
||||
|
||||
Ok(crate::api_success())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/resend",
|
||||
operation_id = "ai_message_resend",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Resend message", body = ApiResponse<MessageResponse>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_resend(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
let new_msg = service
|
||||
.resend_message(conversation_id, user_id, message_id)
|
||||
.await?;
|
||||
|
||||
let resp = MessageResponse::from(new_msg);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/children",
|
||||
operation_id = "ai_message_children",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Parent message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List child messages", body = ApiResponse<Vec<MessageResponse>>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_children(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, parent_message_id) = path.into_inner();
|
||||
|
||||
let msgs = service
|
||||
.list_child_messages(conversation_id, user_id, parent_message_id)
|
||||
.await?;
|
||||
|
||||
let resp: Vec<MessageResponse> = msgs
|
||||
.into_iter()
|
||||
.map(MessageResponse::from)
|
||||
.collect();
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/stream",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "SSE stream"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_stream(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
// Verify user owns the conversation
|
||||
let conv = service
|
||||
.find_conversation_owned(conversation_id, user_id)
|
||||
.await?;
|
||||
|
||||
let model = conv.model;
|
||||
|
||||
let response = actix_web::HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.insert_header(("Cache-Control", "no-cache"))
|
||||
.insert_header(("X-Accel-Buffering", "no"))
|
||||
.streaming(super::super::stream::create_chat_sse_stream(
|
||||
service.get_ref().clone(),
|
||||
conversation_id,
|
||||
message_id,
|
||||
model,
|
||||
user_id,
|
||||
));
|
||||
|
||||
Ok(response.into())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/edit",
|
||||
operation_id = "ai_message_edit",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID to edit"),
|
||||
),
|
||||
request_body = EditMessageParams,
|
||||
responses(
|
||||
(status = 200, description = "Message edited, new version created", body = ApiResponse<MessageResponse>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_edit(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
params: web::Json<EditMessageParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
let new_msg = service
|
||||
.edit_message(conversation_id, user_id, message_id, params.content.clone())
|
||||
.await?;
|
||||
|
||||
let resp = MessageResponse::from(new_msg);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/versions",
|
||||
operation_id = "ai_message_versions",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "List message versions", body = ApiResponse<Vec<MessageResponse>>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_versions(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
let versions = service
|
||||
.list_message_versions(conversation_id, user_id, message_id)
|
||||
.await?;
|
||||
|
||||
let resp: Vec<MessageResponse> = versions
|
||||
.into_iter()
|
||||
.map(MessageResponse::from)
|
||||
.collect();
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/switch-version",
|
||||
operation_id = "ai_message_switch_version",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("message_id" = Uuid, Path, description = "Message ID"),
|
||||
),
|
||||
request_body = super::types::SwitchVersionParams,
|
||||
responses(
|
||||
(status = 200, description = "Version switched", body = ApiResponse<MessageResponse>),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn message_switch_version(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
params: web::Json<super::types::SwitchVersionParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let (conversation_id, message_id) = path.into_inner();
|
||||
|
||||
let msg = service
|
||||
.switch_message_version(conversation_id, user_id, message_id, params.version_number)
|
||||
.await?;
|
||||
|
||||
let resp = MessageResponse::from(msg);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
5
libs/api/chat/handlers/mod.rs
Normal file
5
libs/api/chat/handlers/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub mod conversation;
|
||||
pub mod fork;
|
||||
pub mod message;
|
||||
pub mod share;
|
||||
pub mod types;
|
||||
76
libs/api/chat/handlers/share.rs
Normal file
76
libs/api/chat/handlers/share.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use session::Session;
|
||||
use service::error::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::ApiResponse;
|
||||
|
||||
use super::types::{ConversationResponse, ShareResponse};
|
||||
|
||||
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
|
||||
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/ai/conversations/{conversation_id}/share",
|
||||
operation_id = "ai_conversation_share",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Share token created", body = ApiResponse<ShareResponse>),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_share(
|
||||
service: web::Data<service::AppService>,
|
||||
session: Session,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = get_user_id(&session)?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
let (share, share_token) = service
|
||||
.share_conversation(conversation_id, user_id)
|
||||
.await?;
|
||||
|
||||
let resp = ShareResponse {
|
||||
id: share.id,
|
||||
share_token,
|
||||
view_count: share.view_count,
|
||||
expires_at: share.expires_at,
|
||||
};
|
||||
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/share/{share_token}",
|
||||
operation_id = "ai_shared_conversation_get",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
("share_token" = String, Path, description = "Share token"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Get shared conversation", body = ApiResponse<ConversationResponse>),
|
||||
(status = 404, description = "Not found or expired"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn shared_conversation_get(
|
||||
service: web::Data<service::AppService>,
|
||||
path: web::Path<(Uuid, String)>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let (conversation_id, share_token) = path.into_inner();
|
||||
|
||||
let c = service
|
||||
.get_shared_conversation(conversation_id, share_token)
|
||||
.await?;
|
||||
|
||||
let resp = ConversationResponse::from(c);
|
||||
Ok(ApiResponse::ok(resp).to_response())
|
||||
}
|
||||
175
libs/api/chat/handlers/types.rs
Normal file
175
libs/api/chat/handlers/types.rs
Normal file
@ -0,0 +1,175 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateConversationParams {
|
||||
pub project_id: Option<Uuid>,
|
||||
pub title: Option<String>,
|
||||
pub model: Option<String>,
|
||||
pub model_config: Option<serde_json::Value>,
|
||||
pub access_visibility: Option<String>,
|
||||
pub can_ask: Option<String>,
|
||||
/// AI model UUID for model selection
|
||||
pub model_uid: Option<Uuid>,
|
||||
/// AI model display name
|
||||
pub model_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct ConversationResponse {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub project_id: Option<Uuid>,
|
||||
pub scope: String,
|
||||
pub title: Option<String>,
|
||||
pub model: String,
|
||||
pub model_config: Option<serde_json::Value>,
|
||||
pub status: String,
|
||||
pub root_message_id: Option<Uuid>,
|
||||
pub fork_count: i32,
|
||||
pub is_shared: bool,
|
||||
pub message_count: i32,
|
||||
pub token_usage_total: Option<i32>,
|
||||
pub access_visibility: String,
|
||||
pub can_ask: String,
|
||||
pub project_uid: Option<i32>,
|
||||
pub model_uid: Option<Uuid>,
|
||||
pub model_name: Option<String>,
|
||||
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateConversationParams {
|
||||
pub title: Option<String>,
|
||||
pub model: Option<String>,
|
||||
pub model_config: Option<serde_json::Value>,
|
||||
pub status: Option<String>,
|
||||
pub access_visibility: Option<String>,
|
||||
pub can_ask: Option<String>,
|
||||
pub model_uid: Option<Uuid>,
|
||||
pub model_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ConversationListQuery {
|
||||
pub project_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct MessageContent {
|
||||
pub role: String,
|
||||
pub content: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateMessageParams {
|
||||
pub parent_message_id: Option<Uuid>,
|
||||
pub content: MessageContent,
|
||||
pub model: Option<String>,
|
||||
pub is_fork_origin: Option<bool>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub room_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct MessageResponse {
|
||||
pub id: Uuid,
|
||||
pub conversation_id: Uuid,
|
||||
pub parent_message_id: Option<Uuid>,
|
||||
pub role: String,
|
||||
pub content: serde_json::Value,
|
||||
pub model: Option<String>,
|
||||
pub is_fork_origin: bool,
|
||||
pub stop_reason: Option<String>,
|
||||
pub input_tokens: Option<i32>,
|
||||
pub output_tokens: Option<i32>,
|
||||
pub latency_ms: Option<i32>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub room_id: Option<Uuid>,
|
||||
pub version_group_id: Option<Uuid>,
|
||||
pub version_number: i32,
|
||||
pub is_latest: bool,
|
||||
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MessageListQuery {
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct EditMessageParams {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct SwitchVersionParams {
|
||||
pub version_number: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ForkParams {}
|
||||
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct ShareResponse {
|
||||
pub id: Uuid,
|
||||
pub share_token: String,
|
||||
pub view_count: i32,
|
||||
#[schema(value_type = Option<chrono::DateTime<chrono::Utc>>)]
|
||||
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
impl From<models::ai::ai_conversation::Model> for ConversationResponse {
|
||||
fn from(c: models::ai::ai_conversation::Model) -> Self {
|
||||
Self {
|
||||
id: c.id,
|
||||
user_id: c.user_id,
|
||||
project_id: c.project_id,
|
||||
scope: c.scope,
|
||||
title: c.title,
|
||||
model: c.model,
|
||||
model_config: c.model_config,
|
||||
status: c.status,
|
||||
root_message_id: c.root_message_id,
|
||||
fork_count: c.fork_count,
|
||||
is_shared: c.is_shared,
|
||||
message_count: c.message_count,
|
||||
token_usage_total: c.token_usage_total,
|
||||
access_visibility: c.access_visibility,
|
||||
can_ask: c.can_ask,
|
||||
project_uid: c.project_uid,
|
||||
model_uid: c.model_uid,
|
||||
model_name: c.model_name,
|
||||
created_at: c.created_at,
|
||||
updated_at: c.updated_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<models::ai::ai_message::Model> for MessageResponse {
|
||||
fn from(m: models::ai::ai_message::Model) -> Self {
|
||||
Self {
|
||||
id: m.id,
|
||||
conversation_id: m.conversation_id,
|
||||
parent_message_id: m.parent_message_id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
model: m.model,
|
||||
is_fork_origin: m.is_fork_origin,
|
||||
stop_reason: m.stop_reason,
|
||||
input_tokens: m.input_tokens,
|
||||
output_tokens: m.output_tokens,
|
||||
latency_ms: m.latency_ms,
|
||||
metadata: m.metadata,
|
||||
room_id: m.room_id,
|
||||
version_group_id: m.version_group_id,
|
||||
version_number: m.version_number,
|
||||
is_latest: m.is_latest,
|
||||
created_at: m.created_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
85
libs/api/chat/mod.rs
Normal file
85
libs/api/chat/mod.rs
Normal file
@ -0,0 +1,85 @@
|
||||
use actix_web::web;
|
||||
|
||||
pub mod handlers;
|
||||
pub mod stream;
|
||||
pub mod watch;
|
||||
|
||||
pub fn init_chat_routes(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(
|
||||
web::scope("/ai/conversations")
|
||||
.route("", web::post().to(handlers::conversation::conversation_create))
|
||||
.route("", web::get().to(handlers::conversation::conversation_list))
|
||||
.route(
|
||||
"/{conversation_id}",
|
||||
web::get().to(handlers::conversation::conversation_get),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}",
|
||||
web::patch().to(handlers::conversation::conversation_update),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}",
|
||||
web::delete().to(handlers::conversation::conversation_delete),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/watch",
|
||||
web::get().to(watch::conversation_watch),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/share",
|
||||
web::post().to(handlers::share::conversation_share),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/share/{share_token}",
|
||||
web::get().to(handlers::share::shared_conversation_get),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages",
|
||||
web::get().to(handlers::message::message_list),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages",
|
||||
web::post().to(handlers::message::message_create),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}",
|
||||
web::get().to(handlers::message::message_get),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/stop",
|
||||
web::post().to(handlers::message::message_stop),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/resend",
|
||||
web::post().to(handlers::message::message_resend),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/fork/{target_message_id}",
|
||||
web::post().to(handlers::fork::message_fork),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/forks",
|
||||
web::get().to(handlers::fork::message_forks),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/stream",
|
||||
web::get().to(handlers::message::message_stream),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/children",
|
||||
web::get().to(handlers::message::message_children),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/edit",
|
||||
web::post().to(handlers::message::message_edit),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/versions",
|
||||
web::get().to(handlers::message::message_versions),
|
||||
)
|
||||
.route(
|
||||
"/{conversation_id}/messages/{message_id}/switch-version",
|
||||
web::post().to(handlers::message::message_switch_version),
|
||||
),
|
||||
);
|
||||
}
|
||||
463
libs/api/chat/stream.rs
Normal file
463
libs/api/chat/stream.rs
Normal file
@ -0,0 +1,463 @@
|
||||
use agent::chat::chat_execution;
|
||||
use agent::chat::{normalize_thinking_content, AiChunkType, AiStreamChunk};
|
||||
use agent::client::AiClientConfig;
|
||||
use agent::client::types::ChatRequestMessage;
|
||||
use agent::client::StreamChunkType;
|
||||
use futures::StreamExt;
|
||||
use models::ai::{ai_message, ai_conversation, AiMessage};
|
||||
use queue::{ChatMessageEvent, ChatStreamChunkEvent};
|
||||
use sea_orm::{EntityTrait, QueryFilter, ColumnTrait, QueryOrder, ActiveModelTrait, Set, PaginatorTrait};
|
||||
use service::AppService;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Create an SSE stream that executes AI chat with ReAct tool-calling.
|
||||
///
|
||||
/// Also publishes chat messages and stream chunks via NATS JetStream for
|
||||
/// multi-viewer support. The requesting client receives SSE events, while
|
||||
/// other viewers receive chunks via NATS → WebSocket broadcast.
|
||||
pub fn create_chat_sse_stream(
|
||||
service: AppService,
|
||||
conversation_id: Uuid,
|
||||
user_message_id: Uuid,
|
||||
model_name: String,
|
||||
user_id: Uuid,
|
||||
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
|
||||
|
||||
let cache = service.cache.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Check for active stream (SSE reconnect recovery) BEFORE starting a new one
|
||||
// so the frontend can recover from a page refresh.
|
||||
if let Some((msg_id, started_at)) = cache.get_chat_stream_active(conversation_id).await {
|
||||
let _ = tx.send(format!(
|
||||
"data: {{\"event\":\"recovery\",\"data\":{{\"message_id\":\"{}\",\"started_at\":{}}}}}\n\n",
|
||||
msg_id,
|
||||
started_at
|
||||
)).await;
|
||||
}
|
||||
|
||||
let queue = service.queue_producer.clone();
|
||||
let chunk_seq = Arc::new(AtomicU64::new(0));
|
||||
|
||||
// Build messages from conversation history
|
||||
let messages = match build_messages_from_history(&service, conversation_id).await {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Get AI config
|
||||
let api_key = match service.config.ai_api_key() {
|
||||
Ok(k) => k,
|
||||
Err(_) => {
|
||||
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let base_url = match service.config.ai_basic_url() {
|
||||
Ok(u) => u,
|
||||
Err(_) => {
|
||||
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let config = AiClientConfig::new(api_key).with_base_url(&base_url);
|
||||
|
||||
// Get tools from ChatService if available
|
||||
let (tools, tool_registry, embed_service) = match &service.chat_service {
|
||||
Some(cs) => (
|
||||
cs.tools(),
|
||||
cs.tool_registry().cloned(),
|
||||
service.embed_service.as_ref().map(|es| (**es).clone()),
|
||||
),
|
||||
None => (Vec::new(), None, None),
|
||||
};
|
||||
|
||||
// Get project_id from conversation
|
||||
let project_id = match service.find_conversation(conversation_id).await {
|
||||
Ok(c) => c.project_id.unwrap_or(Uuid::nil()),
|
||||
Err(_) => {
|
||||
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"conversation not found\"}\n\n".to_string()).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Pre-flight balance check: verify project + user can afford at least a minimal AI call
|
||||
let balance_ok = agent::billing::check_balance(
|
||||
&service.db, project_id, user_id, Uuid::nil(), 500, 250,
|
||||
).await;
|
||||
|
||||
match balance_ok {
|
||||
Ok(true) => {},
|
||||
Ok(false) => {
|
||||
tracing::warn!(project_id = %project_id, user_id = %user_id, "Insufficient balance for chat AI call");
|
||||
|
||||
let _ = agent::billing::persist_billing_error(
|
||||
&service.db, "user", user_id, "insufficient_balance",
|
||||
&format!("Insufficient balance. Your account does not have enough funds for this AI request."),
|
||||
Some(serde_json::json!({
|
||||
"user_id": user_id.to_string(),
|
||||
"project_id": project_id.to_string(),
|
||||
})),
|
||||
).await;
|
||||
|
||||
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
|
||||
let _ = tx.send(format!("data: {{\"event\":\"billing_error\",\"data\":\"{}\"}}\n\n", error_msg)).await;
|
||||
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string()).await;
|
||||
return;
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Balance check failed, proceeding without pre-flight check");
|
||||
}
|
||||
}
|
||||
|
||||
let max_tool_depth = 99;
|
||||
|
||||
// Determine conversation project_id for chat message event
|
||||
let conv_project_id = match service.find_conversation(conversation_id).await {
|
||||
Ok(c) => c.project_id,
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
// Broadcast chat message start event via NATS
|
||||
let chat_msg = ChatMessageEvent {
|
||||
message_id: user_message_id,
|
||||
conversation_id,
|
||||
project_id: conv_project_id,
|
||||
sender_id: Uuid::nil(),
|
||||
role: "assistant".to_string(),
|
||||
content: String::new(),
|
||||
model: Some(model_name.clone()),
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
let _ = queue.publish_chat_message(&chat_msg).await;
|
||||
|
||||
// Mark stream as active in Redis so page refresh can recover
|
||||
let _ = cache.set_chat_stream_active(conversation_id, user_message_id).await;
|
||||
|
||||
let on_chunk_tx = tx.clone();
|
||||
let on_chunk_queue = queue.clone();
|
||||
let on_chunk_seq = chunk_seq.clone();
|
||||
let on_chunk_conv_id = conversation_id;
|
||||
let on_chunk_msg_id = user_message_id;
|
||||
let on_chunk_model = model_name.clone();
|
||||
|
||||
let on_chunk: agent::chat::StreamCallback = Box::new(move |chunk: AiStreamChunk| {
|
||||
let tx = on_chunk_tx.clone();
|
||||
let queue = on_chunk_queue.clone();
|
||||
let seq = on_chunk_seq.fetch_add(1, Ordering::Relaxed);
|
||||
let conv_id = on_chunk_conv_id;
|
||||
let msg_id = on_chunk_msg_id;
|
||||
let model = on_chunk_model.clone();
|
||||
Box::pin(async move {
|
||||
let event = match chunk.chunk_type {
|
||||
AiChunkType::Thinking => "thinking",
|
||||
AiChunkType::Answer => "token",
|
||||
AiChunkType::ToolCall => "tool_call",
|
||||
AiChunkType::ToolResult => "tool_result",
|
||||
};
|
||||
let content = match chunk.chunk_type {
|
||||
AiChunkType::Thinking => normalize_thinking_content(&chunk.content),
|
||||
_ => chunk.content.clone(),
|
||||
};
|
||||
let sse = format!(
|
||||
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
|
||||
event,
|
||||
serde_json::to_string(&content).unwrap_or_default()
|
||||
);
|
||||
let _ = tx.send(sse).await;
|
||||
|
||||
// Also broadcast via NATS for other viewers
|
||||
let natts_chunk = ChatStreamChunkEvent {
|
||||
conversation_id: conv_id,
|
||||
message_id: msg_id,
|
||||
seq,
|
||||
content,
|
||||
done: false,
|
||||
error: None,
|
||||
chunk_type: Some(event.to_string()),
|
||||
model_name: Some(model),
|
||||
};
|
||||
queue.publish_chat_chunk(&natts_chunk).await;
|
||||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||||
});
|
||||
|
||||
let result = chat_execution::execute_chat_stream(
|
||||
messages,
|
||||
tools,
|
||||
&model_name,
|
||||
&config,
|
||||
0.7, // temperature
|
||||
4096, // max_tokens
|
||||
max_tool_depth,
|
||||
tool_registry.as_ref(),
|
||||
service.db.clone(),
|
||||
service.cache.clone(),
|
||||
service.config.clone(),
|
||||
project_id,
|
||||
Uuid::nil(), // sender_uid — unknown in Chat API context
|
||||
embed_service,
|
||||
on_chunk,
|
||||
Some(conversation_id),
|
||||
).await;
|
||||
|
||||
// Clear stream active state (streaming finished)
|
||||
let _ = cache.clear_chat_stream_active(conversation_id).await;
|
||||
|
||||
match result {
|
||||
Ok(stream_result) => {
|
||||
// Build ordered content blocks from stream chunks, merging
|
||||
// consecutive blocks of the same role (thinking/assistant).
|
||||
let raw_blocks: Vec<(String, String)> = stream_result.chunks.iter()
|
||||
.filter(|c| matches!(c.chunk_type, StreamChunkType::Thinking | StreamChunkType::Answer))
|
||||
.map(|chunk| {
|
||||
let role = match chunk.chunk_type {
|
||||
StreamChunkType::Thinking => "thinking",
|
||||
_ => "assistant",
|
||||
};
|
||||
(role.to_string(), chunk.content.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
let merged_blocks = merge_consecutive_blocks(raw_blocks);
|
||||
// Apply thinking normalization to the fully merged thinking
|
||||
// blocks — per-token normalization is meaningless since each
|
||||
// chunk is a single token.
|
||||
let normalized_blocks: Vec<(String, String)> = merged_blocks.into_iter().map(|(role, content)| {
|
||||
if role == "thinking" {
|
||||
(role, normalize_thinking_content(&content))
|
||||
} else {
|
||||
(role, content)
|
||||
}
|
||||
}).collect();
|
||||
let content_blocks: Vec<serde_json::Value> = normalized_blocks.iter()
|
||||
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
|
||||
.collect();
|
||||
let content_value = if content_blocks.is_empty() {
|
||||
serde_json::json!([{ "role": "assistant", "content": stream_result.content }])
|
||||
} else {
|
||||
serde_json::json!(content_blocks)
|
||||
};
|
||||
|
||||
// Persist assistant message
|
||||
let assistant_msg_id = Uuid::now_v7();
|
||||
let assistant_msg = ai_message::ActiveModel {
|
||||
id: Set(assistant_msg_id),
|
||||
conversation_id: Set(conversation_id),
|
||||
parent_message_id: Set(Some(user_message_id)),
|
||||
role: Set("assistant".to_string()),
|
||||
content: Set(content_value),
|
||||
model: Set(Some(model_name.clone())),
|
||||
is_fork_origin: Set(false),
|
||||
stop_reason: Set(Some("stop".to_string())),
|
||||
input_tokens: Set(Some(stream_result.input_tokens as i32)),
|
||||
output_tokens: Set(Some(stream_result.output_tokens as i32)),
|
||||
latency_ms: Set(None),
|
||||
metadata: Set(None),
|
||||
room_id: Set(None),
|
||||
version_group_id: Set(Some(assistant_msg_id)),
|
||||
version_number: Set(1),
|
||||
is_latest: Set(true),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
};
|
||||
|
||||
let saved = assistant_msg.insert(service.db.writer()).await;
|
||||
|
||||
if let Ok(msg) = &saved {
|
||||
update_conversation_after_response(&service, conversation_id, msg).await;
|
||||
|
||||
// After AI response, check/update conversation title and emit via SSE
|
||||
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
|
||||
.one(service.db.reader()).await
|
||||
{
|
||||
let existing_title = conv.title.clone();
|
||||
let needs_title = existing_title.as_deref().map(|t| t.is_empty() || t == "New Chat").unwrap_or(true);
|
||||
|
||||
if needs_title {
|
||||
// Generate title from first user message
|
||||
let first_user_msg = AiMessage::find()
|
||||
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
||||
.filter(ai_message::Column::Role.eq("user"))
|
||||
.order_by_asc(ai_message::Column::CreatedAt)
|
||||
.one(service.db.reader()).await.ok().flatten();
|
||||
|
||||
if let Some(user_msg) = first_user_msg {
|
||||
let content = match &user_msg.content {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(arr) => {
|
||||
arr.first()
|
||||
.and_then(|f| f.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
}
|
||||
other => other.to_string(),
|
||||
};
|
||||
|
||||
// Simple title extraction: first meaningful words
|
||||
let title = content
|
||||
.split_whitespace()
|
||||
.filter(|w| w.len() > 2)
|
||||
.take(5)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
if !title.is_empty() {
|
||||
let truncated: String = title.chars().take(40).collect();
|
||||
|
||||
// Save title to DB
|
||||
let mut active: ai_conversation::ActiveModel = conv.into();
|
||||
active.title = Set(Some(truncated.clone()));
|
||||
active.updated_at = Set(chrono::Utc::now());
|
||||
let _ = active.update(service.db.writer()).await;
|
||||
|
||||
// Emit title via SSE
|
||||
let title_payload = serde_json::json!({"title": truncated}).to_string();
|
||||
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
|
||||
}
|
||||
}
|
||||
} else if let Some(title) = &existing_title {
|
||||
// Title already set (e.g. by AI tool) — emit it
|
||||
let title_payload = serde_json::json!({"title": title}).to_string();
|
||||
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast final chat message with token usage
|
||||
let final_msg = ChatMessageEvent {
|
||||
message_id: user_message_id,
|
||||
conversation_id,
|
||||
project_id: conv_project_id,
|
||||
sender_id: Uuid::nil(),
|
||||
role: "assistant".to_string(),
|
||||
content: stream_result.content.clone(),
|
||||
model: Some(model_name.clone()),
|
||||
input_tokens: Some(stream_result.input_tokens as i32),
|
||||
output_tokens: Some(stream_result.output_tokens as i32),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
let _ = queue.publish_chat_message(&final_msg).await;
|
||||
|
||||
// Send final SSE done event
|
||||
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"ok\"}\n\n".to_string()).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Box::pin(ReceiverStream::new(rx).map(|msg| Ok(actix_web::web::Bytes::from(msg))))
|
||||
}
|
||||
|
||||
/// Update conversation metadata after an AI assistant message is saved.
|
||||
async fn update_conversation_after_response(
|
||||
service: &AppService,
|
||||
conversation_id: Uuid,
|
||||
assistant_msg: &ai_message::Model,
|
||||
) {
|
||||
use models::ai::ai_conversation;
|
||||
use sea_orm::EntityTrait;
|
||||
|
||||
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
|
||||
.one(service.db.reader()).await
|
||||
{
|
||||
let input_tokens = assistant_msg.input_tokens.unwrap_or(0) as i64;
|
||||
let output_tokens = assistant_msg.output_tokens.unwrap_or(0) as i64;
|
||||
let total_tokens = input_tokens + output_tokens;
|
||||
|
||||
let mut active: ai_conversation::ActiveModel = conv.into();
|
||||
if let Ok(count) = AiMessage::find()
|
||||
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
||||
.count(service.db.reader()).await
|
||||
{
|
||||
active.message_count = Set(count as i32);
|
||||
}
|
||||
active.token_usage_total = Set(Some(total_tokens as i32));
|
||||
active.updated_at = Set(chrono::Utc::now());
|
||||
let _ = active.update(service.db.writer()).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Build ChatRequestMessage list from ai_message conversation history.
|
||||
async fn build_messages_from_history(
|
||||
service: &AppService,
|
||||
conversation_id: Uuid,
|
||||
) -> Result<Vec<ChatRequestMessage>, String> {
|
||||
let msgs = AiMessage::find()
|
||||
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
||||
.filter(ai_message::Column::IsLatest.eq(true))
|
||||
.order_by_asc(ai_message::Column::CreatedAt)
|
||||
.all(service.db.reader())
|
||||
.await
|
||||
.map_err(|e| format!("db error: {}", e))?;
|
||||
|
||||
let mut chat_messages = Vec::new();
|
||||
|
||||
for msg in &msgs {
|
||||
let role = msg.role.as_str();
|
||||
let content = match &msg.content {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(arr) => {
|
||||
// Content is ordered blocks: [{role:"thinking",content:"..."}, {role:"assistant","content":"..."}, ...]
|
||||
// For assistant messages: concatenate all "assistant" blocks
|
||||
// For user/system messages: take the first block's content
|
||||
if role == "assistant" {
|
||||
arr.iter()
|
||||
.filter(|item| item.get("role").and_then(|r| r.as_str()) != Some("thinking"))
|
||||
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
} else if let Some(first) = arr.first() {
|
||||
first.get("content")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
other => other.to_string(),
|
||||
};
|
||||
|
||||
match role {
|
||||
"user" => chat_messages.push(ChatRequestMessage::user(content)),
|
||||
"assistant" => chat_messages.push(ChatRequestMessage::assistant(Some(content), None)),
|
||||
"system" => chat_messages.push(ChatRequestMessage::system(content)),
|
||||
_ => chat_messages.push(ChatRequestMessage::user(content)),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chat_messages)
|
||||
}
|
||||
|
||||
/// Merge consecutive content blocks of the same role into single blocks.
|
||||
/// This transforms many small per-chunk blocks into clean interleaved segments:
|
||||
/// [thinking, thinking, assistant, assistant] → [thinking, assistant]
|
||||
/// Per-token chunks are concatenated directly — the model sends \n inside
|
||||
/// the token content where needed, not between tokens.
|
||||
fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, String)> {
|
||||
let mut merged: Vec<(String, String)> = Vec::new();
|
||||
for (role, content) in blocks {
|
||||
if content.is_empty() { continue; }
|
||||
if let Some(last) = merged.last_mut() {
|
||||
if last.0 == role {
|
||||
last.1.push_str(&content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
merged.push((role, content));
|
||||
}
|
||||
merged
|
||||
}
|
||||
158
libs/api/chat/watch.rs
Normal file
158
libs/api/chat/watch.rs
Normal file
@ -0,0 +1,158 @@
|
||||
//! SSE endpoint for watching a chat conversation in real-time via NATS.
|
||||
//!
|
||||
//! Unlike the primary SSE stream (which triggers AI execution), this endpoint
|
||||
//! passively subscribes to NATS Core subjects and forwards chat messages and
|
||||
//! stream chunks to connected clients. This enables multiple viewers to watch
|
||||
//! the same AI conversation in real-time.
|
||||
|
||||
use actix_web::{web, HttpResponse, Result};
|
||||
use futures::StreamExt;
|
||||
use service::AppService;
|
||||
use std::pin::Pin;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
/// SSE endpoint for watching a chat conversation.
|
||||
///
|
||||
/// `GET /api/ai/conversations/{conversation_id}/watch`
|
||||
///
|
||||
/// Subscribes to NATS Core subjects (`chat.chunk.{id}` and `chat.message.{id}`)
|
||||
/// and forwards received events as SSE to the connected client.
|
||||
///
|
||||
/// SSE events:
|
||||
/// - `chunk` — a stream chunk (thinking, token, tool_call, tool_result, done, error)
|
||||
/// - `message` — a complete chat message
|
||||
/// - `error` — an error event
|
||||
pub fn create_watch_sse_stream(
|
||||
service: AppService,
|
||||
conversation_id: Uuid,
|
||||
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<String>(200);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let nats = match &service.queue_producer.nats {
|
||||
Some(n) => n.clone(),
|
||||
None => {
|
||||
let _ = tx.send(format!(
|
||||
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
|
||||
serde_json::to_string("NATS not available").unwrap_or_default()
|
||||
)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Subscribe to chat chunks
|
||||
let chunk_subject = format!("chat.chunk.{}", conversation_id);
|
||||
let mut chunk_sub = match nats.subscribe(&chunk_subject).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
let _ = tx.send(format!(
|
||||
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
|
||||
serde_json::to_string(&e.to_string()).unwrap_or_default()
|
||||
)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Subscribe to chat messages
|
||||
let msg_subject = format!("chat.message.{}", conversation_id);
|
||||
let mut msg_sub = match nats.subscribe(&msg_subject).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
let _ = tx.send(format!(
|
||||
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
|
||||
serde_json::to_string(&e.to_string()).unwrap_or_default()
|
||||
)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let _ = tx.send(":ok\n\n".to_string()).await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
chunk_msg = chunk_sub.next() => {
|
||||
match chunk_msg {
|
||||
Some(msg) => {
|
||||
let payload = String::from_utf8_lossy(&msg.payload);
|
||||
// Parse to get chunk_type for the event field
|
||||
let event_type = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&payload) {
|
||||
parsed.get("chunk_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("chunk")
|
||||
.to_string()
|
||||
} else {
|
||||
"chunk".to_string()
|
||||
};
|
||||
let sse = format!(
|
||||
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
|
||||
event_type, payload
|
||||
);
|
||||
if tx.send(sse).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
msg = msg_sub.next() => {
|
||||
match msg {
|
||||
Some(msg) => {
|
||||
let payload = String::from_utf8_lossy(&msg.payload);
|
||||
let sse = format!(
|
||||
"data: {{\"event\":\"message\",\"data\":{}}}\n\n",
|
||||
payload
|
||||
);
|
||||
if tx.send(sse).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx).map(|s| {
|
||||
Ok(actix_web::web::Bytes::from(s))
|
||||
}))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/ai/conversations/{conversation_id}/watch",
|
||||
params(
|
||||
("conversation_id" = Uuid, Path, description = "Conversation ID"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "SSE stream of conversation events"),
|
||||
(status = 404, description = "Not found"),
|
||||
),
|
||||
tag = "AI Chat"
|
||||
)]
|
||||
pub async fn conversation_watch(
|
||||
service: web::Data<AppService>,
|
||||
session: session::Session,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or_else(|| ApiError::from(service::error::AppError::Unauthorized))?;
|
||||
let conversation_id = path.into_inner();
|
||||
|
||||
// Verify access (view-only is sufficient)
|
||||
let _conv = service
|
||||
.find_conversation_owned(conversation_id, user_id)
|
||||
.await?;
|
||||
|
||||
let response = HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.insert_header(("Cache-Control", "no-cache"))
|
||||
.insert_header(("X-Accel-Buffering", "no"))
|
||||
.streaming(create_watch_sse_stream(
|
||||
service.get_ref().clone(),
|
||||
conversation_id,
|
||||
));
|
||||
|
||||
Ok(response.into())
|
||||
}
|
||||
@ -116,20 +116,30 @@ pub async fn serve_frontend(req: HttpRequest, path: web::Path<String>) -> HttpRe
|
||||
};
|
||||
let cc = cache_control_header(path_str);
|
||||
|
||||
// Try brotli first (best compression), then gzip, then uncompressed.
|
||||
// Only serve compressed variant if client explicitly accepts it AND we have one.
|
||||
let (data, encoding, etag, content_path) =
|
||||
match frontend::get_frontend_asset_compressed(path_str) {
|
||||
Some(r) => (r.0, r.1, r.2, path_str),
|
||||
None => {
|
||||
// Path not found — try index.html as SPA fallback.
|
||||
// Also use "index.html" for Content-Type detection (text/html).
|
||||
match frontend::get_frontend_asset_with_etag("index.html") {
|
||||
Some((data, etag)) => (data, "", etag, "index.html"),
|
||||
None => return HttpResponse::NotFound().finish(),
|
||||
}
|
||||
}
|
||||
};
|
||||
// Try brotli/gzip compressed variant first (best compression),
|
||||
// then fall back to uncompressed if client doesn't accept the encoding.
|
||||
let compressed = crate::frontend::get_frontend_asset_compressed(path_str);
|
||||
let uncompressed = crate::frontend::get_frontend_asset_with_etag(path_str);
|
||||
|
||||
let (data, encoding, etag, content_path) = if let Some((c_data, c_enc, c_etag)) = compressed {
|
||||
if accepts_encoding(&req, c_enc) {
|
||||
(c_data, c_enc, c_etag, path_str)
|
||||
} else if let Some((u_data, u_etag)) = uncompressed {
|
||||
// Client doesn't accept the pre-compressed encoding — serve uncompressed.
|
||||
(u_data, "", u_etag, path_str)
|
||||
} else {
|
||||
// No uncompressed fallback — still serve compressed (client must handle it).
|
||||
(c_data, c_enc, c_etag, path_str)
|
||||
}
|
||||
} else if let Some((data, etag)) = uncompressed {
|
||||
(data, "", etag, path_str)
|
||||
} else {
|
||||
// Path not found — try index.html as SPA fallback.
|
||||
match crate::frontend::get_frontend_asset_with_etag("index.html") {
|
||||
Some((data, etag)) => (data, "", etag, "index.html"),
|
||||
None => return HttpResponse::NotFound().finish(),
|
||||
}
|
||||
};
|
||||
|
||||
if !encoding.is_empty() && accepts_encoding(&req, &encoding) {
|
||||
build_asset_response(&req, data, etag, content_path, cc, &encoding)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user