200 lines
6.5 KiB
Rust
200 lines
6.5 KiB
Rust
//! Token counting utilities using tiktoken.
|
|
//!
|
|
//! Provides accurate token counting for OpenAI-compatible models.
|
|
//! Uses the `tiktoken-rs` crate (already in workspace dependencies).
|
|
//!
|
|
//! # Strategy
|
|
//!
|
|
//! Remote usage from API response is always preferred. When the API does not
|
|
//! return usage metadata (e.g., local models, streaming), tiktoken is used as
|
|
//! a fallback for accurate counting.
|
|
|
|
use crate::error::{AgentError, Result};
|
|
|
|
/// Token usage data. Use `from_remote()` when the API returns usage info,
|
|
/// or `from_estimate()` when falling back to tiktoken.
|
|
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)]
|
|
pub struct TokenUsage {
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
}
|
|
|
|
impl TokenUsage {
|
|
/// Create from remote API usage data. Returns `None` if all values are zero
|
|
/// (some providers return zeroed usage on error).
|
|
pub fn from_remote(prompt_tokens: u32, completion_tokens: u32) -> Option<Self> {
|
|
if prompt_tokens == 0 && completion_tokens == 0 {
|
|
None
|
|
} else {
|
|
Some(Self {
|
|
input_tokens: prompt_tokens as i64,
|
|
output_tokens: completion_tokens as i64,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Create from tiktoken estimate.
|
|
pub fn from_estimate(input_tokens: usize, output_tokens: usize) -> Self {
|
|
Self {
|
|
input_tokens: input_tokens as i64,
|
|
output_tokens: output_tokens as i64,
|
|
}
|
|
}
|
|
|
|
pub fn total(&self) -> i64 {
|
|
self.input_tokens + self.output_tokens
|
|
}
|
|
}
|
|
|
|
/// Resolve token usage: remote data is preferred, tiktoken is the fallback.
|
|
///
|
|
/// `remote` — `Some` when API returned usage; `None` when not available.
|
|
/// `model` — model name, required for tiktoken fallback.
|
|
/// `input_text` — input text length hint for fallback estimate (uses ~4 chars/token).
|
|
pub fn resolve_usage(
|
|
remote: Option<TokenUsage>,
|
|
model: &str,
|
|
input_text: &str,
|
|
output_text: &str,
|
|
) -> TokenUsage {
|
|
if let Some(usage) = remote {
|
|
return usage;
|
|
}
|
|
|
|
// Fallback: tiktoken estimate
|
|
let input = count_message_text(input_text, model).unwrap_or_else(|_| {
|
|
// Rough estimate: ~4 chars per token
|
|
(input_text.len() / 4).max(1)
|
|
});
|
|
let output = output_text.len() / 4;
|
|
TokenUsage::from_estimate(input, output)
|
|
}
|
|
|
|
/// Estimate the number of tokens in a text string using the appropriate tokenizer.
|
|
pub fn count_text(text: &str, model: &str) -> Result<usize> {
|
|
let bpe = get_tokenizer(model)?;
|
|
// Use encode_ordinary since we're counting raw text, not chat messages
|
|
let tokens = bpe.encode_ordinary(text);
|
|
Ok(tokens.len())
|
|
}
|
|
|
|
/// Count tokens in a single chat message (text content only).
|
|
pub fn count_message_text(text: &str, model: &str) -> Result<usize> {
|
|
let bpe = get_tokenizer(model)?;
|
|
// For messages, use encode_with_special_tokens to count role/separator tokens
|
|
let tokens = bpe.encode_with_special_tokens(text);
|
|
Ok(tokens.len())
|
|
}
|
|
|
|
/// Estimate the maximum number of characters that fit within a token budget
|
|
/// given a model's context limit and a reserve for the output.
|
|
///
|
|
/// Uses a rough estimate of ~4 characters per token (typical for English text).
|
|
/// For non-Latin scripts, this is less accurate.
|
|
pub fn estimate_max_chars(
|
|
_model: &str,
|
|
context_limit: usize,
|
|
reserve_output_tokens: usize,
|
|
) -> Result<usize> {
|
|
let chars_per_token = 4;
|
|
// Subtract reserve for output, system overhead, and a safety margin (10%)
|
|
let safe_limit = context_limit
|
|
.saturating_sub(reserve_output_tokens)
|
|
.saturating_sub(512); // 512 token safety margin
|
|
Ok(safe_limit.saturating_mul(chars_per_token))
|
|
}
|
|
|
|
/// Truncate text to fit within a token budget for a given model.
|
|
pub fn truncate_to_token_budget(
|
|
text: &str,
|
|
model: &str,
|
|
context_limit: usize,
|
|
reserve_output_tokens: usize,
|
|
) -> Result<String> {
|
|
let max_chars = estimate_max_chars(model, context_limit, reserve_output_tokens)?;
|
|
|
|
if text.len() <= max_chars {
|
|
return Ok(text.to_string());
|
|
}
|
|
|
|
// Binary search for the exact character boundary that fits the token budget
|
|
let bpe = get_tokenizer(model)?;
|
|
let mut low = 0usize;
|
|
let mut high = text.len();
|
|
let mut result = text.to_string();
|
|
|
|
while low + 100 < high {
|
|
let mid = (low + high) / 2;
|
|
let candidate = &text[..mid];
|
|
let tokens = bpe.encode_ordinary(candidate);
|
|
|
|
if tokens.len() <= safe_token_budget(context_limit, reserve_output_tokens) {
|
|
result = candidate.to_string();
|
|
low = mid;
|
|
} else {
|
|
high = mid;
|
|
}
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Returns the safe token budget (context limit minus reserve and margin).
|
|
fn safe_token_budget(context_limit: usize, reserve: usize) -> usize {
|
|
context_limit.saturating_sub(reserve).saturating_sub(512)
|
|
}
|
|
|
|
/// Get the appropriate tiktoken tokenizer for a model.
|
|
///
|
|
/// Model name mapping:
|
|
/// - "gpt-4o", "o1", "o3", "o4" → o200k_base
|
|
/// - "claude-*", "gpt-3.5-turbo", "gpt-4" → cl100k_base
|
|
/// - Unknown → cl100k_base (safe fallback)
|
|
fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
|
|
use tiktoken_rs;
|
|
|
|
// Try model-specific tokenizer first
|
|
if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
|
|
return Ok(bpe);
|
|
}
|
|
|
|
// Fallback: use cl100k_base for unknown models
|
|
tiktoken_rs::cl100k_base()
|
|
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))
|
|
}
|
|
|
|
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).
|
|
/// Returns the token count including the prefix.
|
|
pub fn count_with_prefix(text: &str, prefix: &str, model: &str) -> Result<usize> {
|
|
let bpe = get_tokenizer(model)?;
|
|
let prefixed = format!("{}{}", prefix, text);
|
|
let tokens = bpe.encode_with_special_tokens(&prefixed);
|
|
Ok(tokens.len())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_count_text() {
|
|
let count = count_text("Hello, world!", "gpt-4").unwrap();
|
|
assert!(count > 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_estimate_max_chars() {
|
|
// gpt-4o context ~128k tokens
|
|
let chars = estimate_max_chars("gpt-4o", 128_000, 2048).unwrap();
|
|
assert!(chars > 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_truncate() {
|
|
// 50k chars exceeds budget: 8192 - 512 - 512 = 7168 tokens → ~28k chars
|
|
let long_text = "a".repeat(50000);
|
|
let truncated = truncate_to_token_budget(&long_text, "gpt-4o", 8192, 512).unwrap();
|
|
assert!(truncated.len() < long_text.len());
|
|
}
|
|
}
|