diff --git a/src/cortex-cli/src/stats_cmd.rs b/src/cortex-cli/src/stats_cmd.rs index 1e40750..590d238 100644 --- a/src/cortex-cli/src/stats_cmd.rs +++ b/src/cortex-cli/src/stats_cmd.rs @@ -120,7 +120,6 @@ fn load_custom_pricing() -> std::collections::HashMap { // Example: CORTEX_PRICING_GPT4O=2.5,10.0 for (key, value) in std::env::vars() { if let Some(model_suffix) = key.strip_prefix("CORTEX_PRICING_") { - let model_name = model_suffix.to_lowercase().replace('_', "-"); let parts: Vec<&str> = value.split(',').collect(); if parts.len() == 2 && let (Ok(input), Ok(output)) = ( @@ -128,13 +127,13 @@ fn load_custom_pricing() -> std::collections::HashMap { parts[1].trim().parse::(), ) { - custom.insert( - model_name, - ModelPricing { - input_per_million: input, - output_per_million: output, - }, - ); + let pricing = ModelPricing { + input_per_million: input, + output_per_million: output, + }; + for model_name in model_names_from_env_suffix(model_suffix) { + custom.insert(model_name, pricing.clone()); + } } } } @@ -142,6 +141,36 @@ fn load_custom_pricing() -> std::collections::HashMap { custom } +/// Convert CORTEX_PRICING suffixes into model lookup keys. +/// +/// Supports: +/// - Legacy style: `GPT_4O` -> `gpt-4o` +/// - Provider/model style: `OPENAI_GPT_4O` -> `openai/gpt-4o` +/// - Explicit slash style: `OPENAI__GPT_4O` -> `openai/gpt-4o` +fn model_names_from_env_suffix(model_suffix: &str) -> Vec { + let lower = model_suffix.to_lowercase(); + let mut names = vec![lower.replace('_', "-")]; + + // First underscore as provider/model separator: + // OPENAI_GPT_4O -> openai/gpt-4o + if let Some((provider, rest)) = lower.split_once('_') + && !provider.is_empty() + && !rest.is_empty() + { + names.push(format!("{}/{}", provider, rest.replace('_', "-"))); + } + + // Double underscore explicitly encodes slash: + // OPENAI__GPT_4O -> openai/gpt-4o + if lower.contains("__") { + names.push(lower.replace("__", "/").replace('_', "-")); + } + + names.sort(); + names.dedup(); + names +} + impl StatsCli { /// Run the stats command. pub async fn run(self) -> Result<()> { @@ -755,4 +784,17 @@ mod tests { let err = validate_days_range("abc").unwrap_err(); assert!(err.contains("not a valid number")); } + + #[test] + fn test_model_names_from_env_suffix_provider_model_mapping() { + let names = model_names_from_env_suffix("OPENAI_GPT_4O"); + assert!(names.contains(&"openai-gpt-4o".to_string())); + assert!(names.contains(&"openai/gpt-4o".to_string())); + } + + #[test] + fn test_model_names_from_env_suffix_explicit_slash_mapping() { + let names = model_names_from_env_suffix("OPENAI__GPT_4O"); + assert!(names.contains(&"openai/gpt-4o".to_string())); + } }