"comparison of these types is not implemented" when using predict

43 Views Asked by At

(Seems there is no tag for clustMixType, tag suggestions welcome)

I am attempting to use library clustMixType to make some clusters.

library(tidyverse)
library(clustMixType)


# no scaling or real data prep here, just reproducing an issue with minimal code
my_diamonds <- diamonds %>% 
  mutate(is_color_g = factor(ifelse(color == 'G', 1, 0))) %>% 
  select(cut, carat, is_color_g, depth, table, price) %>% 
  group_by(cut) %>% 
  nest %>% 
  mutate(k = 3)

my_diamonds <- my_diamonds %>% 
  mutate(mod.kproto = map2(data, k, ~kproto(.x, k = .y, lambda = NULL, iter.max = 100, nstart = 1, na.rm = 'no')))

This results in a list column with a cluster model for each of cut:

my_diamonds
# A tibble: 5 × 4
# Groups:   cut [5]
  cut       data                      k mod.kproto
  <ord>     <list>                <dbl> <list>    
1 Ideal     <tibble [21,551 × 5]>     3 <kproto>  
2 Premium   <tibble [13,791 × 5]>     3 <kproto>  
3 Good      <tibble [4,906 × 5]>      3 <kproto>  
4 Very Good <tibble [12,082 × 5]>     3 <kproto>  
5 Fair      <tibble [1,610 × 5]>      3 <kproto>  

According to the library docs (pdf) we can use predict to assing newdata to the nearest cluster.

Under predict.kproto there is an example: predicted.clusters <- predict(kpres, x) where x is new data. I gave it a try:

my_diamonds <- my_diamonds %>% 
+   mutate(preds = map2(data, mod.kproto, ~predict(.y, .x)))
Error in `mutate()`:
! Problem while computing `preds = map2(data, mod.kproto, ~predict(.y, .x))`.
ℹ The error occurred in group 1: cut = Fair.
Caused by error in `x[, j] != rep(protos[i, j], nrows)`:
! comparison of these types is not implemented
Run `rlang::last_error()` to see where the error occurred.
Warning message:
Problem while computing `preds = map2(data, mod.kproto, ~predict(.y, .x))`.
ℹ Incompatible methods ("Ops.data.frame", "Ops.factor") for "!="
ℹ The warning occurred in group 1: cut = Fair. 

Why am I getting this error and how can I overcome it to use clustMixType's predict function to assign clusters to newdata?

1

There are 1 best solutions below

4
henhesu On BEST ANSWER

It seems that passing x as a standard data.frame does the trick:

my_diamonds %>% 
  mutate(preds = map2(data, mod.kproto, ~predict(.y, as.data.frame(.x))))
#> # A tibble: 5 × 5
#> # Groups:   cut [5]
#>   cut       data                      k mod.kproto preds           
#>   <ord>     <list>                <dbl> <list>     <list>          
#> 1 Ideal     <tibble [21,551 × 5]>     3 <kproto>   <named list [2]>
#> 2 Premium   <tibble [13,791 × 5]>     3 <kproto>   <named list [2]>
#> 3 Good      <tibble [4,906 × 5]>      3 <kproto>   <named list [2]>
#> 4 Very Good <tibble [12,082 × 5]>     3 <kproto>   <named list [2]>
#> 5 Fair      <tibble [1,610 × 5]>      3 <kproto>   <named list [2]>

Created on 2023-09-19 by the reprex package (v2.0.1)

Update / Deep Dive

I did a bit more debugging and noticed that the error arises from this line:

d2 <- sapply(which(catvars), function(j) return(x[,j] != rep(protos[i,j], nrows)) )

What happens is here that we subset x with [, j] where j is equal to the result of which(catvars) which returns

is_color_g 
         2 

in your case.

The error arises because of the different ways that base::data.frame() and tibble::tibble() handle one-dimensional results of subsetting operations. As taken from this answer:

  • By default, [.data.frame will drop the dimensions if the result has only 1 column, similar to how matrix subsetting works. So the result is a vector.
  • [.tbl_df will never drop dimensions like this; it always returns a tbl.

See for yourself:

iris[,1]
#>   [1] 5.1 4.9 4.7 4.6 5.0 5.4 4.6 5.0 4.4 4.9 5.4 4.8 4.8 4.3 5.8 5.7 5.4 5.1
#>  [19] 5.7 5.1 5.4 5.1 4.6 5.1 4.8 5.0 5.0 5.2 5.2 4.7 4.8 5.4 5.2 5.5 4.9 5.0
#>  [37] 5.5 4.9 4.4 5.1 5.0 4.5 4.4 5.0 5.1 4.8 5.1 4.6 5.3 5.0 7.0 6.4 6.9 5.5
#>  [55] 6.5 5.7 6.3 4.9 6.6 5.2 5.0 5.9 6.0 6.1 5.6 6.7 5.6 5.8 6.2 5.6 5.9 6.1
#>  [73] 6.3 6.1 6.4 6.6 6.8 6.7 6.0 5.7 5.5 5.5 5.8 6.0 5.4 6.0 6.7 6.3 5.6 5.5
#>  [91] 5.5 6.1 5.8 5.0 5.6 5.7 5.7 6.2 5.1 5.7 6.3 5.8 7.1 6.3 6.5 7.6 4.9 7.3
#> [109] 6.7 7.2 6.5 6.4 6.8 5.7 5.8 6.4 6.5 7.7 7.7 6.0 6.9 5.6 7.7 6.3 6.7 7.2
#> [127] 6.2 6.1 6.4 7.2 7.4 7.9 6.4 6.3 6.1 7.7 6.3 6.4 6.0 6.9 6.7 6.9 5.8 6.8
#> [145] 6.7 6.7 6.3 6.5 6.2 5.9

tibble::as_tibble(iris)[,1]
#> # A tibble: 150 × 1
#>    Sepal.Length
#>           <dbl>
#>  1          5.1
#>  2          4.9
#>  3          4.7
#>  4          4.6
#>  5          5  
#>  6          5.4
#>  7          4.6
#>  8          5  
#>  9          4.4
#> 10          4.9
#> # … with 140 more rows

Created on 2023-09-19 by the reprex package (v2.0.1)

This means that with x being a tibble, the result will be a one-column tibble instead of a vector, leading to the encountered syntax error.