114 lines
3.3 KiB
Rust
114 lines
3.3 KiB
Rust
use std::future::Future;
|
|
use std::time::Duration;
|
|
|
|
use crate::agent::request::AgentRequest;
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
pub fn build_input_string(request: &AgentRequest) -> String {
|
|
let mut input = String::new();
|
|
|
|
if !request.context.is_empty() {
|
|
input.push_str("<retrieved_context>\n");
|
|
for chunk in &request.context {
|
|
let source = chunk.source.as_deref().unwrap_or("unknown");
|
|
let score = chunk
|
|
.score
|
|
.map(|s| format!("{s:.4}"))
|
|
.unwrap_or_else(|| "n/a".to_string());
|
|
input.push_str(&format!(
|
|
"\n<chunk id=\"{}\" source=\"{}\" score=\"{}\">\n{}\n</chunk>\n",
|
|
chunk.id, source, score, chunk.content
|
|
));
|
|
}
|
|
input.push_str("</retrieved_context>\n\n");
|
|
}
|
|
|
|
for message in &request.messages {
|
|
match message {
|
|
super::request::AgentMessage::User(content) => {
|
|
input.push_str(&format!("User: {content}\n"));
|
|
}
|
|
super::request::AgentMessage::Assistant(content) => {
|
|
input.push_str(&format!("Assistant: {content}\n"));
|
|
}
|
|
}
|
|
}
|
|
|
|
input.push_str(&format!("User: {}", request.input));
|
|
|
|
input
|
|
}
|
|
|
|
pub fn estimate_tokens(text: &str) -> u64 {
|
|
if text.is_empty() {
|
|
return 0;
|
|
}
|
|
(text.chars().count() as f64 / 2.5).ceil() as u64
|
|
}
|
|
|
|
pub fn check_token_budget(
|
|
estimated_input_tokens: u64,
|
|
accumulated_output_chars: usize,
|
|
limit: i64,
|
|
) -> bool {
|
|
let output_estimate = (accumulated_output_chars as f64 / 2.5).ceil() as u64;
|
|
estimated_input_tokens + output_estimate > limit as u64
|
|
}
|
|
|
|
pub async fn with_retry<F, Fut, T>(
|
|
max_attempts: usize,
|
|
base_delay_ms: u64,
|
|
f: F,
|
|
) -> AiResult<T>
|
|
where
|
|
F: Fn() -> Fut,
|
|
Fut: Future<Output = AiResult<T>>,
|
|
{
|
|
let mut last_error: Option<AiError> = None;
|
|
for attempt in 0..max_attempts {
|
|
match f().await {
|
|
Ok(result) => return Ok(result),
|
|
Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => {
|
|
let delay = Duration::from_millis(base_delay_ms * 2u64.pow(attempt as u32));
|
|
tracing::warn!(
|
|
error = %e,
|
|
attempt = attempt + 1,
|
|
max_attempts,
|
|
delay_ms = delay.as_millis(),
|
|
"retrying after transient error"
|
|
);
|
|
tokio::time::sleep(delay).await;
|
|
last_error = Some(e);
|
|
}
|
|
Err(e) => return Err(e),
|
|
}
|
|
}
|
|
Err(AiError::ModelRetriesExhausted {
|
|
attempts: max_attempts,
|
|
last_error: last_error
|
|
.map(|e| e.to_string())
|
|
.unwrap_or_else(|| "unknown".to_string()),
|
|
})
|
|
}
|
|
|
|
fn is_retryable(error: &AiError) -> bool {
|
|
matches!(
|
|
error,
|
|
AiError::Api(_) | AiError::Response(_) | AiError::ModelRetriesExhausted { .. }
|
|
)
|
|
}
|
|
|
|
pub fn tool_result_content_to_string(
|
|
content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>,
|
|
) -> String {
|
|
use rig::completion::message::ToolResultContent;
|
|
content
|
|
.iter()
|
|
.filter_map(|item| match item {
|
|
ToolResultContent::Text(t) => Some(t.text.clone()),
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|