77 lines
2.4 KiB
Rust
77 lines
2.4 KiB
Rust
use rig::client::EmbeddingsClient;
|
|
use rig::embeddings::EmbeddingModel;
|
|
|
|
use crate::{client::AiClient, error::{AiError, AiResult}};
|
|
|
|
#[derive(Clone)]
|
|
pub struct EmbedClient {
|
|
model_name: String,
|
|
client: rig::providers::openai::Client,
|
|
}
|
|
|
|
impl EmbedClient {
|
|
pub fn new(ai_client: &AiClient) -> AiResult<Self> {
|
|
Ok(Self {
|
|
model_name: ai_client.embed_model().to_string(),
|
|
client: ai_client.embed_client().clone(),
|
|
})
|
|
}
|
|
|
|
fn embedding_model(&self) -> impl EmbeddingModel + '_ {
|
|
self.client.embedding_model(&self.model_name)
|
|
}
|
|
|
|
pub async fn embed_text(&self, text: String) -> AiResult<Vec<f32>> {
|
|
let model = self.embedding_model();
|
|
let mut embeddings = model.embed_texts(vec![text])
|
|
.await
|
|
.map_err(|e| AiError::Api(e.to_string()))?;
|
|
embeddings.pop()
|
|
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
|
|
.ok_or_else(|| AiError::Response("no embedding returned".to_string()))
|
|
}
|
|
|
|
pub async fn embed_texts(&self, texts: Vec<String>) -> AiResult<Vec<Vec<f32>>> {
|
|
if texts.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
let model = self.embedding_model();
|
|
let embeddings = model.embed_texts(texts)
|
|
.await
|
|
.map_err(|e| AiError::Api(e.to_string()))?;
|
|
Ok(embeddings.into_iter()
|
|
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
|
|
.collect())
|
|
}
|
|
|
|
pub async fn embed_texts_chunked(
|
|
&self,
|
|
texts: Vec<String>,
|
|
batch_size: usize,
|
|
) -> AiResult<Vec<Vec<f32>>> {
|
|
if batch_size == 0 {
|
|
return Err(AiError::Config("batch_size must be > 0".to_string()));
|
|
}
|
|
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
|
|
for chunk in texts.chunks(batch_size) {
|
|
let model = self.embedding_model();
|
|
let chunk_embeddings = model.embed_texts(chunk.to_vec())
|
|
.await
|
|
.map_err(|e| AiError::Api(e.to_string()))?;
|
|
embeddings.extend(chunk_embeddings.into_iter()
|
|
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()));
|
|
}
|
|
Ok(embeddings)
|
|
}
|
|
}
|
|
|
|
pub trait AiClientEmbedExt {
|
|
fn embedder(&self) -> AiResult<EmbedClient>;
|
|
}
|
|
|
|
impl AiClientEmbedExt for AiClient {
|
|
fn embedder(&self) -> AiResult<EmbedClient> {
|
|
EmbedClient::new(self)
|
|
}
|
|
}
|