gitdataai/libs/agent/tokent.rs
2026-04-14 19:02:01 +08:00

200 lines
6.5 KiB
Rust

//! Token counting utilities using tiktoken.
//!
//! Provides accurate token counting for OpenAI-compatible models.
//! Uses the `tiktoken-rs` crate (already in workspace dependencies).
//!
//! # Strategy
//!
//! Remote usage from API response is always preferred. When the API does not
//! return usage metadata (e.g., local models, streaming), tiktoken is used as
//! a fallback for accurate counting.
use crate::error::{AgentError, Result};
/// Token usage data. Use `from_remote()` when the API returns usage info,
/// or `from_estimate()` when falling back to tiktoken.
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)]
pub struct TokenUsage {
pub input_tokens: i64,
pub output_tokens: i64,
}
impl TokenUsage {
/// Create from remote API usage data. Returns `None` if all values are zero
/// (some providers return zeroed usage on error).
pub fn from_remote(prompt_tokens: u32, completion_tokens: u32) -> Option<Self> {
if prompt_tokens == 0 && completion_tokens == 0 {
None
} else {
Some(Self {
input_tokens: prompt_tokens as i64,
output_tokens: completion_tokens as i64,
})
}
}
/// Create from tiktoken estimate.
pub fn from_estimate(input_tokens: usize, output_tokens: usize) -> Self {
Self {
input_tokens: input_tokens as i64,
output_tokens: output_tokens as i64,
}
}
pub fn total(&self) -> i64 {
self.input_tokens + self.output_tokens
}
}
/// Resolve token usage: remote data is preferred, tiktoken is the fallback.
///
/// `remote` — `Some` when API returned usage; `None` when not available.
/// `model` — model name, required for tiktoken fallback.
/// `input_text` — input text length hint for fallback estimate (uses ~4 chars/token).
pub fn resolve_usage(
remote: Option<TokenUsage>,
model: &str,
input_text: &str,
output_text: &str,
) -> TokenUsage {
if let Some(usage) = remote {
return usage;
}
// Fallback: tiktoken estimate
let input = count_message_text(input_text, model).unwrap_or_else(|_| {
// Rough estimate: ~4 chars per token
(input_text.len() / 4).max(1)
});
let output = output_text.len() / 4;
TokenUsage::from_estimate(input, output)
}
/// Estimate the number of tokens in a text string using the appropriate tokenizer.
pub fn count_text(text: &str, model: &str) -> Result<usize> {
let bpe = get_tokenizer(model)?;
// Use encode_ordinary since we're counting raw text, not chat messages
let tokens = bpe.encode_ordinary(text);
Ok(tokens.len())
}
/// Count tokens in a single chat message (text content only).
pub fn count_message_text(text: &str, model: &str) -> Result<usize> {
let bpe = get_tokenizer(model)?;
// For messages, use encode_with_special_tokens to count role/separator tokens
let tokens = bpe.encode_with_special_tokens(text);
Ok(tokens.len())
}
/// Estimate the maximum number of characters that fit within a token budget
/// given a model's context limit and a reserve for the output.
///
/// Uses a rough estimate of ~4 characters per token (typical for English text).
/// For non-Latin scripts, this is less accurate.
pub fn estimate_max_chars(
_model: &str,
context_limit: usize,
reserve_output_tokens: usize,
) -> Result<usize> {
let chars_per_token = 4;
// Subtract reserve for output, system overhead, and a safety margin (10%)
let safe_limit = context_limit
.saturating_sub(reserve_output_tokens)
.saturating_sub(512); // 512 token safety margin
Ok(safe_limit.saturating_mul(chars_per_token))
}
/// Truncate text to fit within a token budget for a given model.
pub fn truncate_to_token_budget(
text: &str,
model: &str,
context_limit: usize,
reserve_output_tokens: usize,
) -> Result<String> {
let max_chars = estimate_max_chars(model, context_limit, reserve_output_tokens)?;
if text.len() <= max_chars {
return Ok(text.to_string());
}
// Binary search for the exact character boundary that fits the token budget
let bpe = get_tokenizer(model)?;
let mut low = 0usize;
let mut high = text.len();
let mut result = text.to_string();
while low + 100 < high {
let mid = (low + high) / 2;
let candidate = &text[..mid];
let tokens = bpe.encode_ordinary(candidate);
if tokens.len() <= safe_token_budget(context_limit, reserve_output_tokens) {
result = candidate.to_string();
low = mid;
} else {
high = mid;
}
}
Ok(result)
}
/// Returns the safe token budget (context limit minus reserve and margin).
fn safe_token_budget(context_limit: usize, reserve: usize) -> usize {
context_limit.saturating_sub(reserve).saturating_sub(512)
}
/// Get the appropriate tiktoken tokenizer for a model.
///
/// Model name mapping:
/// - "gpt-4o", "o1", "o3", "o4" → o200k_base
/// - "claude-*", "gpt-3.5-turbo", "gpt-4" → cl100k_base
/// - Unknown → cl100k_base (safe fallback)
fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
use tiktoken_rs;
// Try model-specific tokenizer first
if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
return Ok(bpe);
}
// Fallback: use cl100k_base for unknown models
tiktoken_rs::cl100k_base()
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))
}
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).
/// Returns the token count including the prefix.
pub fn count_with_prefix(text: &str, prefix: &str, model: &str) -> Result<usize> {
let bpe = get_tokenizer(model)?;
let prefixed = format!("{}{}", prefix, text);
let tokens = bpe.encode_with_special_tokens(&prefixed);
Ok(tokens.len())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_text() {
let count = count_text("Hello, world!", "gpt-4").unwrap();
assert!(count > 0);
}
#[test]
fn test_estimate_max_chars() {
// gpt-4o context ~128k tokens
let chars = estimate_max_chars("gpt-4o", 128_000, 2048).unwrap();
assert!(chars > 0);
}
#[test]
fn test_truncate() {
// 50k chars exceeds budget: 8192 - 512 - 512 = 7168 tokens → ~28k chars
let long_text = "a".repeat(50000);
let truncated = truncate_to_token_budget(&long_text, "gpt-4o", 8192, 512).unwrap();
assert!(truncated.len() < long_text.len());
}
}