Can cut() be improved?

96 Views Asked by At

Currently the last 4 lines of cut.default are written as:

code <- .bincode(x, breaks, right, include.lowest)
if (codes.only) 
  code
else factor(code, seq_along(labels), labels, ordered = ordered_result)

Could this not be improved by rewriting it in the following way?

  code <- .bincode(x, breaks, right, include.lowest)
  if (!codes.only) {
    levels(code) <- as.character(labels)
    class(code) <- c(if (ordered_result) "ordered" else character(0), "factor")
  }
  code
}

If there's a reason this might not work or is simply a misguided approach I'd be happy to hear your thoughts.

It seems to improve performance for when the result is a factor (the default).

Small data

library(bench)

x <- as.double(1:100) + 0
breaks <- seq(0, 90, 7)
unsorted_breaks <- sample(breaks)
mark(cut.default(x, breaks), 
            cut2(x, breaks), 
     min_iterations = 10^4)
# A tibble: 2 × 13
  expression       min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory    
  <bch:expr>     <bch> <bch:>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>    
1 cut.default(x… 130µs  148µs     6372.    2.97KB     1.27  9998     2      1.57s <fct>  <Rprofmem>
2 cut2(x, break… 101µs  115µs     8082.      448B     2.43  9997     3      1.24s <fct>  <Rprofmem>
# ℹ 2 more variables: time <list>, gc <list>

mark(cut.default(x, unsorted_breaks), 
            cut2(x, unsorted_breaks),
     min_iterations = 10^4)
# A tibble: 2 × 13
  expression       min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory    
  <bch:expr>     <bch> <bch:>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>    
1 cut.default(x… 130µs  148µs     6328.    2.97KB     1.90  9997     3      1.58s <fct>  <Rprofmem>
2 cut2(x, unsor… 106µs  116µs     8281.      448B     1.66  9998     2      1.21s <fct>  <Rprofmem>
# ℹ 2 more variables: time <list>, gc <list>

Larger data

set.seed(42)
x <- rnorm(10^6, 0, 10^5)
breaks <- seq(min(x), max(x), by = 5)

mark(cut.default(x, breaks),
            cut2(x, breaks))
# A tibble: 2 × 13
  expression       min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory    
  <bch:expr>     <bch> <bch:>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>    
1 cut.default(x… 1.78s  1.78s     0.561   113.9MB        0     1     0      1.78s <fct>  <Rprofmem>
2 cut2(x, break… 1.06s  1.06s     0.944    74.7MB        0     1     0      1.06s <fct>  <Rprofmem>
# ℹ 2 more variables: time <list>, gc <list>

x <- rnorm(10^7) + 0
b <- seq(0, max(x), 0.2) + 0

mark(e1 = cut.default(x, b),
     e2 = cut2(x, b))
# A tibble: 2 × 13
  expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>
1 e1            1.49s    1.49s     0.671     267MB    0.671     1     1      1.49s <fct> 
2 e2         282.39ms 303.67ms     3.29     38.1MB    0         2     0   607.34ms <fct> 
# ℹ 3 more variables: memory <list>, time <list>, gc <list>

cut2

cut2 <- function (x, breaks, labels = NULL, include.lowest = FALSE, right = TRUE, 
                  dig.lab = 3L, ordered_result = FALSE, ...){
  if (!is.numeric(x)) 
    stop("'x' must be numeric")
  if (length(breaks) == 1L) {
    if (is.na(breaks) || breaks < 2L) 
      stop("invalid number of intervals")
    nb <- as.integer(breaks + 1)
    dx <- diff(rx <- range(x, na.rm = TRUE))
    if (dx == 0) {
      dx <- if (rx[1L] != 0) 
        abs(rx[1L])
      else 1
      breaks <- seq.int(rx[1L] - dx/1000, rx[2L] + dx/1000, 
                        length.out = nb)
    }
    else {
      breaks <- seq.int(rx[1L], rx[2L], length.out = nb)
      breaks[c(1L, nb)] <- c(rx[1L] - dx/1000, rx[2L] + 
                               dx/1000)
    }
  }
  else nb <- length(breaks <- sort.int(as.double(breaks)))
  if (anyDuplicated(breaks)) 
    stop("'breaks' are not unique")
  codes.only <- FALSE
  if (is.null(labels)) {
    for (dig in dig.lab:max(12L, dig.lab)) {
      ch.br <- formatC(0 + breaks, digits = dig, width = 1L)
      if (ok <- all(ch.br[-1L] != ch.br[-nb])) 
        break
    }
    labels <- if (ok) 
      paste0(if (right) 
        "("
        else "[", ch.br[-nb], ",", ch.br[-1L], if (right) 
          "]"
        else ")")
    else paste0("Range_", seq_len(nb - 1L))
    if (ok && include.lowest) {
      if (right) 
        substr(labels[1L], 1L, 1L) <- "["
      else substring(labels[nb - 1L], nchar(labels[nb - 
                                                     1L], "c")) <- "]"
    }
  }
  else if (is.logical(labels) && !labels) 
    codes.only <- TRUE
  else if (length(labels) != nb - 1L) 
    stop("number of intervals and length of 'labels' differ")
  code <- .bincode(x, breaks, right, include.lowest)
  if (!codes.only) {
    levels(code) <- as.character(labels)
    class(code) <- c(if (ordered_result) "ordered" else character(0), "factor")
  }
  code
}
0

There are 0 best solutions below