58 lines
1.9 KiB
Rust
58 lines
1.9 KiB
Rust
use rig::agent::AgentBuilder;
|
|
use rig::client::CompletionClient;
|
|
use rig::completion::Prompt;
|
|
|
|
use super::agent::RigAgent;
|
|
use super::helpers::with_retry;
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
impl RigAgent {
|
|
pub async fn prompt(
|
|
&self,
|
|
system_prompt: &str,
|
|
user_input: &str,
|
|
) -> AiResult<(String, u64, u64)> {
|
|
let model_name = self.config.model.clone();
|
|
let client = self.client.llm_client().clone();
|
|
let temperature = self.config.temperature;
|
|
let max_completion_tokens = self.config.max_completion_tokens;
|
|
let retry_attempts = self.config.retry_max_attempts;
|
|
let retry_delay_ms = self.config.retry_base_delay_ms;
|
|
let sp = system_prompt.to_string();
|
|
let ui = user_input.to_string();
|
|
|
|
with_retry(retry_attempts, retry_delay_ms, || {
|
|
let client = client.clone();
|
|
let model_name = model_name.clone();
|
|
let sp = sp.clone();
|
|
let ui = ui.clone();
|
|
async move {
|
|
let model = client.completion_model(&model_name);
|
|
let mut builder = AgentBuilder::new(model).preamble(&sp);
|
|
if let Some(temp) = temperature {
|
|
builder = builder.temperature(temp);
|
|
}
|
|
if let Some(mt) = max_completion_tokens {
|
|
builder = builder.max_tokens(mt);
|
|
}
|
|
let agent = builder.build();
|
|
|
|
let response = agent
|
|
.prompt(&ui)
|
|
.extended_details()
|
|
.await
|
|
.map_err(|e: rig::completion::PromptError| {
|
|
AiError::Api(e.to_string())
|
|
})?;
|
|
|
|
Ok((
|
|
response.output,
|
|
response.usage.input_tokens,
|
|
response.usage.output_tokens,
|
|
))
|
|
}
|
|
})
|
|
.await
|
|
}
|
|
}
|