I have a Spark DataFrame with the following structure:
root
|-- distribution: vector (nullable = true)
+--------------------+
| topicDistribution|
+--------------------+
| [0.1, 0.2] |
| [0.3, 0.2] |
| [0.5, 0.2] |
| [0.1, 0.7] |
| [0.1, 0.8] |
| [0.1, 0.9] |
+--------------------+
root
|-- distribution: vector (nullable = true)
|-- max_index: integer (nullable = true)
+--------------------+-----------+
| topicDistribution| max_index |
+--------------------+-----------+
| [0.1, 0.2] | 1 |
| [0.3, 0.2] | 0 |
| [0.5, 0.2] | 0 |
| [0.1, 0.7] | 1 |
| [0.1, 0.8] | 1 |
| [0.1, 0.9] | 1 |
+--------------------+-----------+
import org.apache.spark.sql.functions.udf
val func = udf( (x: Vector[Double]) => x.indices.maxBy(x) )
df.withColumn("max_idx",func(($"topicDistribution"))).show()
Exception in thread "main" org.apache.spark.sql.AnalysisException:
cannot resolve 'UDF(topicDistribution)' due to data type mismatch:
argument 1 requires array<double> type, however, '`topicDistribution`'
is of vector type.;;
// create some sample data:
import org.apache.spark.mllib.linalg.{Vectors,Vector}
case class myrow(topics:Vector)
val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2))))
val mydf = sqlContext.createDataFrame(rdd)
mydf.show()
+----------+
| topics|
+----------+
|[0.1, 0.2]|
|[0.6, 0.2]|
+----------+
// build the udf
import org.apache.spark.sql.functions.udf
val func = udf( (x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values) )
mydf.withColumn("max_idx",func($"topics")).show()
+----------+-------+
| topics|max_idx|
+----------+-------+
|[0.1, 0.2]| 1|
|[0.6, 0.2]| 0|
+----------+-------+
// note: you might have to change the UDF to be Vector instead of Seq for your particular use-case //edited to use Vector instead of Seq as you original question and your comment asked