Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions src/cortex-cli/src/stats_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,28 +120,57 @@ fn load_custom_pricing() -> std::collections::HashMap<String, ModelPricing> {
// Example: CORTEX_PRICING_GPT4O=2.5,10.0
for (key, value) in std::env::vars() {
if let Some(model_suffix) = key.strip_prefix("CORTEX_PRICING_") {
let model_name = model_suffix.to_lowercase().replace('_', "-");
let parts: Vec<&str> = value.split(',').collect();
if parts.len() == 2
&& let (Ok(input), Ok(output)) = (
parts[0].trim().parse::<f64>(),
parts[1].trim().parse::<f64>(),
)
{
custom.insert(
model_name,
ModelPricing {
input_per_million: input,
output_per_million: output,
},
);
let pricing = ModelPricing {
input_per_million: input,
output_per_million: output,
};
for model_name in model_names_from_env_suffix(model_suffix) {
custom.insert(model_name, pricing.clone());
}
}
}
}

custom
}

/// Convert CORTEX_PRICING suffixes into model lookup keys.
///
/// Supports:
/// - Legacy style: `GPT_4O` -> `gpt-4o`
/// - Provider/model style: `OPENAI_GPT_4O` -> `openai/gpt-4o`
/// - Explicit slash style: `OPENAI__GPT_4O` -> `openai/gpt-4o`
fn model_names_from_env_suffix(model_suffix: &str) -> Vec<String> {
let lower = model_suffix.to_lowercase();
let mut names = vec![lower.replace('_', "-")];

// First underscore as provider/model separator:
// OPENAI_GPT_4O -> openai/gpt-4o
if let Some((provider, rest)) = lower.split_once('_')
&& !provider.is_empty()
&& !rest.is_empty()
{
names.push(format!("{}/{}", provider, rest.replace('_', "-")));
}

// Double underscore explicitly encodes slash:
// OPENAI__GPT_4O -> openai/gpt-4o
if lower.contains("__") {
names.push(lower.replace("__", "/").replace('_', "-"));
}

names.sort();
names.dedup();
names
}

impl StatsCli {
/// Run the stats command.
pub async fn run(self) -> Result<()> {
Expand Down Expand Up @@ -755,4 +784,17 @@ mod tests {
let err = validate_days_range("abc").unwrap_err();
assert!(err.contains("not a valid number"));
}

#[test]
fn test_model_names_from_env_suffix_provider_model_mapping() {
let names = model_names_from_env_suffix("OPENAI_GPT_4O");
assert!(names.contains(&"openai-gpt-4o".to_string()));
assert!(names.contains(&"openai/gpt-4o".to_string()));
}

#[test]
fn test_model_names_from_env_suffix_explicit_slash_mapping() {
let names = model_names_from_env_suffix("OPENAI__GPT_4O");
assert!(names.contains(&"openai/gpt-4o".to_string()));
}
}