schoon schoon - 2 years ago 134
Scala Question

Why can't I display prediction column of Spark MultilayerPerceptronClassifier?

I am using Spark's MultilayerPerceptronClassifier. This generates a column 'predicted' in 'predictions'. When I try to show it I get the error:

SparkException: Failed to execute user defined function($anonfun$1: (vector) => double) ...
Caused by: java.lang.IllegalArgumentException: requirement failed: A & B Dimension mismatch!

Other columns, for example, vector display OK.
Part of predictions schema:

|-- vector: vector (nullable = true)
|-- prediction: double (nullable = true)

My code is:

//racist is boolean, needs to be string:
val train2 = train.withColumn("racist", 'racist.cast("String"))
val test2 = test.withColumn("racist", 'racist.cast("String"))

val indexer = new StringIndexer().setInputCol("racist").setOutputCol("indexracist")

val word2Vec = new Word2Vec().setInputCol("lemma").setOutputCol("vector") //.setVectorSize(3).setMinCount(0)

val layers = Array[Int](4,5, 2)

val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100).setFeaturesCol("vector").setLabelCol("indexracist")

val pipeline = new Pipeline().setStages(Array(indexer, word2Vec, mpc))

val model =

val predictions = model.transform(test2)"prediction").show()

EDIT the proposed similar question's problem was

val layers = Array[Int](0, 0, 0, 0)

which is not the case here, nor is it the same error.

EDIT AGAIN: part0 of train and test are saved in PARQUET format here.

Answer Source

The addition of .setVectorSize(3).setMinCount(0) and changnig val layers = Array[Int](3,5, 2) made it work:

val word2Vec = new Word2Vec().setInputCol("lemma").setOutputCol("vector").setVectorSize(3).setMinCount(0)

// specify layers for the neural network:
// input layer of size 4 (features), two intermediate of size 5 and 4
// and output of size 3 (classes)
val layers = Array[Int](3,5, 2)
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download