diff --git a/.positai/settings.json b/.positai/settings.json index 510265ef..500757de 100644 --- a/.positai/settings.json +++ b/.positai/settings.json @@ -11,5 +11,9 @@ "executeCode": { "*": "allow" } + }, + "model": { + "id": "claude-opus-4-6", + "provider": "positai" } } \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index cc81b957..2f0146ac 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -27,3 +27,10 @@ of internal nodes for any topologically identical tree. Splits objects are defined in `as.Splits()`, and denote split membership as binary 0/1 in an underlying `raw` object. + +## Workflow requirements + +- After completing each optimization or user-visible change, update `NEWS.md` + before moving on to the next task. +- Increment the `.900X` dev version suffix in `DESCRIPTION` with each + `NEWS.md` update. diff --git a/DESCRIPTION b/DESCRIPTION index 04beda37..772b6ca2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: TreeTools Title: Create, Modify and Analyse Phylogenetic Trees -Version: 2.1.0.9004 +Version: 2.1.0.9005 Authors@R: c( person("Martin R.", 'Smith', role = c("aut", "cre", "cph"), email = "martin.smith@durham.ac.uk", diff --git a/NEWS.md b/NEWS.md index bf4bb119..85481258 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,7 @@ -# TreeTools 2.1.0.9004 (2026-03-12) # +# TreeTools 2.1.0.9005 (2026-03-13) # -- Rewrite popcount calculation for more efficient `TipsInSplits()`. +- `SplitFrequency(reference = NULL)` split normalization moved to C++, + eliminating an R-level per-split loop. # TreeTools 2.1.0.9003 (2026-03-09) # diff --git a/R/RcppExports.R b/R/RcppExports.R index 8117f656..b5e6e8f7 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -137,6 +137,10 @@ cpp_count_splits <- function(edge, nTip) { .Call(`_TreeTools_cpp_count_splits`, edge, nTip) } +normalize_splits <- function(splits, n_tip) { + .Call(`_TreeTools_normalize_splits`, splits, n_tip) +} + splits_to_edge <- function(splits, nTip) { .Call(`_TreeTools_splits_to_edge`, splits, nTip) } diff --git a/R/Support.R b/R/Support.R index b105f416..e080f9cc 100644 --- a/R/Support.R +++ b/R/Support.R @@ -58,30 +58,10 @@ SplitFrequency <- function(reference, forest = NULL) { return(structure(splits, nTip = nTip, tip.label = tipLabels, # nocov count = integer(), class = "Splits")) # nocov } - # The ClusterTable outputs clusters (clades); normalize so bit 0 (tip 1) - # is not in the set (matching as.Splits convention) - nTipMod <- nTip %% 8L - lastByteMask <- if (nTipMod == 0L) as.raw(0xff) else as.raw(bitwShiftL(1L, nTipMod) - 1L) - keep <- logical(nrow(splits)) - for (i in seq_len(nrow(splits))) { - val <- splits[i, ] - # Count bits set (to filter trivial splits) - nBits <- sum(vapply(as.integer(val), function(b) sum(as.integer(intToBits(b))), integer(1))) - if (nBits < 2L || nBits > nTip - 2L) next # trivial split - # Normalize: if bit 0 is NOT set, complement to match as.Splits format - if (!as.logical(as.integer(val[1]) %% 2L)) { - for (j in seq_along(val)) { - splits[i, j] <- as.raw(bitwXor(as.integer(val[j]), 0xffL)) - } - # Mask last byte - if (nTipMod > 0L) { - splits[i, nbin] <- as.raw(bitwAnd(as.integer(splits[i, nbin]), - as.integer(lastByteMask))) - } - } - keep[i] <- TRUE - } - splits <- splits[keep, , drop = FALSE] + # Normalize splits: ensure bit 0 is set, filter trivial splits + normalized <- normalize_splits(splits, nTip) + keep <- normalized[["keep"]] + splits <- normalized[["splits"]][keep, , drop = FALSE] counts <- counts[keep] ret <- structure(splits, nTip = nTip, diff --git a/benchmark/_compare_results.R b/benchmark/_compare_results.R index b0aded68..a719d1a0 100644 --- a/benchmark/_compare_results.R +++ b/benchmark/_compare_results.R @@ -24,8 +24,12 @@ for (pr_file in pr_files) { # Prepare a report report <- list() + # Use deparse1 for reliable expression-to-string conversion; + # as.character(unlist()) decomposes call objects into their components. + expr_names <- vapply(pr1[["expression"]], deparse1, "") + # Iterate over each function benchmarked - for (fn_name in unique(as.character(unlist(pr1[["expression"]])))) { + for (fn_name in unique(expr_names)) { pr1_times <- as.numeric(pr1[["time"]][[1]]) pr2_times <- as.numeric(pr2[["time"]][[1]]) pr_times <- if (rep_exists) c(pr1_times, pr2_times) else pr1_times @@ -79,18 +83,21 @@ for (pr_file in pr_files) { # Create a markdown-formatted message has_significant_regression <- FALSE - for (fn_name in names(report)) { - res <- report[[fn_name]] - status <- if (res$matched) { + for (i in seq_along(report)) { + fn_name <- names(report)[[i]] + res <- report[[i]] + if (is.null(res) || is.null(res$matched)) next + + status <- if (isTRUE(res$matched)) { if (res$slower) { - if (abs(percentage_change) > threshold_percent) { + if (abs(res$change) > threshold_percent) { has_significant_regression <- TRUE "\U1F7E0 Slower \U1F641" } else { "\U1F7E3 ~same" } } else if (res$faster) { - if (abs(percentage_change) > threshold_percent) { + if (abs(res$change) > threshold_percent) { "\U1F7E2 Faster!" } else { "\U1F7E3 ~same" @@ -111,14 +118,14 @@ for (pr_file in pr_files) { signif(res$median_pr * 1e3, 3), ", ", signif(res$median_cf * 1e3, 3), " |\n" ) + + cat(message) + output <- paste0(output, message) } if (has_significant_regression) { regressions <- TRUE } - - cat(message) - output <- paste0(output, message) } cat(paste0(output, "\nEOF"), file = Sys.getenv("GITHUB_OUTPUT"), append = TRUE) diff --git a/benchmark/_init.R b/benchmark/_init.R index 11af7c83..216945c9 100644 --- a/benchmark/_init.R +++ b/benchmark/_init.R @@ -1,9 +1,14 @@ library("TreeTools") Benchmark <- function(..., min_iterations = NULL, min_time = NULL) { - args <- list(..., min_iterations = min_iterations %||% 3, time_unit = "us") - if (!is.null(min_time)) args[["min_time"]] <- min_time - result <- do.call(bench::mark, args) + # Pass ... directly to bench::mark to preserve non-standard evaluation; + # do.call() would evaluate expressions first, breaking expression capture. + result <- if (is.null(min_time)) { + bench::mark(..., min_iterations = min_iterations %||% 3, time_unit = "us") + } else { + bench::mark(..., min_iterations = min_iterations %||% 3, + min_time = min_time, time_unit = "us") + } if (interactive()) { print(result) } else { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index ccf10bbe..d08ceb81 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -426,6 +426,18 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// normalize_splits +Rcpp::List normalize_splits(Rcpp::RawMatrix splits, const int n_tip); +RcppExport SEXP _TreeTools_normalize_splits(SEXP splitsSEXP, SEXP n_tipSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::RawMatrix >::type splits(splitsSEXP); + Rcpp::traits::input_parameter< const int >::type n_tip(n_tipSEXP); + rcpp_result_gen = Rcpp::wrap(normalize_splits(splits, n_tip)); + return rcpp_result_gen; +END_RCPP +} // splits_to_edge IntegerMatrix splits_to_edge(const RawMatrix splits, const IntegerVector nTip); RcppExport SEXP _TreeTools_splits_to_edge(SEXP splitsSEXP, SEXP nTipSEXP) { @@ -517,6 +529,7 @@ static const R_CallMethodDef CallEntries[] = { {"_TreeTools_pack_splits_logical", (DL_FUNC) &_TreeTools_pack_splits_logical, 1}, {"_TreeTools_pack_splits_logical_vec", (DL_FUNC) &_TreeTools_pack_splits_logical_vec, 1}, {"_TreeTools_cpp_count_splits", (DL_FUNC) &_TreeTools_cpp_count_splits, 2}, + {"_TreeTools_normalize_splits", (DL_FUNC) &_TreeTools_normalize_splits, 2}, {"_TreeTools_splits_to_edge", (DL_FUNC) &_TreeTools_splits_to_edge, 2}, {"_TreeTools_tips_in_splits", (DL_FUNC) &_TreeTools_tips_in_splits, 1}, {"_TreeTools_edge_to_rooted_shape", (DL_FUNC) &_TreeTools_edge_to_rooted_shape, 3}, diff --git a/src/splits.cpp b/src/splits.cpp index 096dc2d7..2024caa3 100644 --- a/src/splits.cpp +++ b/src/splits.cpp @@ -714,3 +714,48 @@ int cpp_count_splits(const Rcpp::IntegerMatrix& edge, const int nTip) { return (n_internal - n_singles) - 1 - (is_rooted ? 1 : 0); } + +// Normalize splits from ClusterTable output for as.Splits convention: +// - Filter trivial splits (fewer than 2 or more than nTip-2 bits set) +// - Ensure bit 0 is set (complement if not) +// - Mask unused trailing bits in last byte +// Returns a List with "splits" (RawMatrix) and "keep" (LogicalVector) +// [[Rcpp::export]] +Rcpp::List normalize_splits(Rcpp::RawMatrix splits, const int n_tip) { + const int n_split = splits.nrow(); + const int n_bin = splits.ncol(); + const int n_spare = n_tip % BIN_SIZE; + const Rbyte last_mask = n_spare == 0 + ? Rbyte(0xff) + : static_cast((1 << n_spare) - 1); + + Rcpp::LogicalVector keep(n_split, false); + + for (int i = 0; i < n_split; ++i) { + // Count bits set + int n_bits = 0; + for (int j = 0; j < n_bin; ++j) { + n_bits += __builtin_popcount(static_cast(splits(i, j))); + } + + if (n_bits < 2 || n_bits > n_tip - 2) continue; // trivial + + // Normalize: if bit 0 is NOT set, complement + if (!(splits(i, 0) & Rbyte(1))) { + for (int j = 0; j < n_bin; ++j) { + splits(i, j) = static_cast(~splits(i, j)); + } + // Mask trailing bits in last byte + if (n_spare > 0) { + splits(i, n_bin - 1) &= last_mask; + } + } + + keep[i] = true; + } + + return Rcpp::List::create( + Rcpp::Named("splits") = splits, + Rcpp::Named("keep") = keep + ); +}