gitdataai/libs/git/ssh/mod.rs
ZhenYi 0703816482 feat(ssh): enhance ref update handling and add push queue
- Add post-receive refs tracking with mutex-protected storage
- Improve branch protection error messages with actionable guidance
- Add push queue slot waiting mechanism for concurrent push control
- Support for checking push queue availability before push operations
2026-05-15 11:48:40 +08:00

499 lines
16 KiB
Rust

use crate::error::GitError;
use crate::hook::pool::types::{HookTask, TaskType};
use anyhow::Context;
use argon2::Argon2;
use argon2::password_hash::{PasswordHash, PasswordVerifier};
use config::AppConfig;
use db::cache::AppCache;
use db::database::AppDatabase;
use deadpool_redis::cluster::Pool as RedisPool;
use models::users::{user, user_token};
use redis::AsyncCommands;
use russh::keys::PrivateKey;
use russh::server::Server;
use russh::{MethodKind, MethodSet, SshId, server::Config};
use sea_orm::prelude::*;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
pub mod authz;
pub mod branch_protect;
pub mod forward;
pub mod git_service;
pub mod handle;
pub mod push_queue;
pub mod rate_limit;
pub mod ref_update;
pub mod server;
#[derive(Clone)]
pub struct SSHHandle {
pub db: AppDatabase,
pub app: AppConfig,
pub cache: AppCache,
pub redis_pool: RedisPool,
}
impl SSHHandle {
pub async fn run(&self) {
let this = self.clone();
tokio::spawn(async move {
if let Err(e) = this.run_ssh().await {
tracing::error!("SSH server error: {}", e);
}
});
}
pub fn new(db: AppDatabase, app: AppConfig, cache: AppCache, redis_pool: RedisPool) -> Self {
SSHHandle {
db,
app,
cache,
redis_pool,
}
}
pub async fn run_ssh(&self) -> anyhow::Result<()> {
tracing::info!("SSH server starting");
let private_key_content = self.app.ssh_server_private_key()?;
if private_key_content.is_empty() {
return Err(anyhow::anyhow!("SSH server private key is not configured"));
}
tracing::info!(
"Loading SSH private key (hex, {} bytes)",
private_key_content.len()
);
let private_key_bytes = hex::decode(&private_key_content).with_context(|| {
format!(
"Failed to decode hex-encoded SSH private key ({} bytes)",
private_key_content.len()
)
})?;
tracing::info!("Hex decoded to {} bytes", private_key_bytes.len());
let private_key_pem = std::str::from_utf8(&private_key_bytes)
.with_context(|| "Decoded SSH private key is not valid UTF-8")?;
if let Some(first_line) = private_key_pem.lines().next() {
tracing::info!("PEM header: {}", first_line);
}
// Do NOT log the full private key content — that would be a severe security leak
let private_key = {
match ssh_key::PrivateKey::from_openssh(private_key_pem) {
Ok(ssh_key) => {
tracing::info!("Successfully parsed with ssh-key crate");
let openssh_pem = ssh_key
.to_openssh(ssh_key::LineEnding::LF)
.with_context(|| "Failed to serialize to OpenSSH format")?;
PrivateKey::from_str(&openssh_pem)
.with_context(|| "Failed to parse with russh after ssh-key conversion")?
}
Err(e) => {
tracing::info!(
"ssh-key from_openssh failed: {}, trying direct russh parse",
e
);
PrivateKey::from_str(private_key_pem).with_context(|| {
format!("Failed to parse SSH private key with both methods")
})?
}
}
};
tracing::info!("SSH private key loaded");
let mut config = Config::default();
config.keys = vec![private_key];
let version = format!("SSH-2.0-GitdataAI {}", env!("CARGO_PKG_VERSION"));
config.server_id = SshId::Standard(version.into());
let mut method = MethodSet::empty();
method.push(MethodKind::PublicKey);
method.push(MethodKind::KeyboardInteractive);
config.methods = method;
config.inactivity_timeout = Some(Duration::from_secs(300));
config.keepalive_interval = Some(Duration::from_secs(60));
config.keepalive_max = 3;
tracing::info!("SSH server configured with methods: {:?}", config.methods);
let token_service = SshTokenService::new(self.db.clone());
let mut server = server::SSHServer::new(
self.db.clone(),
self.cache.clone(),
self.redis_pool.clone(),
token_service,
);
// Start the rate limiter cleanup background task so the HashMap
// doesn't grow unbounded over time.
let _cleanup = server.rate_limiter.clone().start_cleanup();
let ssh_port = self.app.ssh_port()?;
let bind_addr = format!("0.0.0.0:{}", ssh_port);
let public_host = self.app.ssh_domain()?;
let msg = if ssh_port == 22 {
format!(
"SSH server listening on port 22. Please use port {} for SSH connections.",
ssh_port
)
} else {
format!(
"SSH server listening on port {} (public: {}). Please use port {} for SSH connections.",
ssh_port, public_host, ssh_port
)
};
tracing::info!("{}", msg);
server.run_on_address(Arc::new(config), bind_addr).await?;
Ok(())
}
}
/// Enqueues a sync task to the Redis-backed hook queue.
/// The background worker picks it up and processes it with per-repo locking.
#[derive(Clone)]
pub struct ReceiveSyncService {
pool: deadpool_redis::cluster::Pool,
redis_prefix: String,
/// Optional NATS publish function: (subject, payload) -> Result<sequence, error>
nats_publish: Option<
Arc<
dyn Fn(
String,
Vec<u8>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = anyhow::Result<u64>> + Send>,
> + Send
+ Sync,
>,
>,
}
impl ReceiveSyncService {
pub fn new(pool: deadpool_redis::cluster::Pool) -> Self {
Self {
pool,
redis_prefix: "{hook}".to_string(),
nats_publish: None,
}
}
pub fn with_nats(
pool: deadpool_redis::cluster::Pool,
nats_publish: Arc<
dyn Fn(
String,
Vec<u8>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = anyhow::Result<u64>> + Send>,
> + Send
+ Sync,
>,
) -> Self {
Self {
pool,
redis_prefix: "{hook}".to_string(),
nats_publish: Some(nats_publish),
}
}
/// Enqueue a sync task. Fire-and-forget — logs errors but does not block.
pub async fn queue_position(&self, repo_uid: uuid::Uuid) -> Option<(usize, usize)> {
let queue_key = format!("{}:sync", self.redis_prefix);
let work_key = format!("{}:work", queue_key);
let redis = self.pool.get().await.ok()?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let queue_items: Vec<String> = conn.lrange(&queue_key, 0, -1).await.ok()?;
let work_items: Vec<String> = conn.lrange(&work_key, 0, -1).await.unwrap_or_default();
let repo_id = repo_uid.to_string();
let queued_before = queue_items
.iter()
.rev()
.take_while(|item| {
serde_json::from_str::<HookTask>(item)
.map(|task| task.repo_id != repo_id)
.unwrap_or(true)
})
.count();
let total = work_items.len() + queue_items.len() + 1;
Some((work_items.len() + queued_before + 1, total))
}
fn push_queue_keys(repo_uid: uuid::Uuid) -> (String, String) {
let hash_tag = format!("{{push:{}}}", repo_uid);
(
format!("git:{}:queue", hash_tag),
format!("git:{}:lock", hash_tag),
)
}
pub async fn join_push_queue(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
) -> redis::RedisResult<()> {
let (queue_key, _) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Io,
"failed to get Redis connection",
e.to_string(),
))
})?;
let mut conn: deadpool_redis::cluster::Connection = redis;
redis::cmd("RPUSH")
.arg(&queue_key)
.arg(request_id)
.query_async::<()>(&mut conn)
.await
}
pub async fn push_queue_position(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
) -> Option<(usize, usize)> {
let (queue_key, _) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.ok()?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let queue_items: Vec<String> = conn.lrange(&queue_key, 0, -1).await.ok()?;
let position = queue_items.iter().position(|item| item == request_id)? + 1;
Some((position, queue_items.len()))
}
pub async fn try_acquire_push_lock(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
ttl_secs: usize,
) -> redis::RedisResult<bool> {
let (_, lock_key) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Io,
"failed to get Redis connection",
e.to_string(),
))
})?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let acquired: Option<String> = redis::cmd("SET")
.arg(&lock_key)
.arg(request_id)
.arg("NX")
.arg("EX")
.arg(ttl_secs)
.query_async(&mut conn)
.await?;
Ok(acquired.is_some())
}
pub async fn release_push_queue(&self, repo_uid: uuid::Uuid, request_id: &str) {
let (queue_key, lock_key) = Self::push_queue_keys(repo_uid);
let redis = match self.pool.get().await {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, repo_id = %repo_uid, "push_queue_release_redis_connection_failed");
return;
}
};
let mut conn: deadpool_redis::cluster::Connection = redis;
let script = redis::Script::new(
r#"
redis.call("LREM", KEYS[1], 0, ARGV[1])
if redis.call("GET", KEYS[2]) == ARGV[1] then
redis.call("DEL", KEYS[2])
end
return 1
"#,
);
if let Err(e) = script
.key(&queue_key)
.key(&lock_key)
.arg(request_id)
.invoke_async::<()>(&mut conn)
.await
{
tracing::warn!(error = %e, repo_id = %repo_uid, "push_queue_release_failed");
}
}
pub async fn refresh_push_lock(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
ttl_secs: usize,
) -> redis::RedisResult<bool> {
let (_, lock_key) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Io,
"failed to get Redis connection",
e.to_string(),
))
})?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let refreshed: i32 = redis::Script::new(
r#"
if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("EXPIRE", KEYS[1], ARGV[2])
return 1
end
return 0
"#,
)
.key(&lock_key)
.arg(request_id)
.arg(ttl_secs)
.invoke_async(&mut conn)
.await?;
Ok(refreshed == 1)
}
pub async fn send(&self, task: RepoReceiveSyncTask) -> Option<(usize, usize)> {
let position = self.queue_position(task.repo_uid).await;
let hook_task = HookTask {
id: uuid::Uuid::new_v4().to_string(),
repo_id: task.repo_uid.to_string(),
task_type: TaskType::Sync,
payload: serde_json::Value::Null,
created_at: chrono::Utc::now(),
retry_count: 0,
};
// Try NATS first if available
if let Some(nats_publish) = &self.nats_publish {
let payload = match serde_json::to_vec(&hook_task) {
Ok(p) => p,
Err(e) => {
tracing::error!("failed to serialize hook task: {}", e);
return position;
}
};
match nats_publish("queue.hook.sync".to_string(), payload).await {
Ok(seq) => {
tracing::info!(repo_id = %task.repo_uid, seq = seq, "hook task queued to NATS");
metrics::counter!("hook_task_queued_total", "backend" => "nats").increment(1);
return position;
}
Err(e) => {
tracing::warn!(error = %e, "NATS publish failed, falling back to Redis");
}
}
}
// Fallback to Redis List
let task_json = match serde_json::to_string(&hook_task) {
Ok(j) => j,
Err(e) => {
tracing::error!("failed to serialize hook task: {}", e);
return position;
}
};
let queue_key = format!("{}:sync", self.redis_prefix);
let redis = match self.pool.get().await {
Ok(c) => c,
Err(e) => {
tracing::error!("failed to get Redis connection: {}", e);
return position;
}
};
let mut conn: deadpool_redis::cluster::Connection = redis;
if let Err(e) = redis::cmd("LPUSH")
.arg(&queue_key)
.arg(&task_json)
.query_async::<()>(&mut conn)
.await
{
tracing::error!(
"failed to enqueue sync task repo_id={} error={}",
task.repo_uid,
e
);
} else {
tracing::info!(repo_id = %task.repo_uid, "hook task queued to Redis");
metrics::counter!("hook_task_queued_total", "backend" => "redis").increment(1);
}
position
}
}
#[derive(Clone)]
pub struct RepoReceiveSyncTask {
pub repo_uid: uuid::Uuid,
}
/// SSH token authentication service.
/// Uses the same token hash algorithm as user access keys (Argon2id PHC string).
#[derive(Clone)]
pub struct SshTokenService {
db: AppDatabase,
}
impl SshTokenService {
pub fn new(db: AppDatabase) -> Self {
Self { db }
}
pub async fn find_user_by_token(&self, token: &str) -> Result<Option<user::Model>, GitError> {
let token_models = user_token::Entity::find()
.filter(user_token::Column::IsRevoked.eq(false))
.all(self.db.reader())
.await
.map_err(|e| GitError::Internal(e.to_string()))?;
for token_model in token_models {
if token_model
.expires_at
.map(|expires_at| expires_at < chrono::Utc::now())
.unwrap_or(false)
{
continue;
}
let Ok(hash) = PasswordHash::new(&token_model.token_hash) else {
tracing::warn!(token_id = token_model.id, "invalid stored SSH token hash");
continue;
};
if Argon2::default()
.verify_password(token.as_bytes(), &hash)
.is_err()
{
continue;
}
let user_model = user::Entity::find()
.filter(user::Column::Uid.eq(token_model.user))
.one(self.db.reader())
.await
.map_err(|e| GitError::Internal(e.to_string()))?;
return Ok(user_model);
}
Ok(None)
}
}
pub async fn run_ssh(config: AppConfig) -> anyhow::Result<()> {
tracing::info!("SSH server initializing");
let db = AppDatabase::init(&config).await?;
let cache = AppCache::init(&config).await?;
let redis_pool = cache.redis_pool().clone();
let _hook = crate::hook::HookService::new(
db.clone(),
cache.clone(),
redis_pool.clone(),
config.clone(),
);
SSHHandle::new(db, config.clone(), cache, redis_pool)
.run_ssh()
.await?;
Ok(())
}