//! Token counting utilities using tiktoken. //! //! Provides accurate token counting for OpenAI-compatible models. //! Uses the `tiktoken-rs` crate (already in workspace dependencies). //! //! # Strategy //! //! Remote usage from API response is always preferred. When the API does not //! return usage metadata (e.g., local models, streaming), tiktoken is used as //! a fallback for accurate counting. use crate::error::{AgentError, Result}; /// Token usage data. Use `from_remote()` when the API returns usage info, /// or `from_estimate()` when falling back to tiktoken. #[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)] pub struct TokenUsage { pub input_tokens: i64, pub output_tokens: i64, } impl TokenUsage { /// Create from remote API usage data. Returns `None` if all values are zero /// (some providers return zeroed usage on error). pub fn from_remote(prompt_tokens: u32, completion_tokens: u32) -> Option { if prompt_tokens == 0 && completion_tokens == 0 { None } else { Some(Self { input_tokens: prompt_tokens as i64, output_tokens: completion_tokens as i64, }) } } /// Create from tiktoken estimate. pub fn from_estimate(input_tokens: usize, output_tokens: usize) -> Self { Self { input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, } } pub fn total(&self) -> i64 { self.input_tokens + self.output_tokens } } /// Resolve token usage: remote data is preferred, tiktoken is the fallback. /// /// `remote` — `Some` when API returned usage; `None` when not available. /// `model` — model name, required for tiktoken fallback. /// `input_text` — input text length hint for fallback estimate (uses ~4 chars/token). pub fn resolve_usage( remote: Option, model: &str, input_text: &str, output_text: &str, ) -> TokenUsage { if let Some(usage) = remote { return usage; } // Fallback: tiktoken estimate let input = count_message_text(input_text, model).unwrap_or_else(|_| { // Rough estimate: ~4 chars per token (input_text.len() / 4).max(1) }); let output = output_text.len() / 4; TokenUsage::from_estimate(input, output) } /// Estimate the number of tokens in a text string using the appropriate tokenizer. pub fn count_text(text: &str, model: &str) -> Result { let bpe = get_tokenizer(model)?; // Use encode_ordinary since we're counting raw text, not chat messages let tokens = bpe.encode_ordinary(text); Ok(tokens.len()) } /// Count tokens in a single chat message (text content only). pub fn count_message_text(text: &str, model: &str) -> Result { let bpe = get_tokenizer(model)?; // For messages, use encode_with_special_tokens to count role/separator tokens let tokens = bpe.encode_with_special_tokens(text); Ok(tokens.len()) } /// Estimate the maximum number of characters that fit within a token budget /// given a model's context limit and a reserve for the output. /// /// Uses a rough estimate of ~4 characters per token (typical for English text). /// For non-Latin scripts, this is less accurate. pub fn estimate_max_chars( _model: &str, context_limit: usize, reserve_output_tokens: usize, ) -> Result { let chars_per_token = 4; // Subtract reserve for output, system overhead, and a safety margin (10%) let safe_limit = context_limit .saturating_sub(reserve_output_tokens) .saturating_sub(512); // 512 token safety margin Ok(safe_limit.saturating_mul(chars_per_token)) } /// Truncate text to fit within a token budget for a given model. pub fn truncate_to_token_budget( text: &str, model: &str, context_limit: usize, reserve_output_tokens: usize, ) -> Result { let max_chars = estimate_max_chars(model, context_limit, reserve_output_tokens)?; if text.len() <= max_chars { return Ok(text.to_string()); } // Binary search for the exact character boundary that fits the token budget let bpe = get_tokenizer(model)?; let mut low = 0usize; let mut high = text.len(); let mut result = text.to_string(); while low + 100 < high { let mid = (low + high) / 2; let candidate = &text[..mid]; let tokens = bpe.encode_ordinary(candidate); if tokens.len() <= safe_token_budget(context_limit, reserve_output_tokens) { result = candidate.to_string(); low = mid; } else { high = mid; } } Ok(result) } /// Returns the safe token budget (context limit minus reserve and margin). fn safe_token_budget(context_limit: usize, reserve: usize) -> usize { context_limit.saturating_sub(reserve).saturating_sub(512) } /// Get the appropriate tiktoken tokenizer for a model. /// /// Model name mapping: /// - "gpt-4o", "o1", "o3", "o4" → o200k_base /// - "claude-*", "gpt-3.5-turbo", "gpt-4" → cl100k_base /// - Unknown → cl100k_base (safe fallback) fn get_tokenizer(model: &str) -> Result { use tiktoken_rs; // Try model-specific tokenizer first if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) { return Ok(bpe); } // Fallback: use cl100k_base for unknown models tiktoken_rs::cl100k_base() .map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e))) } /// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text). /// Returns the token count including the prefix. pub fn count_with_prefix(text: &str, prefix: &str, model: &str) -> Result { let bpe = get_tokenizer(model)?; let prefixed = format!("{}{}", prefix, text); let tokens = bpe.encode_with_special_tokens(&prefixed); Ok(tokens.len()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_count_text() { let count = count_text("Hello, world!", "gpt-4").unwrap(); assert!(count > 0); } #[test] fn test_estimate_max_chars() { // gpt-4o context ~128k tokens let chars = estimate_max_chars("gpt-4o", 128_000, 2048).unwrap(); assert!(chars > 0); } #[test] fn test_truncate() { // 50k chars exceeds budget: 8192 - 512 - 512 = 7168 tokens → ~28k chars let long_text = "a".repeat(50000); let truncated = truncate_to_token_budget(&long_text, "gpt-4o", 8192, 512).unwrap(); assert!(truncated.len() < long_text.len()); } }