gitdataai/lib/ai/agent/prompt.rs

57 lines
1.9 KiB
Rust

use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::completion::Prompt;
use super::agent::RigAgent;
use super::helpers::with_retry;
use crate::error::{AiError, AiResult};
impl RigAgent {
pub async fn prompt(
&self,
system_prompt: &str,
user_input: &str,
) -> AiResult<(String, u64, u64)> {
let model_name = self.config.model.clone();
let client = self.client.llm_client().clone();
let temperature = self.config.temperature;
let max_completion_tokens = self.config.max_completion_tokens;
let retry_attempts = self.config.retry_max_attempts;
let retry_delay_ms = self.config.retry_base_delay_ms;
let sp = system_prompt.to_string();
let ui = user_input.to_string();
with_retry(retry_attempts, retry_delay_ms, || {
let client = client.clone();
let model_name = model_name.clone();
let sp = sp.clone();
let ui = ui.clone();
async move {
let model = client.completion_model(&model_name);
let mut builder = AgentBuilder::new(model).preamble(&sp);
if let Some(temp) = temperature {
builder = builder.temperature(temp);
}
if let Some(mt) = max_completion_tokens {
builder = builder.max_tokens(mt);
}
let agent = builder.build();
let response =
agent.prompt(&ui).extended_details().await.map_err(
|e: rig::completion::PromptError| {
AiError::Api(e.to_string())
},
)?;
Ok((
response.output,
response.usage.input_tokens,
response.usage.output_tokens,
))
}
})
.await
}
}