375 lines
13 KiB
Rust
375 lines
13 KiB
Rust
//! AI-powered PR description generation.
|
|
//!
|
|
//! Generates a structured description for pull requests based on the diff.
|
|
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
use crate::git::GitDomain;
|
|
use chrono::Utc;
|
|
use models::agents::ModelStatus;
|
|
use models::agents::model::{Column as MColumn, Entity as MEntity};
|
|
use models::pull_request::pull_request;
|
|
use models::repos::repo;
|
|
use sea_orm::*;
|
|
use serde::{Deserialize, Serialize};
|
|
use session::Session;
|
|
use utoipa::ToSchema;
|
|
use uuid::Uuid;
|
|
|
|
use super::billing::BillingRecord;
|
|
|
|
/// Structured PR description generated by AI.
|
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
|
pub struct PrDescription {
|
|
/// 3-5 line summary of what this PR does.
|
|
pub summary: String,
|
|
/// Key changes made in this PR.
|
|
pub changes: Vec<String>,
|
|
/// Potential risks or things to watch out for.
|
|
#[serde(default)]
|
|
pub risks: Vec<String>,
|
|
/// Suggested test scenarios.
|
|
#[serde(default)]
|
|
pub tests: Vec<String>,
|
|
}
|
|
|
|
/// Response from the AI description generation endpoint.
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct GeneratePrDescriptionResponse {
|
|
pub description: PrDescription,
|
|
/// Markdown-formatted description ready to paste into the PR body.
|
|
pub markdown_body: String,
|
|
pub billing: Option<BillingRecord>,
|
|
}
|
|
|
|
/// Request body for generating a PR description.
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct GeneratePrDescriptionRequest {
|
|
/// PR number to generate description for.
|
|
pub pr_number: Option<i64>,
|
|
/// Override the default AI model for this generation.
|
|
pub model_id: Option<Uuid>,
|
|
}
|
|
|
|
/// Internal response (passed from PR creation background task).
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct GeneratedPrDescription {
|
|
pub body: String,
|
|
pub created_by_ai: bool,
|
|
}
|
|
|
|
/// Build a prompt for PR description generation.
|
|
fn build_description_prompt(title: &str, body: Option<&str>, diff: &str) -> String {
|
|
let existing_desc = body
|
|
.map(|b| format!("Existing user description:\n{}", b))
|
|
.unwrap_or_default();
|
|
format!(
|
|
r#"You are an expert code reviewer. Generate a clear, concise pull request description.
|
|
|
|
PR Title: {title}
|
|
{existing_desc}
|
|
|
|
Changed files diff (truncated to key portions):
|
|
---
|
|
{diff}
|
|
---
|
|
|
|
Based on the PR title and diff, generate a structured description in this exact JSON format:
|
|
{{
|
|
"summary": "A 3-5 line summary of what this PR does and why",
|
|
"changes": ["List of key changes, one per item"],
|
|
"risks": ["Potential risks or considerations, if any"],
|
|
"tests": ["Suggested test scenarios to verify this PR"]
|
|
}}
|
|
|
|
Respond with ONLY the JSON object, no markdown code fences or extra text."#
|
|
)
|
|
}
|
|
|
|
/// Format the structured description as a markdown PR body.
|
|
fn format_as_markdown(_title: &str, desc: &PrDescription) -> String {
|
|
let mut lines = vec![format!("## Summary\n\n{}", desc.summary)];
|
|
|
|
if !desc.changes.is_empty() {
|
|
lines.push("\n## Changes\n\n".to_string());
|
|
for change in &desc.changes {
|
|
lines.push(format!("- {}", change));
|
|
}
|
|
}
|
|
|
|
if !desc.risks.is_empty() {
|
|
lines.push("\n## Risks & Considerations\n\n".to_string());
|
|
for risk in &desc.risks {
|
|
lines.push(format!("- ⚠️ {}", risk));
|
|
}
|
|
}
|
|
|
|
if !desc.tests.is_empty() {
|
|
lines.push("\n## Testing\n\n".to_string());
|
|
for test in &desc.tests {
|
|
lines.push(format!("- {}", test));
|
|
}
|
|
}
|
|
|
|
lines.push("\n---\n".to_string());
|
|
lines.push(format!(
|
|
"*🤖 Generated by AI · {}",
|
|
Utc::now().format("%Y-%m-%d")
|
|
));
|
|
|
|
lines.join("\n")
|
|
}
|
|
|
|
/// Call the AI model with a prompt and return the text response.
|
|
async fn call_ai_model_for_description(
|
|
model_name: &str,
|
|
prompt: &str,
|
|
app_config: &config::AppConfig,
|
|
) -> Result<agent::AiCallResponse, AppError> {
|
|
let api_key = app_config
|
|
.ai_api_key()
|
|
.map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?;
|
|
|
|
let base_url = app_config
|
|
.ai_basic_url()
|
|
.unwrap_or_else(|_| "https://api.openai.com".into());
|
|
|
|
let client_config = agent::AiClientConfig::new(api_key).with_base_url(base_url);
|
|
|
|
let messages = vec![
|
|
async_openai::types::chat::ChatCompletionRequestMessage::User(
|
|
async_openai::types::chat::ChatCompletionRequestUserMessage {
|
|
content: async_openai::types::chat::ChatCompletionRequestUserMessageContent::Text(
|
|
prompt.to_string(),
|
|
),
|
|
..Default::default()
|
|
},
|
|
),
|
|
];
|
|
|
|
agent::call_with_params(&messages, model_name, &client_config, 0.3, 4096, None, None)
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(format!("AI call failed: {}", e)))
|
|
}
|
|
|
|
/// Extract JSON from a response that may contain markdown code fences.
|
|
fn extract_json(s: &str) -> Option<String> {
|
|
// Try to find a JSON code block
|
|
if let Some(start) = s.find("```json") {
|
|
let rest = &s[start + 7..];
|
|
if let Some(end) = rest.find("```") {
|
|
return Some(rest[..end].trim().to_string());
|
|
}
|
|
}
|
|
if let Some(start) = s.find("```") {
|
|
let rest = &s[start + 3..];
|
|
if let Some(end) = rest.find("```") {
|
|
return Some(rest[..end].trim().to_string());
|
|
}
|
|
}
|
|
// Try raw JSON
|
|
let trimmed = s.trim();
|
|
if trimmed.starts_with('{') || trimmed.starts_with('[') {
|
|
return Some(trimmed.to_string());
|
|
}
|
|
None
|
|
}
|
|
|
|
impl AppService {
|
|
/// Public entry point — performs session auth then delegates.
|
|
pub async fn generate_pr_description(
|
|
&self,
|
|
namespace: String,
|
|
repo_name: String,
|
|
request: GeneratePrDescriptionRequest,
|
|
ctx: &Session,
|
|
) -> Result<GeneratePrDescriptionResponse, AppError> {
|
|
let repo = self
|
|
.utils_find_repo(namespace.clone(), repo_name.clone(), ctx)
|
|
.await?;
|
|
|
|
let pr = match request.pr_number {
|
|
Some(n) => pull_request::Entity::find()
|
|
.filter(pull_request::Column::Repo.eq(repo.id))
|
|
.filter(pull_request::Column::Number.eq(n))
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or_else(|| AppError::NotFound("Pull request not found".to_string()))?,
|
|
None => pull_request::Entity::find()
|
|
.filter(pull_request::Column::Repo.eq(repo.id))
|
|
.order_by_desc(pull_request::Column::CreatedAt)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or_else(|| AppError::NotFound("No pull request found".to_string()))?,
|
|
};
|
|
|
|
self.generate_pr_description_internal(pr, repo, request.model_id)
|
|
.await
|
|
}
|
|
|
|
/// Internal entry point — skips auth. Used by background tasks.
|
|
pub async fn generate_pr_description_internal(
|
|
&self,
|
|
pr: pull_request::Model,
|
|
repo: repo::Model,
|
|
model_id: Option<Uuid>,
|
|
) -> Result<GeneratePrDescriptionResponse, AppError> {
|
|
// Find a model first so we can use its context limit for diff truncation
|
|
let model = match model_id {
|
|
Some(id) => MEntity::find_by_id(id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or(AppError::NotFound("Model not found".to_string()))?,
|
|
None => MEntity::find()
|
|
.filter(MColumn::Status.eq(ModelStatus::Active.to_string()))
|
|
.order_by_asc(MColumn::Name)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or_else(|| {
|
|
AppError::InternalServerError(
|
|
"No active AI model found. Please configure an AI model first.".into(),
|
|
)
|
|
})?,
|
|
};
|
|
|
|
// Get the diff with token-aware truncation
|
|
let diff = self
|
|
.get_pr_description_diff(&repo, &pr, &model.name, model.context_length)
|
|
.await?;
|
|
|
|
// Build prompt and call AI
|
|
let prompt = build_description_prompt(&pr.title, pr.body.as_deref(), &diff);
|
|
let ai_response = call_ai_model_for_description(&model.name, &prompt, &self.config).await?;
|
|
|
|
// Record billing (non-fatal).
|
|
let billing = self
|
|
.record_ai_usage(
|
|
repo.project,
|
|
model.id,
|
|
ai_response.input_tokens,
|
|
ai_response.output_tokens,
|
|
)
|
|
.await
|
|
.inspect_err(|e| {
|
|
slog::warn!(
|
|
self.logs,
|
|
"failed to record AI billing for PR description";
|
|
"project" => %repo.project,
|
|
"error" => ?e
|
|
);
|
|
})
|
|
.ok();
|
|
|
|
// Parse JSON response
|
|
let json_str =
|
|
extract_json(&ai_response.content).unwrap_or_else(|| ai_response.content.clone());
|
|
|
|
let pr_desc: PrDescription = serde_json::from_str(&json_str).map_err(|e| {
|
|
AppError::InternalServerError(format!(
|
|
"Failed to parse AI response as JSON: {}. Raw: {}",
|
|
e,
|
|
&ai_response.content[..ai_response.content.len().min(200)]
|
|
))
|
|
})?;
|
|
|
|
let markdown_body = format_as_markdown(&pr.title, &pr_desc);
|
|
|
|
Ok(GeneratePrDescriptionResponse {
|
|
description: pr_desc,
|
|
markdown_body,
|
|
billing,
|
|
})
|
|
}
|
|
|
|
/// Get the diff for PR description generation (unified format, truncated).
|
|
async fn get_pr_description_diff(
|
|
&self,
|
|
repo: &repo::Model,
|
|
pr: &pull_request::Model,
|
|
model_name: &str,
|
|
context_limit: i64,
|
|
) -> Result<String, AppError> {
|
|
let oids_result = {
|
|
let base = pr.base.clone();
|
|
let head = pr.head.clone();
|
|
let repo_model = repo.clone();
|
|
let handle: tokio::task::JoinHandle<Result<(git2::Oid, git2::Oid), AppError>> =
|
|
tokio::task::spawn_blocking(move || {
|
|
let domain = GitDomain::from_model(repo_model)?;
|
|
let base_commit_oid = domain
|
|
.branch_target(&base)
|
|
.map_err(|e| crate::git::GitError::Internal(e.to_string()))?
|
|
.ok_or_else(|| {
|
|
AppError::NotFound(format!("Branch '{}' not found", base))
|
|
})?;
|
|
let head_commit_oid = domain
|
|
.branch_target(&head)
|
|
.map_err(|e| crate::git::GitError::Internal(e.to_string()))?
|
|
.ok_or_else(|| {
|
|
AppError::NotFound(format!("Branch '{}' not found", head))
|
|
})?;
|
|
let base_oid = base_commit_oid.to_oid().map_err(|e| {
|
|
AppError::InternalServerError(format!("Invalid OID: {}", e))
|
|
})?;
|
|
let head_oid = head_commit_oid.to_oid().map_err(|e| {
|
|
AppError::InternalServerError(format!("Invalid OID: {}", e))
|
|
})?;
|
|
Ok((base_oid, head_oid))
|
|
});
|
|
handle
|
|
}
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(format!("Join error: {}", e)))?;
|
|
let (base_oid, head_oid) = oids_result.map_err(|e| {
|
|
AppError::InternalServerError(format!("Failed to resolve branch OIDs: {:?}", e))
|
|
})?;
|
|
|
|
let repo_for_diff = repo.clone();
|
|
let diff_text = tokio::task::spawn_blocking(move || -> Result<String, AppError> {
|
|
let domain = GitDomain::from_model(repo_for_diff)?;
|
|
let old_tree = domain
|
|
.repo()
|
|
.find_tree(base_oid)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let new_tree = domain
|
|
.repo()
|
|
.find_tree(head_oid)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let diff = domain
|
|
.repo()
|
|
.diff_tree_to_tree(Some(&old_tree), Some(&new_tree), None)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let mut patch_buf: Vec<u8> = Vec::new();
|
|
diff.print(git2::DiffFormat::Patch, |_delta, _hunk, line| {
|
|
patch_buf.extend_from_slice(line.content());
|
|
patch_buf.push(b'\n');
|
|
true
|
|
})
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
String::from_utf8(patch_buf).map_err(|e| AppError::InternalServerError(e.to_string()))
|
|
})
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(format!("Task join error: {e}")))?
|
|
.map_err(AppError::from)?;
|
|
|
|
// Truncate to avoid token limits with token-aware budgeting.
|
|
// Reserve 4096 tokens for output + system overhead.
|
|
let reserve = 4096;
|
|
match agent::tokent::truncate_to_token_budget(
|
|
&diff_text,
|
|
model_name,
|
|
context_limit as usize,
|
|
reserve,
|
|
) {
|
|
Ok(truncated) if truncated.len() < diff_text.len() => Ok(format!(
|
|
"{}...\n[diff truncated to fit token budget]",
|
|
truncated
|
|
)),
|
|
_ => Ok(diff_text),
|
|
}
|
|
}
|
|
}
|