Anthony Anthony - 24 days ago 6
Scala Question

Exclude items from training set data

I have my data in two

colors
and
excluded_colors
.

colors
contains all colors
excluded_colors
contains some colors that I wish to exclude from my trainingset.

I am trying to split the data into a training and testing set and ensure that the colors in
excluded_colors
are not in my training set but exist in the testing set.

In order to achieve the above, I did this

var colors = spark.sql("""
select colors.*
from colors
LEFT JOIN excluded_colors
ON excluded_colors.color_id = colors.color_id
where excluded_colors.color_id IS NULL
"""
)
val trainer: (Int => Int) = (arg:Int) => 0
val sqlTrainer = udf(trainer)
val tester: (Int => Int) = (arg:Int) => 1
val sqlTester = udf(tester)

val rsplit = colors.randomSplit(Array(0.7, 0.3))
val train_colors = splits(0).select("color_id").withColumn("test",sqlTrainer(col("color_id")))
val test_colors = splits(1).select("color_id").withColumn("test",sqlTester(col("color_id")))


However, I'm realizing that by doing the above the colors in
excluded_colors
are completely ignored. They are not even in my testing set.

Question
How can I split the data in 70/30 while also ensuring that the colors in
excluded_colors
are not in training but are present in testing.

Answer

What we want to do is remove the "excluded colors" from the training set but have them in the testing and have a training/test split of 70/30.

What we need is a bit of math.

Given the total dataset (TD) and the excluded colors dataset (E) we can say that for train dataset (Tr) and test dataset (Ts) that:

|Tr| = x * (|TD|-|E|)
|Ts| = |E| + (1-x) * |TD|

We also know that |Tr| = 0.7 |TD|

Hence x = 0.7 |TD| / (|TD| - |E|)

Now that we know the sampling factor x, we can say:

Tr = (TD-E).sample(withReplacement = false, fraction = x)
// where (TD - E) is the result of the SQL expr above

Ts = TD.sample(withReplacement = false, fraction = 0.3)
// we sample the test set from the original dataset