feat(ssh): enhance ref update handling and add push queue

- Add post-receive refs tracking with mutex-protected storage
- Improve branch protection error messages with actionable guidance
- Add push queue slot waiting mechanism for concurrent push control
- Support for checking push queue availability before push operations
This commit is contained in:
ZhenYi 2026-05-15 11:48:40 +08:00
parent 2b6b4af3db
commit 0703816482
5 changed files with 657 additions and 63 deletions

View File

@ -23,7 +23,7 @@ pub fn check_branch_protection(
if r#ref.new_oid == "0000000000000000000000000000000000000000" { if r#ref.new_oid == "0000000000000000000000000000000000000000" {
if protection.forbid_deletion { if protection.forbid_deletion {
return Some(format!( return Some(format!(
"Deletion of protected branch '{}' is forbidden", "GitData: 🛡️ protected branch rejected. Deletion of '{}' is forbidden. Create a PR or ask a maintainer to update branch protection.",
r#ref.name r#ref.name
)); ));
} }
@ -34,7 +34,7 @@ pub fn check_branch_protection(
if r#ref.name.starts_with("refs/tags/") { if r#ref.name.starts_with("refs/tags/") {
if protection.forbid_tag_push { if protection.forbid_tag_push {
return Some(format!( return Some(format!(
"Tag push to protected branch '{}' is forbidden", "GitData: 🛡️ protected ref rejected. Tag push to '{}' is forbidden by branch protection.",
r#ref.name r#ref.name
)); ));
} }
@ -49,7 +49,7 @@ pub fn check_branch_protection(
&& protection.forbid_force_push && protection.forbid_force_push
{ {
return Some(format!( return Some(format!(
"Force push to protected branch '{}' is forbidden", "GitData: 🛡️ protected branch rejected. Force push to '{}' is forbidden. Create a PR instead of rewriting protected history.",
r#ref.name r#ref.name
)); ));
} }
@ -57,7 +57,7 @@ pub fn check_branch_protection(
// Check push // Check push
if protection.forbid_push { if protection.forbid_push {
return Some(format!( return Some(format!(
"Push to protected branch '{}' is forbidden", "GitData: 🛡️ protected branch rejected. Direct push to '{}' is forbidden. Please push to a feature branch and create a PR.",
r#ref.name r#ref.name
)); ));
} }

View File

@ -5,6 +5,7 @@ use crate::ssh::authz::SshAuthService;
use crate::ssh::branch_protect::check_branch_protection; use crate::ssh::branch_protect::check_branch_protection;
use crate::ssh::forward::forward; use crate::ssh::forward::forward;
use crate::ssh::git_service::{GitService, build_git_command, parse_git_command, parse_repo_path}; use crate::ssh::git_service::{GitService, build_git_command, parse_git_command, parse_repo_path};
use crate::ssh::push_queue::{PushQueueEvent, PushQueueWaitError, wait_for_push_queue_slot};
use crate::ssh::ref_update::RefUpdate; use crate::ssh::ref_update::RefUpdate;
use db::cache::AppCache; use db::cache::AppCache;
use db::database::AppDatabase; use db::database::AppDatabase;
@ -21,13 +22,15 @@ use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Stdio; use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio_util::bytes::Bytes; use tokio_util::bytes::Bytes;
const PRE_PACK_LIMIT: usize = 1_048_576; const PRE_PACK_LIMIT: usize = 1_048_576;
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::process::ChildStdin; use tokio::process::ChildStdin;
use tokio::sync::mpsc::Sender; use tokio::sync::{Mutex, mpsc::Sender};
use tokio::time::sleep; use tokio::time::sleep;
pub struct SSHandle { pub struct SSHandle {
pub repo: Option<PathBuf>, pub repo: Option<PathBuf>,
@ -39,6 +42,7 @@ pub struct SSHandle {
pub auth: SshAuthService, pub auth: SshAuthService,
pub buffer: HashMap<ChannelId, Vec<u8>>, pub buffer: HashMap<ChannelId, Vec<u8>>,
pub branch: HashMap<ChannelId, Vec<RefUpdate>>, pub branch: HashMap<ChannelId, Vec<RefUpdate>>,
pub post_receive_refs: HashMap<ChannelId, Arc<Mutex<Vec<RefUpdate>>>>,
pub service: Option<GitService>, pub service: Option<GitService>,
pub cache: AppCache, pub cache: AppCache,
pub sync: ReceiveSyncService, pub sync: ReceiveSyncService,
@ -70,6 +74,7 @@ impl SSHandle {
auth, auth,
buffer: HashMap::new(), buffer: HashMap::new(),
branch: HashMap::new(), branch: HashMap::new(),
post_receive_refs: HashMap::new(),
service: None, service: None,
cache, cache,
sync, sync,
@ -94,8 +99,37 @@ impl SSHandle {
}); });
} }
self.eof.remove(&channel_id); self.eof.remove(&channel_id);
self.post_receive_refs.remove(&channel_id);
self.upload_pack_eof_sent.remove(&channel_id); self.upload_pack_eof_sent.remove(&channel_id);
} }
fn format_post_receive_hints(
namespace: &str,
repo: &repo::Model,
refs: &[RefUpdate],
queue: Option<(usize, usize)>,
) -> String {
let mut lines = Vec::new();
for r#ref in refs {
if r#ref.old_oid == ZERO_OID && r#ref.name.starts_with("refs/heads/") {
let branch = r#ref.name.trim_start_matches("refs/heads/");
lines.push(format!(
"remote: GitData: 🌱 new branch '{}' pushed. Create a PR: /{}/repo/{}/pulls/new?head={}\r\n",
branch,
namespace,
repo.repo_name,
branch
));
}
}
if let Some((position, total)) = queue {
lines.push(format!(
"remote: GitData: ⏳ repository sync queued ({}/{}). Metadata, webhooks and search indexes will update shortly.\r\n",
position, total
));
}
lines.concat()
}
} }
impl Drop for SSHandle { impl Drop for SSHandle {
@ -489,10 +523,16 @@ impl russh::server::Handler for SSHandle {
} }
} }
} }
if let Some(refs_for_hints) = self.post_receive_refs.get(&channel) {
*refs_for_hints.lock().await = refs.clone();
}
self.branch.insert(channel, refs); self.branch.insert(channel, refs);
} }
Err(e) => { Err(e) => {
tracing::warn!("ref_update_parse_error error={:?}", e); tracing::warn!("ref_update_parse_error error={:?}", e);
if let Some(refs_for_hints) = self.post_receive_refs.get(&channel) {
refs_for_hints.lock().await.clear();
}
self.branch.insert(channel, vec![]); self.branch.insert(channel, vec![]);
} }
} }
@ -606,6 +646,7 @@ impl russh::server::Handler for SSHandle {
return Err(russh::Error::Disconnect); return Err(russh::Error::Disconnect);
} }
}; };
let namespace = owner.to_string();
let repo = repo.strip_suffix(".git").unwrap_or(repo).to_string(); let repo = repo.strip_suffix(".git").unwrap_or(repo).to_string();
let repo = match self.auth.find_repo(owner, &repo).await { let repo = match self.auth.find_repo(owner, &repo).await {
@ -660,6 +701,98 @@ impl russh::server::Handler for SSHandle {
is_write is_write
); );
let mut push_queue_lease = if is_write {
let repo_id = repo.id;
let queue_result =
wait_for_push_queue_slot(self.sync.clone(), repo_id, |event, request_id| {
let request_id = request_id.to_string();
match event {
PushQueueEvent::Waiting(position) => {
let msg = format!(
"remote: GitData: ⏳ another push is running for this repository. Queued {}/{}.\r\n",
position.position, position.total
);
let _ = session.extended_data(
channel_id,
1,
Bytes::copy_from_slice(msg.as_bytes()),
);
let _ = session.flush();
tracing::info!(
repo_id = %repo_id,
request_id = %request_id,
position = position.position,
total = position.total,
"push_queue_waiting"
);
}
PushQueueEvent::Acquired => {
let msg = "remote: GitData: 🚀 push queue slot acquired. Processing now.\r\n";
let _ = session.extended_data(
channel_id,
1,
Bytes::copy_from_slice(msg.as_bytes()),
);
let _ = session.flush();
tracing::info!(
repo_id = %repo_id,
request_id = %request_id,
"push_queue_acquired"
);
}
}
})
.await;
match queue_result {
Ok(lease) => Some(lease),
Err(error) => {
match &error {
PushQueueWaitError::Join(e) => {
tracing::error!(error = %e, repo = %repo.repo_name, "push_queue_join_failed");
let msg = "remote: GitData: ⛔ push queue is temporarily unavailable. Please retry later.\r\n";
let _ = session.extended_data(
channel_id,
1,
Bytes::copy_from_slice(msg.as_bytes()),
);
}
PushQueueWaitError::Lock(e) => {
tracing::error!(error = %e, repo_id = %repo.id, "push_queue_lock_failed");
let msg = "remote: GitData: ⛔ push queue lock failed. Please retry later.\r\n";
let _ = session.extended_data(
channel_id,
1,
Bytes::copy_from_slice(msg.as_bytes()),
);
}
PushQueueWaitError::Timeout => {
tracing::warn!(repo_id = %repo.id, "push_queue_timeout");
let msg = "remote: GitData: ⏱️ push queue timed out. Please retry in a moment.\r\n";
let _ = session.extended_data(
channel_id,
1,
Bytes::copy_from_slice(msg.as_bytes()),
);
}
}
let _ = session.channel_failure(channel_id);
let _ = session.close(channel_id);
self.cleanup_channel(channel_id);
return if matches!(error, PushQueueWaitError::Timeout) {
Ok(())
} else {
Err(russh::Error::IO(io::Error::new(
io::ErrorKind::Other,
error.to_string(),
)))
};
}
}
} else {
None
};
let repo_path = PathBuf::from(&repo.storage_path); let repo_path = PathBuf::from(&repo.storage_path);
if !repo_path.exists() { if !repo_path.exists() {
tracing::error!("repo_path_not_found path={}", repo.storage_path); tracing::error!("repo_path_not_found path={}", repo.storage_path);
@ -683,6 +816,9 @@ impl russh::server::Handler for SSHandle {
} }
Err(e) => { Err(e) => {
tracing::error!("process_spawn_failed error={}", e); tracing::error!("process_spawn_failed error={}", e);
if let Some(lease) = &mut push_queue_lease {
lease.release().await;
}
let _ = session.channel_failure(channel_id); let _ = session.channel_failure(channel_id);
self.cleanup_channel(channel_id); self.cleanup_channel(channel_id);
return Err(russh::Error::IO(e)); return Err(russh::Error::IO(e));
@ -693,6 +829,9 @@ impl russh::server::Handler for SSHandle {
Some(s) => s, Some(s) => s,
None => { None => {
tracing::error!("stdin pipe unavailable for channel={:?}", channel_id); tracing::error!("stdin pipe unavailable for channel={:?}", channel_id);
if let Some(lease) = &mut push_queue_lease {
lease.release().await;
}
let _ = session_handle.channel_failure(channel_id).await; let _ = session_handle.channel_failure(channel_id).await;
return Err(russh::Error::IO(io::Error::new( return Err(russh::Error::IO(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
@ -705,6 +844,9 @@ impl russh::server::Handler for SSHandle {
Some(s) => s, Some(s) => s,
None => { None => {
tracing::error!("stdout pipe unavailable for channel={:?}", channel_id); tracing::error!("stdout pipe unavailable for channel={:?}", channel_id);
if let Some(lease) = &mut push_queue_lease {
lease.release().await;
}
return Err(russh::Error::IO(io::Error::new( return Err(russh::Error::IO(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"stdout unavailable", "stdout unavailable",
@ -715,6 +857,9 @@ impl russh::server::Handler for SSHandle {
Some(s) => s, Some(s) => s,
None => { None => {
tracing::error!("stderr pipe unavailable for channel={:?}", channel_id); tracing::error!("stderr pipe unavailable for channel={:?}", channel_id);
if let Some(lease) = &mut push_queue_lease {
lease.release().await;
}
return Err(russh::Error::IO(io::Error::new( return Err(russh::Error::IO(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"stderr unavailable", "stderr unavailable",
@ -724,9 +869,15 @@ impl russh::server::Handler for SSHandle {
let (eof_tx, mut eof_rx) = tokio::sync::mpsc::channel::<bool>(10); let (eof_tx, mut eof_rx) = tokio::sync::mpsc::channel::<bool>(10);
self.eof.insert(channel_id, eof_tx); self.eof.insert(channel_id, eof_tx);
let refs_for_hints = Arc::new(Mutex::new(Vec::new()));
self.post_receive_refs
.insert(channel_id, refs_for_hints.clone());
let repo_uid = repo.id; let repo_uid = repo.id;
let repo_for_hints = repo.clone();
let namespace_for_hints = namespace.clone();
let should_sync = service == GitService::ReceivePack; let should_sync = service == GitService::ReceivePack;
let sync = self.sync.clone(); let sync = self.sync.clone();
let mut push_queue_lease = push_queue_lease;
let fut = async move { let fut = async move {
tracing::info!(channel = ?channel_id, "git_task_started"); tracing::info!(channel = ?channel_id, "git_task_started");
@ -753,11 +904,23 @@ impl russh::server::Handler for SSHandle {
loop { loop {
tokio::select! { tokio::select! {
result = shell.wait() => { result = shell.wait() => {
let status = result?; let status = match result {
Ok(status) => status,
Err(e) => {
if let Some(lease) = &mut push_queue_lease {
lease.release().await;
}
return Err(russh::Error::IO(e));
}
};
let status_code = status.code().unwrap_or(128) as u32; let status_code = status.code().unwrap_or(128) as u32;
tracing::info!("git_process_exited channel={:?} status={}", channel_id, status_code); tracing::info!("git_process_exited channel={:?} status={}", channel_id, status_code);
if let Some(lease) = &mut push_queue_lease {
lease.release().await;
}
if !stdout_done || !stderr_done { if !stdout_done || !stderr_done {
let _ = tokio::time::timeout(Duration::from_millis(100), async { let _ = tokio::time::timeout(Duration::from_millis(100), async {
tokio::join!( tokio::join!(
@ -775,11 +938,20 @@ impl russh::server::Handler for SSHandle {
}).await; }).await;
} }
if should_sync { if should_sync && status_code == 0 {
let sync = sync.clone(); let queue = sync.send(RepoReceiveSyncTask { repo_uid }).await;
tokio::spawn(async move { let refs_for_hints = refs_for_hints.lock().await.clone();
sync.send(RepoReceiveSyncTask { repo_uid }).await let msg = SSHandle::format_post_receive_hints(
}); &namespace_for_hints,
&repo_for_hints,
&refs_for_hints,
queue,
);
if !msg.is_empty() {
let _ = session_handle
.extended_data(channel_id, 1, Bytes::copy_from_slice(msg.as_bytes()))
.await;
}
} }
let _ = session_handle.exit_status_request(channel_id, status_code).await; let _ = session_handle.exit_status_request(channel_id, status_code).await;

View File

@ -1,12 +1,14 @@
use crate::error::GitError; use crate::error::GitError;
use crate::hook::pool::types::{HookTask, TaskType}; use crate::hook::pool::types::{HookTask, TaskType};
use crate::http::utils::hash_access_key;
use anyhow::Context; use anyhow::Context;
use argon2::Argon2;
use argon2::password_hash::{PasswordHash, PasswordVerifier};
use config::AppConfig; use config::AppConfig;
use db::cache::AppCache; use db::cache::AppCache;
use db::database::AppDatabase; use db::database::AppDatabase;
use deadpool_redis::cluster::Pool as RedisPool; use deadpool_redis::cluster::Pool as RedisPool;
use models::users::{user, user_token}; use models::users::{user, user_token};
use redis::AsyncCommands;
use russh::keys::PrivateKey; use russh::keys::PrivateKey;
use russh::server::Server; use russh::server::Server;
use russh::{MethodKind, MethodSet, SshId, server::Config}; use russh::{MethodKind, MethodSet, SshId, server::Config};
@ -20,6 +22,7 @@ pub mod branch_protect;
pub mod forward; pub mod forward;
pub mod git_service; pub mod git_service;
pub mod handle; pub mod handle;
pub mod push_queue;
pub mod rate_limit; pub mod rate_limit;
pub mod ref_update; pub mod ref_update;
pub mod server; pub mod server;
@ -194,7 +197,159 @@ impl ReceiveSyncService {
} }
/// Enqueue a sync task. Fire-and-forget — logs errors but does not block. /// Enqueue a sync task. Fire-and-forget — logs errors but does not block.
pub async fn send(&self, task: RepoReceiveSyncTask) { pub async fn queue_position(&self, repo_uid: uuid::Uuid) -> Option<(usize, usize)> {
let queue_key = format!("{}:sync", self.redis_prefix);
let work_key = format!("{}:work", queue_key);
let redis = self.pool.get().await.ok()?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let queue_items: Vec<String> = conn.lrange(&queue_key, 0, -1).await.ok()?;
let work_items: Vec<String> = conn.lrange(&work_key, 0, -1).await.unwrap_or_default();
let repo_id = repo_uid.to_string();
let queued_before = queue_items
.iter()
.rev()
.take_while(|item| {
serde_json::from_str::<HookTask>(item)
.map(|task| task.repo_id != repo_id)
.unwrap_or(true)
})
.count();
let total = work_items.len() + queue_items.len() + 1;
Some((work_items.len() + queued_before + 1, total))
}
fn push_queue_keys(repo_uid: uuid::Uuid) -> (String, String) {
let hash_tag = format!("{{push:{}}}", repo_uid);
(
format!("git:{}:queue", hash_tag),
format!("git:{}:lock", hash_tag),
)
}
pub async fn join_push_queue(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
) -> redis::RedisResult<()> {
let (queue_key, _) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Io,
"failed to get Redis connection",
e.to_string(),
))
})?;
let mut conn: deadpool_redis::cluster::Connection = redis;
redis::cmd("RPUSH")
.arg(&queue_key)
.arg(request_id)
.query_async::<()>(&mut conn)
.await
}
pub async fn push_queue_position(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
) -> Option<(usize, usize)> {
let (queue_key, _) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.ok()?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let queue_items: Vec<String> = conn.lrange(&queue_key, 0, -1).await.ok()?;
let position = queue_items.iter().position(|item| item == request_id)? + 1;
Some((position, queue_items.len()))
}
pub async fn try_acquire_push_lock(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
ttl_secs: usize,
) -> redis::RedisResult<bool> {
let (_, lock_key) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Io,
"failed to get Redis connection",
e.to_string(),
))
})?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let acquired: Option<String> = redis::cmd("SET")
.arg(&lock_key)
.arg(request_id)
.arg("NX")
.arg("EX")
.arg(ttl_secs)
.query_async(&mut conn)
.await?;
Ok(acquired.is_some())
}
pub async fn release_push_queue(&self, repo_uid: uuid::Uuid, request_id: &str) {
let (queue_key, lock_key) = Self::push_queue_keys(repo_uid);
let redis = match self.pool.get().await {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, repo_id = %repo_uid, "push_queue_release_redis_connection_failed");
return;
}
};
let mut conn: deadpool_redis::cluster::Connection = redis;
let script = redis::Script::new(
r#"
redis.call("LREM", KEYS[1], 0, ARGV[1])
if redis.call("GET", KEYS[2]) == ARGV[1] then
redis.call("DEL", KEYS[2])
end
return 1
"#,
);
if let Err(e) = script
.key(&queue_key)
.key(&lock_key)
.arg(request_id)
.invoke_async::<()>(&mut conn)
.await
{
tracing::warn!(error = %e, repo_id = %repo_uid, "push_queue_release_failed");
}
}
pub async fn refresh_push_lock(
&self,
repo_uid: uuid::Uuid,
request_id: &str,
ttl_secs: usize,
) -> redis::RedisResult<bool> {
let (_, lock_key) = Self::push_queue_keys(repo_uid);
let redis = self.pool.get().await.map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::Io,
"failed to get Redis connection",
e.to_string(),
))
})?;
let mut conn: deadpool_redis::cluster::Connection = redis;
let refreshed: i32 = redis::Script::new(
r#"
if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("EXPIRE", KEYS[1], ARGV[2])
return 1
end
return 0
"#,
)
.key(&lock_key)
.arg(request_id)
.arg(ttl_secs)
.invoke_async(&mut conn)
.await?;
Ok(refreshed == 1)
}
pub async fn send(&self, task: RepoReceiveSyncTask) -> Option<(usize, usize)> {
let position = self.queue_position(task.repo_uid).await;
let hook_task = HookTask { let hook_task = HookTask {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
repo_id: task.repo_uid.to_string(), repo_id: task.repo_uid.to_string(),
@ -210,7 +365,7 @@ impl ReceiveSyncService {
Ok(p) => p, Ok(p) => p,
Err(e) => { Err(e) => {
tracing::error!("failed to serialize hook task: {}", e); tracing::error!("failed to serialize hook task: {}", e);
return; return position;
} }
}; };
@ -218,7 +373,7 @@ impl ReceiveSyncService {
Ok(seq) => { Ok(seq) => {
tracing::info!(repo_id = %task.repo_uid, seq = seq, "hook task queued to NATS"); tracing::info!(repo_id = %task.repo_uid, seq = seq, "hook task queued to NATS");
metrics::counter!("hook_task_queued_total", "backend" => "nats").increment(1); metrics::counter!("hook_task_queued_total", "backend" => "nats").increment(1);
return; return position;
} }
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "NATS publish failed, falling back to Redis"); tracing::warn!(error = %e, "NATS publish failed, falling back to Redis");
@ -231,7 +386,7 @@ impl ReceiveSyncService {
Ok(j) => j, Ok(j) => j,
Err(e) => { Err(e) => {
tracing::error!("failed to serialize hook task: {}", e); tracing::error!("failed to serialize hook task: {}", e);
return; return position;
} }
}; };
@ -241,7 +396,7 @@ impl ReceiveSyncService {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
tracing::error!("failed to get Redis connection: {}", e); tracing::error!("failed to get Redis connection: {}", e);
return; return position;
} }
}; };
@ -261,6 +416,7 @@ impl ReceiveSyncService {
tracing::info!(repo_id = %task.repo_uid, "hook task queued to Redis"); tracing::info!(repo_id = %task.repo_uid, "hook task queued to Redis");
metrics::counter!("hook_task_queued_total", "backend" => "redis").increment(1); metrics::counter!("hook_task_queued_total", "backend" => "redis").increment(1);
} }
position
} }
} }
@ -281,40 +437,44 @@ impl SshTokenService {
Self { db } Self { db }
} }
fn hash_token(token: &str) -> Result<String, argon2::password_hash::Error> {
hash_access_key(token)
}
pub async fn find_user_by_token(&self, token: &str) -> Result<Option<user::Model>, GitError> { pub async fn find_user_by_token(&self, token: &str) -> Result<Option<user::Model>, GitError> {
let token_hash = Self::hash_token(token) let token_models = user_token::Entity::find()
.map_err(|e| GitError::Internal(format!("Token hash failed: {}", e)))?;
let token_model = user_token::Entity::find()
.filter(user_token::Column::TokenHash.eq(&token_hash))
.filter(user_token::Column::IsRevoked.eq(false)) .filter(user_token::Column::IsRevoked.eq(false))
.one(self.db.reader()) .all(self.db.reader())
.await .await
.map_err(|e| GitError::Internal(e.to_string()))?; .map_err(|e| GitError::Internal(e.to_string()))?;
let token_model = match token_model { for token_model in token_models {
Some(t) => t, if token_model
None => return Ok(None), .expires_at
}; .map(|expires_at| expires_at < chrono::Utc::now())
.unwrap_or(false)
// Check expiry {
if let Some(expires_at) = token_model.expires_at { continue;
if expires_at < chrono::Utc::now() {
return Ok(None);
} }
let Ok(hash) = PasswordHash::new(&token_model.token_hash) else {
tracing::warn!(token_id = token_model.id, "invalid stored SSH token hash");
continue;
};
if Argon2::default()
.verify_password(token.as_bytes(), &hash)
.is_err()
{
continue;
}
let user_model = user::Entity::find()
.filter(user::Column::Uid.eq(token_model.user))
.one(self.db.reader())
.await
.map_err(|e| GitError::Internal(e.to_string()))?;
return Ok(user_model);
} }
let user_model = user::Entity::find() Ok(None)
.filter(user::Column::Uid.eq(token_model.user))
.one(self.db.reader())
.await
.map_err(|e| GitError::Internal(e.to_string()))?;
Ok(user_model)
} }
} }

186
libs/git/ssh/push_queue.rs Normal file
View File

@ -0,0 +1,186 @@
use crate::ssh::ReceiveSyncService;
use std::fmt;
use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
use tokio::time::sleep;
pub const PUSH_QUEUE_TIMEOUT: Duration = Duration::from_secs(120);
pub const PUSH_LOCK_TTL_SECS: usize = 300;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PushQueuePosition {
pub position: usize,
pub total: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PushQueueEvent {
Waiting(PushQueuePosition),
Acquired,
}
#[derive(Debug)]
pub enum PushQueueWaitError {
Join(redis::RedisError),
Lock(redis::RedisError),
Timeout,
}
impl fmt::Display for PushQueueWaitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Join(e) => write!(f, "failed to join push queue: {e}"),
Self::Lock(e) => write!(f, "failed to acquire push queue lock: {e}"),
Self::Timeout => write!(f, "push queue timed out"),
}
}
}
impl std::error::Error for PushQueueWaitError {}
pub struct PushQueueLease {
service: ReceiveSyncService,
repo_uid: uuid::Uuid,
request_id: String,
heartbeat: Option<JoinHandle<()>>,
released: bool,
}
impl PushQueueLease {
fn new(service: ReceiveSyncService, repo_uid: uuid::Uuid, request_id: String) -> Self {
let heartbeat = Some(start_lock_heartbeat(
service.clone(),
repo_uid,
request_id.clone(),
));
Self {
service,
repo_uid,
request_id,
heartbeat,
released: false,
}
}
pub fn request_id(&self) -> &str {
&self.request_id
}
pub async fn release(&mut self) {
if self.released {
return;
}
self.service
.release_push_queue(self.repo_uid, &self.request_id)
.await;
if let Some(heartbeat) = self.heartbeat.take() {
heartbeat.abort();
}
self.released = true;
}
}
impl Drop for PushQueueLease {
fn drop(&mut self) {
if self.released {
return;
}
if let Some(heartbeat) = self.heartbeat.take() {
heartbeat.abort();
}
let service = self.service.clone();
let repo_uid = self.repo_uid;
let request_id = self.request_id.clone();
tokio::spawn(async move {
service.release_push_queue(repo_uid, &request_id).await;
});
}
}
fn start_lock_heartbeat(
service: ReceiveSyncService,
repo_uid: uuid::Uuid,
request_id: String,
) -> JoinHandle<()> {
tokio::spawn(async move {
let interval = Duration::from_secs((PUSH_LOCK_TTL_SECS as u64 / 3).max(30));
loop {
sleep(interval).await;
match service
.refresh_push_lock(repo_uid, &request_id, PUSH_LOCK_TTL_SECS)
.await
{
Ok(true) => {}
Ok(false) => {
tracing::warn!(
repo_id = %repo_uid,
request_id = %request_id,
"push_queue_lock_lost"
);
break;
}
Err(e) => {
tracing::warn!(
error = %e,
repo_id = %repo_uid,
request_id = %request_id,
"push_queue_lock_refresh_failed"
);
}
}
}
})
}
pub async fn wait_for_push_queue_slot<F>(
service: ReceiveSyncService,
repo_uid: uuid::Uuid,
mut on_event: F,
) -> Result<PushQueueLease, PushQueueWaitError>
where
F: FnMut(PushQueueEvent, &str),
{
let request_id = uuid::Uuid::new_v4().to_string();
service
.join_push_queue(repo_uid, &request_id)
.await
.map_err(PushQueueWaitError::Join)?;
let deadline = Instant::now() + PUSH_QUEUE_TIMEOUT;
let mut last_position = None;
loop {
let position = service.push_queue_position(repo_uid, &request_id).await;
if let Some((position, total)) = position {
let position = PushQueuePosition { position, total };
if last_position != Some(position) && position.position > 1 {
on_event(PushQueueEvent::Waiting(position), &request_id);
}
last_position = Some(position);
if position.position == 1 {
match service
.try_acquire_push_lock(repo_uid, &request_id, PUSH_LOCK_TTL_SECS)
.await
{
Ok(true) => {
on_event(PushQueueEvent::Acquired, &request_id);
return Ok(PushQueueLease::new(service, repo_uid, request_id));
}
Ok(false) => {}
Err(e) => {
service.release_push_queue(repo_uid, &request_id).await;
return Err(PushQueueWaitError::Lock(e));
}
}
}
}
if Instant::now() >= deadline {
service.release_push_queue(repo_uid, &request_id).await;
return Err(PushQueueWaitError::Timeout);
}
sleep(Duration::from_secs(1)).await;
}
}

View File

@ -6,32 +6,108 @@ pub struct RefUpdate {
} }
impl RefUpdate { impl RefUpdate {
/// Parse git reference update commands from SSH protocol text. /// Parse git receive-pack reference update commands from pkt-line data.
/// Format: "<old-oid> <new-oid> <ref-name>\n" /// Payload format: "<old-oid> <new-oid> <ref-name>\0capabilities\n".
pub fn parse_ref_updates(data: &[u8]) -> Result<Vec<Self>, String> { pub fn parse_ref_updates(data: &[u8]) -> Result<Vec<Self>, String> {
let text = String::from_utf8_lossy(data);
let mut refs = Vec::new(); let mut refs = Vec::new();
for line in text.lines() {
let line = line.trim(); for payload in parse_pkt_line_payloads(data)? {
if line.is_empty() || line.starts_with('#') || line.starts_with("PACK") { let line = String::from_utf8_lossy(payload);
let line = line.trim_end_matches(['\r', '\n']);
if line.is_empty() {
continue; continue;
} }
let mut parts = line.split_whitespace();
let old_oid = parts.next().map(|s| s.to_string()).unwrap_or_default(); let mut parts = line.splitn(3, ' ');
let new_oid = parts.next().map(|s| s.to_string()).unwrap_or_default(); let old_oid = parts.next().unwrap_or_default();
let name = parts let new_oid = parts.next().unwrap_or_default();
.next() let raw_name = parts.next().unwrap_or_default();
.unwrap_or("") let name = raw_name
.trim_start_matches('\0') .split_once('\0')
.to_string(); .map(|(name, _)| name)
if !name.is_empty() { .unwrap_or(raw_name)
refs.push(RefUpdate { .trim();
old_oid,
new_oid, if old_oid.len() != 40 || new_oid.len() != 40 || name.is_empty() {
name, continue;
});
} }
refs.push(RefUpdate {
old_oid: old_oid.to_string(),
new_oid: new_oid.to_string(),
name: name.to_string(),
});
} }
Ok(refs) Ok(refs)
} }
} }
fn parse_pkt_line_payloads(data: &[u8]) -> Result<Vec<&[u8]>, String> {
let mut payloads = Vec::new();
let mut offset = 0;
while offset + 4 <= data.len() {
let header = std::str::from_utf8(&data[offset..offset + 4])
.map_err(|_| "invalid pkt-line header encoding".to_string())?;
let len = usize::from_str_radix(header, 16)
.map_err(|_| format!("invalid pkt-line length: {header}"))?;
offset += 4;
match len {
0 => break,
1..=3 => return Err(format!("invalid pkt-line length: {len}")),
_ => {
let payload_len = len - 4;
if offset + payload_len > data.len() {
return Err("truncated pkt-line payload".to_string());
}
payloads.push(&data[offset..offset + payload_len]);
offset += payload_len;
}
}
}
Ok(payloads)
}
#[cfg(test)]
mod tests {
use super::RefUpdate;
fn pkt(payload: &str) -> Vec<u8> {
let len = payload.len() + 4;
let mut out = format!("{len:04x}").into_bytes();
out.extend_from_slice(payload.as_bytes());
out
}
#[test]
fn parses_receive_pack_ref_with_capabilities() {
let mut data = pkt(
"0000000000000000000000000000000000000000 1111111111111111111111111111111111111111 refs/heads/feature\0 report-status\n",
);
data.extend_from_slice(b"0000");
let refs = RefUpdate::parse_ref_updates(&data).unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].old_oid, "0000000000000000000000000000000000000000");
assert_eq!(refs[0].new_oid, "1111111111111111111111111111111111111111");
assert_eq!(refs[0].name, "refs/heads/feature");
}
#[test]
fn parses_receive_pack_ref_without_pack_payload() {
let mut data = pkt(
"2222222222222222222222222222222222222222 0000000000000000000000000000000000000000 refs/heads/old\n",
);
data.extend_from_slice(b"0000");
let refs = RefUpdate::parse_ref_updates(&data).unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].name, "refs/heads/old");
assert_eq!(refs[0].new_oid, "0000000000000000000000000000000000000000");
}
}