midnightfalcon midnightfalcon - 1 year ago 527
Python Question

Pyspark Dataframe Imputations -- Replace Unknown & Missing Values with Column Mean based on specified condition

Given a Spark dataframe, I would like to compute a column mean based on the non-missing and non-unknown values for that column. I would then like to take this mean and use it to replace the column's missing & unknown values.

For example, assuming I'm working with a:

  • Dataframe named df, where each record represents one individual and all columns are integer or numeric

  • Column named age (ages for each record)

  • Column named missing_age (which equals 1 if that individual has no age, 0 otherwise)

  • Column named unknown_age (which equals 1 if that individual has unknown age, 0 otherwise)

Then I can compute this mean as shown below.

calc_mean = df.where((col("unknown_age") == 0) & (col("missing_age") == 0))

OR via SQL and windows functions,

mean_compute = hiveContext.sql("select avg(age) over() as mean from df
where missing_age = 0 and unknown_age = 0")

I don't want to use SQL/windows functions if I can help it. My challenge has been taking this mean and replacing the unknown/missing values with it using non-SQL methods.

I've tried using when(), where(), replace(), withColumn, UDFs, and combinations... Regardless of what I do, I either get errors or the results aren't what I expect. Here's an example of one of many things I've tried that didn't work.

imputed = df.when((col("unknown_age") == 1) | (col("missing_age") == 1),

I've scoured the web, but haven't found similar imputation type questions so any help is much appreciated. It could be something very simple that I've missed.

A side note -- I'm trying to apply this code to all columns in the Spark Dataframe that don't have unknown_ or missing_ in the column names. Can I just wrap the Spark related code in a Python 'for loop' and loop through all of the applicable columns to do this?


Also figured out how to loop through columns... Here's an example.

for x in df.columns:
if 'unknown_' not in x and 'missing_' not in x:
avg_compute = df.where(df['missing_' + x] != 1).agg(avg(x)).first()[0]
df = df.withColumn(x + 'mean_miss_imp', when((df['missing_' + x] == 1),


If age for unknown or missing is some value:

from pyspark.sql.functions import col, avg, when

df = sc.parallelize([
    (10, 0, 0), (20, 0, 0), (-1, 1, 0), (-1, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])

avg_age = df.where(
    (col("unknown_age") != 1) & (col("missing_age") != 1)

df.withColumn("age_imp", when(
    (col("unknown_age") == 1) | (col("missing_age") == 1), avg_age

If age for unknown or missing is NULL you can simplify this to:

df = sc.parallelize([
    (10, 0, 0), (20, 0, 0), (None, 1, 0), (None, 0, 1)
]).toDF(["age", "missing_age", "unknown_age"])

df.na.fill(df.na.drop().agg(avg("age")).first()[0], ["age"])