gitdataai/libs/service/agent/pr_summary.rs
2026-04-14 19:02:01 +08:00

375 lines
13 KiB
Rust

//! AI-powered PR description generation.
//!
//! Generates a structured description for pull requests based on the diff.
use crate::AppService;
use crate::error::AppError;
use crate::git::GitDomain;
use chrono::Utc;
use models::agents::ModelStatus;
use models::agents::model::{Column as MColumn, Entity as MEntity};
use models::pull_request::pull_request;
use models::repos::repo;
use sea_orm::*;
use serde::{Deserialize, Serialize};
use session::Session;
use utoipa::ToSchema;
use uuid::Uuid;
use super::billing::BillingRecord;
/// Structured PR description generated by AI.
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct PrDescription {
/// 3-5 line summary of what this PR does.
pub summary: String,
/// Key changes made in this PR.
pub changes: Vec<String>,
/// Potential risks or things to watch out for.
#[serde(default)]
pub risks: Vec<String>,
/// Suggested test scenarios.
#[serde(default)]
pub tests: Vec<String>,
}
/// Response from the AI description generation endpoint.
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct GeneratePrDescriptionResponse {
pub description: PrDescription,
/// Markdown-formatted description ready to paste into the PR body.
pub markdown_body: String,
pub billing: Option<BillingRecord>,
}
/// Request body for generating a PR description.
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct GeneratePrDescriptionRequest {
/// PR number to generate description for.
pub pr_number: Option<i64>,
/// Override the default AI model for this generation.
pub model_id: Option<Uuid>,
}
/// Internal response (passed from PR creation background task).
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct GeneratedPrDescription {
pub body: String,
pub created_by_ai: bool,
}
/// Build a prompt for PR description generation.
fn build_description_prompt(title: &str, body: Option<&str>, diff: &str) -> String {
let existing_desc = body
.map(|b| format!("Existing user description:\n{}", b))
.unwrap_or_default();
format!(
r#"You are an expert code reviewer. Generate a clear, concise pull request description.
PR Title: {title}
{existing_desc}
Changed files diff (truncated to key portions):
---
{diff}
---
Based on the PR title and diff, generate a structured description in this exact JSON format:
{{
"summary": "A 3-5 line summary of what this PR does and why",
"changes": ["List of key changes, one per item"],
"risks": ["Potential risks or considerations, if any"],
"tests": ["Suggested test scenarios to verify this PR"]
}}
Respond with ONLY the JSON object, no markdown code fences or extra text."#
)
}
/// Format the structured description as a markdown PR body.
fn format_as_markdown(_title: &str, desc: &PrDescription) -> String {
let mut lines = vec![format!("## Summary\n\n{}", desc.summary)];
if !desc.changes.is_empty() {
lines.push("\n## Changes\n\n".to_string());
for change in &desc.changes {
lines.push(format!("- {}", change));
}
}
if !desc.risks.is_empty() {
lines.push("\n## Risks & Considerations\n\n".to_string());
for risk in &desc.risks {
lines.push(format!("- ⚠️ {}", risk));
}
}
if !desc.tests.is_empty() {
lines.push("\n## Testing\n\n".to_string());
for test in &desc.tests {
lines.push(format!("- {}", test));
}
}
lines.push("\n---\n".to_string());
lines.push(format!(
"*🤖 Generated by AI · {}",
Utc::now().format("%Y-%m-%d")
));
lines.join("\n")
}
/// Call the AI model with a prompt and return the text response.
async fn call_ai_model_for_description(
model_name: &str,
prompt: &str,
app_config: &config::AppConfig,
) -> Result<agent::AiCallResponse, AppError> {
let api_key = app_config
.ai_api_key()
.map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?;
let base_url = app_config
.ai_basic_url()
.unwrap_or_else(|_| "https://api.openai.com".into());
let client_config = agent::AiClientConfig::new(api_key).with_base_url(base_url);
let messages = vec![
async_openai::types::chat::ChatCompletionRequestMessage::User(
async_openai::types::chat::ChatCompletionRequestUserMessage {
content: async_openai::types::chat::ChatCompletionRequestUserMessageContent::Text(
prompt.to_string(),
),
..Default::default()
},
),
];
agent::call_with_params(&messages, model_name, &client_config, 0.3, 4096, None, None)
.await
.map_err(|e| AppError::InternalServerError(format!("AI call failed: {}", e)))
}
/// Extract JSON from a response that may contain markdown code fences.
fn extract_json(s: &str) -> Option<String> {
// Try to find a JSON code block
if let Some(start) = s.find("```json") {
let rest = &s[start + 7..];
if let Some(end) = rest.find("```") {
return Some(rest[..end].trim().to_string());
}
}
if let Some(start) = s.find("```") {
let rest = &s[start + 3..];
if let Some(end) = rest.find("```") {
return Some(rest[..end].trim().to_string());
}
}
// Try raw JSON
let trimmed = s.trim();
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return Some(trimmed.to_string());
}
None
}
impl AppService {
/// Public entry point — performs session auth then delegates.
pub async fn generate_pr_description(
&self,
namespace: String,
repo_name: String,
request: GeneratePrDescriptionRequest,
ctx: &Session,
) -> Result<GeneratePrDescriptionResponse, AppError> {
let repo = self
.utils_find_repo(namespace.clone(), repo_name.clone(), ctx)
.await?;
let pr = match request.pr_number {
Some(n) => pull_request::Entity::find()
.filter(pull_request::Column::Repo.eq(repo.id))
.filter(pull_request::Column::Number.eq(n))
.one(&self.db)
.await?
.ok_or_else(|| AppError::NotFound("Pull request not found".to_string()))?,
None => pull_request::Entity::find()
.filter(pull_request::Column::Repo.eq(repo.id))
.order_by_desc(pull_request::Column::CreatedAt)
.one(&self.db)
.await?
.ok_or_else(|| AppError::NotFound("No pull request found".to_string()))?,
};
self.generate_pr_description_internal(pr, repo, request.model_id)
.await
}
/// Internal entry point — skips auth. Used by background tasks.
pub async fn generate_pr_description_internal(
&self,
pr: pull_request::Model,
repo: repo::Model,
model_id: Option<Uuid>,
) -> Result<GeneratePrDescriptionResponse, AppError> {
// Find a model first so we can use its context limit for diff truncation
let model = match model_id {
Some(id) => MEntity::find_by_id(id)
.one(&self.db)
.await?
.ok_or(AppError::NotFound("Model not found".to_string()))?,
None => MEntity::find()
.filter(MColumn::Status.eq(ModelStatus::Active.to_string()))
.order_by_asc(MColumn::Name)
.one(&self.db)
.await?
.ok_or_else(|| {
AppError::InternalServerError(
"No active AI model found. Please configure an AI model first.".into(),
)
})?,
};
// Get the diff with token-aware truncation
let diff = self
.get_pr_description_diff(&repo, &pr, &model.name, model.context_length)
.await?;
// Build prompt and call AI
let prompt = build_description_prompt(&pr.title, pr.body.as_deref(), &diff);
let ai_response = call_ai_model_for_description(&model.name, &prompt, &self.config).await?;
// Record billing (non-fatal).
let billing = self
.record_ai_usage(
repo.project,
model.id,
ai_response.input_tokens,
ai_response.output_tokens,
)
.await
.inspect_err(|e| {
slog::warn!(
self.logs,
"failed to record AI billing for PR description";
"project" => %repo.project,
"error" => ?e
);
})
.ok();
// Parse JSON response
let json_str =
extract_json(&ai_response.content).unwrap_or_else(|| ai_response.content.clone());
let pr_desc: PrDescription = serde_json::from_str(&json_str).map_err(|e| {
AppError::InternalServerError(format!(
"Failed to parse AI response as JSON: {}. Raw: {}",
e,
&ai_response.content[..ai_response.content.len().min(200)]
))
})?;
let markdown_body = format_as_markdown(&pr.title, &pr_desc);
Ok(GeneratePrDescriptionResponse {
description: pr_desc,
markdown_body,
billing,
})
}
/// Get the diff for PR description generation (unified format, truncated).
async fn get_pr_description_diff(
&self,
repo: &repo::Model,
pr: &pull_request::Model,
model_name: &str,
context_limit: i64,
) -> Result<String, AppError> {
let oids_result = {
let base = pr.base.clone();
let head = pr.head.clone();
let repo_model = repo.clone();
let handle: tokio::task::JoinHandle<Result<(git2::Oid, git2::Oid), AppError>> =
tokio::task::spawn_blocking(move || {
let domain = GitDomain::from_model(repo_model)?;
let base_commit_oid = domain
.branch_target(&base)
.map_err(|e| crate::git::GitError::Internal(e.to_string()))?
.ok_or_else(|| {
AppError::NotFound(format!("Branch '{}' not found", base))
})?;
let head_commit_oid = domain
.branch_target(&head)
.map_err(|e| crate::git::GitError::Internal(e.to_string()))?
.ok_or_else(|| {
AppError::NotFound(format!("Branch '{}' not found", head))
})?;
let base_oid = base_commit_oid.to_oid().map_err(|e| {
AppError::InternalServerError(format!("Invalid OID: {}", e))
})?;
let head_oid = head_commit_oid.to_oid().map_err(|e| {
AppError::InternalServerError(format!("Invalid OID: {}", e))
})?;
Ok((base_oid, head_oid))
});
handle
}
.await
.map_err(|e| AppError::InternalServerError(format!("Join error: {}", e)))?;
let (base_oid, head_oid) = oids_result.map_err(|e| {
AppError::InternalServerError(format!("Failed to resolve branch OIDs: {:?}", e))
})?;
let repo_for_diff = repo.clone();
let diff_text = tokio::task::spawn_blocking(move || -> Result<String, AppError> {
let domain = GitDomain::from_model(repo_for_diff)?;
let old_tree = domain
.repo()
.find_tree(base_oid)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let new_tree = domain
.repo()
.find_tree(head_oid)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let diff = domain
.repo()
.diff_tree_to_tree(Some(&old_tree), Some(&new_tree), None)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let mut patch_buf: Vec<u8> = Vec::new();
diff.print(git2::DiffFormat::Patch, |_delta, _hunk, line| {
patch_buf.extend_from_slice(line.content());
patch_buf.push(b'\n');
true
})
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
String::from_utf8(patch_buf).map_err(|e| AppError::InternalServerError(e.to_string()))
})
.await
.map_err(|e| AppError::InternalServerError(format!("Task join error: {e}")))?
.map_err(AppError::from)?;
// Truncate to avoid token limits with token-aware budgeting.
// Reserve 4096 tokens for output + system overhead.
let reserve = 4096;
match agent::tokent::truncate_to_token_budget(
&diff_text,
model_name,
context_limit as usize,
reserve,
) {
Ok(truncated) if truncated.len() < diff_text.len() => Ok(format!(
"{}...\n[diff truncated to fit token budget]",
truncated
)),
_ => Ok(diff_text),
}
}
}