Adrian Adrian - 3 months ago 35
R Question

Conversion from R data.table to nested list

Supposed I'd like to create a nested list from an R data.table, as in the toy example below:

library(data.table)

generate_dt <- function(num_unique_id=100, rows_per_id=2) {
num_rows <- num_unique_id * rows_per_id
my_dt <- data.table(my_id=rep(seq(1, num_unique_id), rows_per_id),
y1=rnorm(num_rows), y2=rnorm(num_rows), y3=rnorm(num_rows),
z=runif(num_rows))
setkey(my_dt, my_id)
return(my_dt)
}

## Suppose I want to go from my_dt to a nested list
list_from_dt <- function(my_dt) {
num_unique_id <- length(unique(my_dt$my_id))
my_list <- lapply(seq_len(num_unique_id), function(id) {
my_dt_subset <- my_dt[J(id)]
return(list(y=as.matrix(my_dt_subset[, c("y1", "y2", "y3"), with=FALSE]),
max_z=max(my_dt_subset$z)))

})
stopifnot(is.matrix(my_list[[1]]$y))
return(my_list)
}

my_dt <- generate_dt()
my_list <- list_from_dt(my_dt) # Suppose I have some code that expects a nested list like this

system.time(replicate(100, unused <- generate_dt())) # Fast
system.time(replicate(100, unused <- list_from_dt(my_dt))) # Roughly 200 times slower


Why is creating the nested list so slow relative to creating the data table? Is there a way to speed up my
list_from_dt
function? I assume the lookups into
my_dt
are relatively fast since it is keyed by id. Is the bottleneck coming from allocating lots of little fragmented pieces of memory for the matrices in my nested list?

Answer

Here's what I see with split and gmax:

f = function(){
    s  = lapply(split(my_dt[, !"z", with=FALSE], by="my_id", keep.by=FALSE), as.matrix)
    mz = my_dt[, max(z), by=my_id]
    Map(list, ys = s, mz = mz$V1)
}

system.time(replicate(100, generate_dt()))          #  0.8
system.time(replicate(100, list_from_dt(my_dt)))    # 20.1
system.time(replicate(100, f()))                    #  2.1

It looks like this:

> head(res, 2)
$`1`
$`1`$ys
              y1          y2          y3
[1,] -0.04493979 -1.01340856  0.08481358
[2,] -0.75860610  0.04113645 -0.36270441

$`1`$mz
[1] 0.9362695


$`2`
$`2`$ys
            y1         y2        y3
[1,] 0.7718361 -0.8005803 1.2195464
[2,] 0.1658420 -1.2846028 0.4607024

$`2`$mz
[1] 0.8551927

The numbers, `1` and `2` are the my_id values, now serving as names for list elements.