krunarsson krunarsson -4 years ago 164
Scala Question

Create one hot encoded vector from category list in Spark

If I have data containing 5 categories (A,B,C,D,E) and a dataset of customers where each customer can belong to one, many or none of the categories. How can I take a data set like this:

id, categories
1 , [A,C]
2 , [B]
3 , []
4 , [D,E]


and transform the categories column to one hot encoded vectors, like this

id, categories, encoded
1 , [A,C] , [1,0,1,0,0]
2 , [B] , [0,1,0,0,0]
3 , [] , [0,0,0,0,0]
4 , [D,E] , [0,0,0,1,1]


Has anyone found a simple way to do this in spark?

Answer Source

Something very easy to do, which is somewhat the same is using a CountVectorizerModel

val df = spark.createDataFrame(Seq(
  (1, Seq("A","C")),
  (2, Seq("B")),
  (3, Seq()),
  (4, Seq("D","E")))
).toDF("id", "category")

val cvm = new CountVectorizerModel(Array("A","B","C","D","E"))
  .setInputCol("category")
  .setOutputCol("features")

cvm.transform(df).show()

/*
+---+--------+-------------------+
| id|category|           features|
+---+--------+-------------------+
|  1|  [A, C]|(5,[0,2],[1.0,1.0])|
|  2|     [B]|      (5,[1],[1.0])|
|  3|      []|          (5,[],[])|
|  4|  [D, E]|(5,[3,4],[1.0,1.0])|
+---+--------+-------------------+
*/

This isn't exactly like what you had wanted, but the feature vector will tell you what Categories exist in your data. For instance in row 1, [0,2] corresponds to the first and 3rd element of the dictionary, or "A" and "C" as written there.

Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download