gitdataai/libs/agent/orao/reason.rs

216 lines
7.2 KiB
Rust

//! Reason phase: structured reasoning and plan generation via LLM.
//!
//! Takes the current [`PerceptionSnapshot`] and the task goal, sends them to
//! the configured LLM, and produces a structured [`ReasoningOutput`] with an
//! executable plan.
use crate::client::AiClientConfig;
use crate::error::AgentError;
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::completion::Prompt;
use super::types::{OraoConfig, PerceptionSnapshot, ReasoningOutput};
/// Prompt template for the ORAO reasoning phase.
const REASON_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework.
## Your Role: REASON Phase
You are currently in the REASON phase. You will receive:
1. A **task goal** — what needs to be accomplished
2. An **observation snapshot** — the current state of the environment
Your job is to produce a structured analysis and a step-by-step action plan.
## Output Format
You MUST respond with a valid JSON object matching this schema:
```json
{
"analysis": "<concise analysis of the current state, problems identified, constraints>",
"plan": [
{
"step_id": 1,
"description": "<what this step does>",
"action_type": "shell_command | file_write | file_edit | git_operation | tool_invoke | user_dialog",
"command_or_content": "<the exact command to run or content to write>",
"expected_result": "<what success looks like>",
"fallback_on_failure": "<what to do if this fails, or null>"
}
]
}
```
## Rules
1. **Be specific**: commands must be exact and complete (include necessary flags, paths, etc.)
2. **One action per step**: each step should do one thing
3. **Validate assumptions**: don't assume dependencies are installed or files exist
4. **Prefer small, safe steps**: each change should be minimal and reversible
5. **Include verification**: build/test after changes to verify correctness
6. **Plan size**: 1-10 steps is typical. Don't plan more than 15 steps without asking the user.
7. **Fallbacks**: for risky steps, specify what to try if the step fails
## Safety
- Prefer read-only verification before making changes
- Use version control (git) for reversible operations
- Flag any step that requires network access or system-level privileges"#;
/// Run the reasoning phase: send observation + task to the LLM and get a plan.
pub async fn reason(
config: &AiClientConfig,
model_name: &str,
_orao_config: &OraoConfig,
task_goal: &str,
snapshot: &PerceptionSnapshot,
round: usize,
history_rounds: &[super::types::RoundRecord],
) -> Result<ReasoningOutput, AgentError> {
let user_prompt = build_reason_prompt(task_goal, snapshot, round, history_rounds);
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let agent = AgentBuilder::new(model)
.preamble(REASON_SYSTEM_PROMPT)
.build();
let response = agent
.prompt(&user_prompt)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
let output = response.output.trim().to_string();
// Extract JSON from the response (it may be wrapped in ```json fences)
let json_str = extract_json(&output);
serde_json::from_str::<ReasoningOutput>(json_str).map_err(|e| {
AgentError::Internal(format!(
"Failed to parse reasoning output as JSON: {}. Raw output: {}",
e, output
))
})
}
/// Build the user prompt for the reason phase.
fn build_reason_prompt(
task_goal: &str,
snapshot: &PerceptionSnapshot,
round: usize,
history_rounds: &[super::types::RoundRecord],
) -> String {
let mut prompt = format!(
"## Task Goal\n\n{}\n\n## Round Number\n\n{}\n\n",
task_goal, round
);
// ── Observation snapshot ────────────────────────────────────────────
prompt.push_str("## Current Observation\n\n");
if let Some(ref git_status) = snapshot.git_status {
prompt.push_str(&format!("### Git Status\n```\n{}\n```\n\n", git_status));
}
if let Some(ref project_structure) = snapshot.project_structure {
prompt.push_str(&format!(
"### Project Structure\n```\n{}\n```\n\n",
project_structure
));
}
if !snapshot.files.is_empty() {
prompt.push_str("### Relevant Files\n\n");
for file in &snapshot.files {
prompt.push_str(&format!("- `{}` ({} bytes)\n", file.path, file.size_bytes));
if let Some(ref preview) = file.content_preview {
let truncated = if preview.len() > 2000 {
format!("{}... [truncated]", &preview[..2000])
} else {
preview.clone()
};
prompt.push_str(&format!("```\n{}\n```\n\n", truncated));
}
}
}
if !snapshot.errors.is_empty() {
prompt.push_str("### Errors / Warnings\n\n");
for err in &snapshot.errors {
prompt.push_str(&format!("- {}\n", err));
}
prompt.push('\n');
}
if let Some(ref prev_result) = snapshot.previous_action_result {
prompt.push_str(&format!(
"### Previous Action Result\n- Verdict: {:?}\n- Exit code: {:?}\n- stdout: {}\n- stderr: {}\n\n",
prev_result.verdict,
prev_result.exit_code,
truncate_str(&prev_result.stdout, 1000),
truncate_str(&prev_result.stderr, 1000),
));
}
// ── History summary ─────────────────────────────────────────────────
if !history_rounds.is_empty() {
prompt.push_str("## Previous Rounds\n\n");
for record in history_rounds.iter().rev().take(3) {
prompt.push_str(&format!(
"- Round {}: {} | {} tokens | {}ms\n",
record.round,
truncate_str(&record.reasoning_summary, 120),
record.tokens_input + record.tokens_output,
record.duration_ms,
));
}
prompt.push('\n');
}
prompt.push_str(
"## Instructions\n\nProduce a structured analysis and action plan in JSON format as specified.",
);
prompt
}
/// Extract JSON content from a string that may be wrapped in markdown fences.
fn extract_json(s: &str) -> &str {
let s = s.trim();
if s.starts_with("```json") {
let inner = &s[7..];
if let Some(end) = inner.rfind("```") {
return inner[..end].trim();
}
return inner.trim();
}
if s.starts_with("```") {
let inner = &s[3..];
if let Some(end) = inner.rfind("```") {
return inner[..end].trim();
}
return inner.trim();
}
// Heuristic: find first { and last }
if let (Some(start), Some(end)) = (s.find('{'), s.rfind('}')) {
return &s[start..=end];
}
s
}
fn truncate_str(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len])
}
}