newbie_learner newbie_learner - 1 year ago 214
Scala Question

Scala: How to get a range of rows in a dataframe

I have a DataFrame created by running of a Parquet file.

The DataFrame consists of 300 M rows. I need to use these rows as input to another function, but I want to do it in smaller batches to prevent OOM error.

Currently, I am using df.head(1000000) to read the first 1M rows, but I cannot find a way to read the subsequent rows. I tried df.collect(), but it gives me a Java OOM error.

I want to iterate over this dataframe. I tried adding another column with the withColumn() API to generate a unique set of values to iterate over, but none of the existing columns in the dataframe have solely unique values.

For example, I tried val df = df1.withColumn("newColumn", df1("col") + 1) as well as val df = df1.withColumn("newColumn",lit(i+=1)), both of which do not return a sequential set of values.

Any other way to get the first n rows of a dataframe and then the next n rows, something that works like a range function of SqlContext?

Answer Source

You can simple use the limit and except api of dataset or dataframes as follows

long count = df.count();
int limit = 50;
while(count > 0){
    df1 = df.limit(limit);;            //will print 50, next 50, etc rows
    df = df.except(df1);
    count = count - limit;
