gitdataai/lib/ai/embed/client.rs

93 lines
2.6 KiB
Rust

use rig::client::EmbeddingsClient;
use rig::embeddings::EmbeddingModel;
use crate::{
client::AiClient,
error::{AiError, AiResult},
};
#[derive(Clone)]
pub struct EmbedClient {
model_name: String,
client: rig::providers::openai::Client,
}
impl EmbedClient {
pub fn new(ai_client: &AiClient) -> AiResult<Self> {
Ok(Self {
model_name: ai_client.embed_model().to_string(),
client: ai_client.embed_client().clone(),
})
}
fn embedding_model(&self) -> impl EmbeddingModel + '_ {
self.client.embedding_model(&self.model_name)
}
pub async fn embed_text(&self, text: String) -> AiResult<Vec<f32>> {
let model = self.embedding_model();
let mut embeddings = model
.embed_texts(vec![text])
.await
.map_err(|e| AiError::Api(e.to_string()))?;
embeddings
.pop()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
.ok_or_else(|| {
AiError::Response("no embedding returned".to_string())
})
}
pub async fn embed_texts(
&self,
texts: Vec<String>,
) -> AiResult<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let model = self.embedding_model();
let embeddings = model
.embed_texts(texts)
.await
.map_err(|e| AiError::Api(e.to_string()))?;
Ok(embeddings
.into_iter()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
.collect())
}
pub async fn embed_texts_chunked(
&self,
texts: Vec<String>,
batch_size: usize,
) -> AiResult<Vec<Vec<f32>>> {
if batch_size == 0 {
return Err(AiError::Config("batch_size must be > 0".to_string()));
}
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
for chunk in texts.chunks(batch_size) {
let model = self.embedding_model();
let chunk_embeddings = model
.embed_texts(chunk.to_vec())
.await
.map_err(|e| AiError::Api(e.to_string()))?;
embeddings.extend(
chunk_embeddings
.into_iter()
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()),
);
}
Ok(embeddings)
}
}
pub trait AiClientEmbedExt {
fn embedder(&self) -> AiResult<EmbedClient>;
}
impl AiClientEmbedExt for AiClient {
fn embedder(&self) -> AiResult<EmbedClient> {
EmbedClient::new(self)
}
}