Anthony Anthony - 1 year ago 57
Scala Question

Exclude items from training set data

I have my data in two


contains all colors
contains some colors that I wish to exclude from my trainingset.

I am trying to split the data into a training and testing set and ensure that the colors in
are not in my training set but exist in the testing set.

In order to achieve the above, I did this

var colors = spark.sql("""
select colors.*
from colors
LEFT JOIN excluded_colors
ON excluded_colors.color_id = colors.color_id
where excluded_colors.color_id IS NULL
val trainer: (Int => Int) = (arg:Int) => 0
val sqlTrainer = udf(trainer)
val tester: (Int => Int) = (arg:Int) => 1
val sqlTester = udf(tester)

val rsplit = colors.randomSplit(Array(0.7, 0.3))
val train_colors = splits(0).select("color_id").withColumn("test",sqlTrainer(col("color_id")))
val test_colors = splits(1).select("color_id").withColumn("test",sqlTester(col("color_id")))

However, I'm realizing that by doing the above the colors in
are completely ignored. They are not even in my testing set.

How can I split the data in 70/30 while also ensuring that the colors in
are not in training but are present in testing.

Answer Source

What we want to do is remove the "excluded colors" from the training set but have them in the testing and have a training/test split of 70/30.

What we need is a bit of math.

Given the total dataset (TD) and the excluded colors dataset (E) we can say that for train dataset (Tr) and test dataset (Ts) that:

|Tr| = x * (|TD|-|E|)
|Ts| = |E| + (1-x) * |TD|

We also know that |Tr| = 0.7 |TD|

Hence x = 0.7 |TD| / (|TD| - |E|)

Now that we know the sampling factor x, we can say:

Tr = (TD-E).sample(withReplacement = false, fraction = x)
// where (TD - E) is the result of the SQL expr above

Ts = TD.sample(withReplacement = false, fraction = 0.3)
// we sample the test set from the original dataset
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download