From 3b17a0493fe0d3c4dd52ccdc5415eceabaa6a11b Mon Sep 17 00:00:00 2001 From: ZhenYi <434836402@qq.com> Date: Mon, 11 May 2026 17:05:30 +0800 Subject: [PATCH] refactor(git/ssh): extract helper functions into dedicated modules Move RefUpdate, GitService, branch_protection check, and forward function from handle.rs into separate modules. --- libs/git/ssh/branch_protect.rs | 67 +++++++++ libs/git/ssh/forward.rs | 50 +++++++ libs/git/ssh/git_service.rs | 82 +++++++++++ libs/git/ssh/handle.rs | 256 ++------------------------------- libs/git/ssh/mod.rs | 8 +- libs/git/ssh/ref_update.rs | 37 +++++ 6 files changed, 254 insertions(+), 246 deletions(-) create mode 100644 libs/git/ssh/branch_protect.rs create mode 100644 libs/git/ssh/forward.rs create mode 100644 libs/git/ssh/git_service.rs create mode 100644 libs/git/ssh/ref_update.rs diff --git a/libs/git/ssh/branch_protect.rs b/libs/git/ssh/branch_protect.rs new file mode 100644 index 0000000..173f0f6 --- /dev/null +++ b/libs/git/ssh/branch_protect.rs @@ -0,0 +1,67 @@ +use crate::ssh::ref_update::RefUpdate; +use models::repos::repo_branch_protect; + +/// Ref name matches a protection rule exactly, or as a directory prefix +/// (e.g. "refs/heads/main" matches "refs/heads/main" and "refs/heads/main/*" +/// but NOT "refs/heads/main-v2"). +fn ref_matches_protection(ref_name: &str, protection_branch: &str) -> bool { + ref_name == protection_branch + || ref_name.starts_with(&format!("{}/", protection_branch)) +} + +/// Granular branch protection check (same logic as HTTP handler). +/// Returns `Some(error_message)` if the push should be rejected. +pub fn check_branch_protection( + branch_protects: &[repo_branch_protect::Model], + r#ref: &RefUpdate, +) -> Option { + for protection in branch_protects { + if !ref_matches_protection(&r#ref.name, &protection.branch) { + continue; + } + + // Check deletion (new_oid is all zeros) + if r#ref.new_oid == "0000000000000000000000000000000000000000" { + if protection.forbid_deletion { + return Some(format!( + "Deletion of protected branch '{}' is forbidden", + r#ref.name + )); + } + continue; + } + + // Check tag push + if r#ref.name.starts_with("refs/tags/") { + if protection.forbid_tag_push { + return Some(format!( + "Tag push to protected branch '{}' is forbidden", + r#ref.name + )); + } + continue; + } + + // Check force push: old != new AND old is non-zero (non-fast-forward) + let is_new_branch = r#ref.old_oid == "0000000000000000000000000000000000000000"; + if !is_new_branch + && r#ref.old_oid != r#ref.new_oid + && r#ref.name.starts_with("refs/heads/") + && protection.forbid_force_push + { + return Some(format!( + "Force push to protected branch '{}' is forbidden", + r#ref.name + )); + } + + // Check push + if protection.forbid_push { + return Some(format!( + "Push to protected branch '{}' is forbidden", + r#ref.name + )); + } + } + None +} diff --git a/libs/git/ssh/forward.rs b/libs/git/ssh/forward.rs new file mode 100644 index 0000000..7bfb4f5 --- /dev/null +++ b/libs/git/ssh/forward.rs @@ -0,0 +1,50 @@ +use russh::server::Handle; +use russh::ChannelId; +use std::future::Future; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::time::sleep; +use tokio_util::bytes::Bytes; + +pub async fn forward<'a, R, Fut, Fwd>( + session_handle: &'a Handle, + chan_id: ChannelId, + r: &mut R, + mut fwd: Fwd, +) -> Result<(), russh::Error> +where + R: AsyncRead + Send + Unpin, + Fut: Future> + 'a, + Fwd: FnMut(&'a Handle, ChannelId, Bytes) -> Fut, +{ + const BUF_SIZE: usize = 1024 * 32; + const MAX_RETRIES: usize = 5; + const RETRY_DELAY: u64 = 10; // ms + + let mut buf = [0u8; BUF_SIZE]; + loop { + let read = r.read(&mut buf).await?; + + if read == 0 { + break; + } + + let mut chunk = Bytes::copy_from_slice(&buf[..read]); + let mut retries = 0; + loop { + match fwd(session_handle, chan_id, chunk).await { + Ok(()) => break, + Err(unsent) => { + retries += 1; + if retries >= MAX_RETRIES { + return Ok(()); + } + chunk = unsent; + sleep(Duration::from_millis(RETRY_DELAY)).await; + } + } + } + } + + Ok(()) +} diff --git a/libs/git/ssh/git_service.rs b/libs/git/ssh/git_service.rs new file mode 100644 index 0000000..50869c4 --- /dev/null +++ b/libs/git/ssh/git_service.rs @@ -0,0 +1,82 @@ +use std::path::PathBuf; +use std::str::FromStr; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum GitService { + UploadPack, + ReceivePack, + UploadArchive, +} + +impl FromStr for GitService { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "upload-pack" => Ok(Self::UploadPack), + "receive-pack" => Ok(Self::ReceivePack), + "upload-archive" => Ok(Self::UploadArchive), + _ => Err(()), + } + } +} + +pub fn parse_git_command(cmd: &str) -> Option<(GitService, &str)> { + let (svc, path) = match cmd.split_once(' ') { + Some(("git-receive-pack", path)) => (GitService::ReceivePack, path), + Some(("git-upload-pack", path)) => (GitService::UploadPack, path), + Some(("git-upload-archive", path)) => (GitService::UploadArchive, path), + _ => return None, + }; + Some((svc, strip_apostrophes(path))) +} + +pub fn parse_repo_path(path: &str) -> Option<(&str, &str)> { + let path = path.trim_matches('/'); + let mut parts = path.splitn(2, '/'); + match (parts.next(), parts.next()) { + (Some(owner), Some(repo)) if !owner.is_empty() && !repo.is_empty() => Some((owner, repo)), + _ => None, + } +} + +pub fn build_git_command(service: GitService, path: PathBuf) -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("git"); + + let cwd = match path.canonicalize() { + Ok(p) => p, + Err(e) => { + tracing::debug!(error = %e, "path canonicalize failed, falling back to raw path"); + path.clone() + } + }; + cmd.current_dir(cwd); + + match service { + GitService::UploadPack => { cmd.arg("upload-pack"); } + GitService::ReceivePack => { cmd.arg("receive-pack"); } + GitService::UploadArchive => { cmd.arg("upload-archive"); } + } + + cmd.arg(".") + .env("GIT_CONFIG_NOSYSTEM", "1") + .env("GIT_NO_REPLACE_OBJECTS", "1"); + + #[cfg(unix)] + { + cmd.env("GIT_CONFIG_GLOBAL", "/dev/null") + .env("GIT_CONFIG_SYSTEM", "/dev/null"); + } + #[cfg(windows)] + { + let nul = "NUL"; + cmd.env("GIT_CONFIG_GLOBAL", nul) + .env("GIT_CONFIG_SYSTEM", nul); + } + + cmd +} + +fn strip_apostrophes(s: &str) -> &str { + s.trim_matches('\'') +} diff --git a/libs/git/ssh/handle.rs b/libs/git/ssh/handle.rs index e19fe7a..651be41 100644 --- a/libs/git/ssh/handle.rs +++ b/libs/git/ssh/handle.rs @@ -7,8 +7,13 @@ use db::database::AppDatabase; use models::repos::{repo, repo_branch_protect}; use models::users::user; use russh::keys::{Certificate, PublicKey}; -use russh::server::{Auth, Handle, Msg, Session}; -use russh::{Channel, ChannelId, CryptoVec, Disconnect}; +use russh::server::{Auth, Msg, Session}; +use russh::{Channel, ChannelId, Disconnect}; +use crate::ssh::ref_update::RefUpdate; +use crate::ssh::git_service::{GitService, parse_git_command, parse_repo_path, build_git_command}; +use crate::ssh::branch_protect::check_branch_protection; +use crate::ssh::forward::forward; +use tokio_util::bytes::Bytes; use sea_orm::ColumnTrait; use sea_orm::EntityTrait; use sea_orm::QueryFilter; @@ -17,53 +22,13 @@ use std::io; use std::net::SocketAddr; use std::path::PathBuf; use std::process::Stdio; -use std::str::FromStr; use std::time::Duration; const PRE_PACK_LIMIT: usize = 1_048_576; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::io::AsyncWriteExt; use tokio::process::ChildStdin; use tokio::sync::mpsc::Sender; use tokio::time::sleep; - -#[derive(Clone, Debug)] -pub struct RefUpdate { - pub name: String, - pub old_oid: String, - pub new_oid: String, -} - -impl RefUpdate { - /// Parse git reference update commands from SSH protocol text. - /// Format: " \n" - pub fn parse_ref_updates(data: &[u8]) -> Result, String> { - let text = String::from_utf8_lossy(data); - let mut refs = Vec::new(); - for line in text.lines() { - let line = line.trim(); - if line.is_empty() || line.starts_with('#') || line.starts_with("PACK") { - continue; - } - let mut parts = line.split_whitespace(); - let old_oid = parts.next().map(|s| s.to_string()).unwrap_or_default(); - let new_oid = parts.next().map(|s| s.to_string()).unwrap_or_default(); - let name = parts - .next() - .unwrap_or("") - .trim_start_matches('\0') - .to_string(); - if !name.is_empty() { - refs.push(RefUpdate { - old_oid, - new_oid, - name, - }); - } - } - Ok(refs) - } -} - pub struct SSHandle { pub repo: Option, pub model: Option, @@ -379,7 +344,7 @@ impl russh::server::Handler for SSHandle { let _ = session.extended_data( channel, 1, - CryptoVec::from_slice(msg.as_bytes()), + Bytes::copy_from_slice(msg.as_bytes()), ); let _ = session.exit_status_request(channel, 1); let _ = session.eof(channel); @@ -416,7 +381,7 @@ impl russh::server::Handler for SSHandle { let _ = session.extended_data( channel, 1, - CryptoVec::from_slice(full_msg.as_bytes()), + Bytes::copy_from_slice(full_msg.as_bytes()), ); let _ = session.exit_status_request(channel, 1); let _ = session.eof(channel); @@ -480,7 +445,7 @@ impl russh::server::Handler for SSHandle { tracing::info!("shell_request user={}", user.username); let _ = session - .data(channel_id, CryptoVec::from_slice(welcome_msg.as_bytes())); + .data(channel_id, Bytes::copy_from_slice(welcome_msg.as_bytes())); let _ = session.exit_status_request(channel_id, 0); let _ = session.eof(channel_id); let _ = session.close(channel_id); @@ -489,7 +454,7 @@ impl russh::server::Handler for SSHandle { tracing::warn!("shell_request_unauthenticated channel={:?}", channel_id); let msg = "Authentication required\r\n"; let _ = session - .data(channel_id, CryptoVec::from_slice(msg.as_bytes())); + .data(channel_id, Bytes::copy_from_slice(msg.as_bytes())); let _ = session.exit_status_request(channel_id, 1); let _ = session.eof(channel_id); let _ = session.close(channel_id); @@ -733,200 +698,3 @@ impl russh::server::Handler for SSHandle { Ok(()) } } - -fn parse_git_command(cmd: &str) -> Option<(GitService, &str)> { - let (svc, path) = match cmd.split_once(' ') { - Some(("git-receive-pack", path)) => (GitService::ReceivePack, path), - Some(("git-upload-pack", path)) => (GitService::UploadPack, path), - Some(("git-upload-archive", path)) => (GitService::UploadArchive, path), - _ => return None, - }; - Some((svc, strip_apostrophes(path))) -} - -fn parse_repo_path(path: &str) -> Option<(&str, &str)> { - let path = path.trim_matches('/'); - let mut parts = path.splitn(2, '/'); - match (parts.next(), parts.next()) { - (Some(owner), Some(repo)) if !owner.is_empty() && !repo.is_empty() => Some((owner, repo)), - _ => None, - } -} - -fn build_git_command(service: GitService, path: PathBuf) -> tokio::process::Command { - let mut cmd = tokio::process::Command::new("git"); - - // Canonicalize only for validation; if it fails, fall back to the raw path. - // Using canonicalize for current_dir is safe since we validate repo existence - // before reaching this point. - let cwd = match path.canonicalize() { - Ok(p) => p, - Err(e) => { - tracing::debug!(error = %e, "path canonicalize failed, falling back to raw path"); - path.clone() - } - }; - cmd.current_dir(cwd); - - match service { - GitService::UploadPack => { cmd.arg("upload-pack"); } - GitService::ReceivePack => { cmd.arg("receive-pack"); } - GitService::UploadArchive => { cmd.arg("upload-archive"); } - } - - cmd.arg(".") - .env("GIT_CONFIG_NOSYSTEM", "1") - .env("GIT_NO_REPLACE_OBJECTS", "1"); - - #[cfg(unix)] - { - cmd.env("GIT_CONFIG_GLOBAL", "/dev/null") - .env("GIT_CONFIG_SYSTEM", "/dev/null"); - } - #[cfg(windows)] - { - // On Windows, /dev/null doesn't exist. Set invalid paths so git - // ignores them without crashing. GIT_CONFIG_NOSYSTEM already disables - // the system config. - let nul = "NUL"; - cmd.env("GIT_CONFIG_GLOBAL", nul) - .env("GIT_CONFIG_SYSTEM", nul); - } - - cmd -} - -fn strip_apostrophes(s: &str) -> &str { - s.trim_matches('\'') -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum GitService { - UploadPack, - ReceivePack, - UploadArchive, -} - -impl FromStr for GitService { - type Err = (); - - fn from_str(s: &str) -> Result { - match s { - "upload-pack" => Ok(Self::UploadPack), - "receive-pack" => Ok(Self::ReceivePack), - "upload-archive" => Ok(Self::UploadArchive), - _ => Err(()), - } - } -} - -/// Ref name matches a protection rule exactly, or as a directory prefix -/// (e.g. "refs/heads/main" matches "refs/heads/main" and "refs/heads/main/*" -/// but NOT "refs/heads/main-v2"). -fn ref_matches_protection(ref_name: &str, protection_branch: &str) -> bool { - ref_name == protection_branch - || ref_name.starts_with(&format!("{}/", protection_branch)) -} - -/// Granular branch protection check (same logic as HTTP handler). -/// Returns `Some(error_message)` if the push should be rejected. -fn check_branch_protection( - branch_protects: &[repo_branch_protect::Model], - r#ref: &RefUpdate, -) -> Option { - for protection in branch_protects { - if !ref_matches_protection(&r#ref.name, &protection.branch) { - continue; - } - - // Check deletion (new_oid is all zeros) - if r#ref.new_oid == "0000000000000000000000000000000000000000" { - if protection.forbid_deletion { - return Some(format!( - "Deletion of protected branch '{}' is forbidden", - r#ref.name - )); - } - continue; - } - - // Check tag push - if r#ref.name.starts_with("refs/tags/") { - if protection.forbid_tag_push { - return Some(format!( - "Tag push to protected branch '{}' is forbidden", - r#ref.name - )); - } - continue; - } - - // Check force push: old != new AND old is non-zero (non-fast-forward) - let is_new_branch = r#ref.old_oid == "0000000000000000000000000000000000000000"; - if !is_new_branch - && r#ref.old_oid != r#ref.new_oid - && r#ref.name.starts_with("refs/heads/") - && protection.forbid_force_push - { - return Some(format!( - "Force push to protected branch '{}' is forbidden", - r#ref.name - )); - } - - // Check push - if protection.forbid_push { - return Some(format!( - "Push to protected branch '{}' is forbidden", - r#ref.name - )); - } - } - None -} - -async fn forward<'a, R, Fut, Fwd>( - session_handle: &'a Handle, - chan_id: ChannelId, - r: &mut R, - mut fwd: Fwd, -) -> Result<(), russh::Error> -where - R: AsyncRead + Send + Unpin, - Fut: Future> + 'a, - Fwd: FnMut(&'a Handle, ChannelId, CryptoVec) -> Fut, -{ - const BUF_SIZE: usize = 1024 * 32; - const MAX_RETRIES: usize = 5; - const RETRY_DELAY: u64 = 10; // ms - - let mut buf = [0u8; BUF_SIZE]; - loop { - let read = r.read(&mut buf).await?; - - if read == 0 { - break; - } - - let mut chunk = CryptoVec::from_slice(&buf[..read]); - let mut retries = 0; - loop { - match fwd(session_handle, chan_id, chunk).await { - Ok(()) => break, - Err(unsent) => { - retries += 1; - if retries >= MAX_RETRIES { - // Give up — connection is likely broken. Returning Ok (not Err) - // so the outer task can clean up gracefully without logging - // a spurious error for a normal disconnection. - return Ok(()); - } - chunk = unsent; - sleep(Duration::from_millis(RETRY_DELAY)).await; - } - } - } - } - - Ok(()) -} diff --git a/libs/git/ssh/mod.rs b/libs/git/ssh/mod.rs index a7cae39..20e544f 100644 --- a/libs/git/ssh/mod.rs +++ b/libs/git/ssh/mod.rs @@ -16,9 +16,13 @@ use std::sync::Arc; use std::time::Duration; pub mod authz; +pub mod branch_protect; +pub mod forward; +pub mod git_service; pub mod handle; -pub mod server; pub mod rate_limit; +pub mod ref_update; +pub mod server; #[derive(Clone)] pub struct SSHHandle { pub db: AppDatabase, @@ -106,7 +110,7 @@ impl SSHHandle { let mut config = Config::default(); config.keys = vec![private_key]; let version = format!("SSH-2.0-GitdataAI {}", env!("CARGO_PKG_VERSION")); - config.server_id = SshId::Standard(version); + config.server_id = SshId::Standard(version.into()); let mut method = MethodSet::empty(); method.push(MethodKind::PublicKey); method.push(MethodKind::KeyboardInteractive); diff --git a/libs/git/ssh/ref_update.rs b/libs/git/ssh/ref_update.rs new file mode 100644 index 0000000..41a4d0e --- /dev/null +++ b/libs/git/ssh/ref_update.rs @@ -0,0 +1,37 @@ +#[derive(Clone, Debug)] +pub struct RefUpdate { + pub name: String, + pub old_oid: String, + pub new_oid: String, +} + +impl RefUpdate { + /// Parse git reference update commands from SSH protocol text. + /// Format: " \n" + pub fn parse_ref_updates(data: &[u8]) -> Result, String> { + let text = String::from_utf8_lossy(data); + let mut refs = Vec::new(); + for line in text.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') || line.starts_with("PACK") { + continue; + } + let mut parts = line.split_whitespace(); + let old_oid = parts.next().map(|s| s.to_string()).unwrap_or_default(); + let new_oid = parts.next().map(|s| s.to_string()).unwrap_or_default(); + let name = parts + .next() + .unwrap_or("") + .trim_start_matches('\0') + .to_string(); + if !name.is_empty() { + refs.push(RefUpdate { + old_oid, + new_oid, + name, + }); + } + } + Ok(refs) + } +}