Tim - 11 months ago 81

Scala Question

I have dataframe in Spark. Looks like this:

`+-------+----------+-------+`

| value| group| ts|

+-------+----------+-------+

| A| X| 1|

| B| X| 2|

| B| X| 3|

| D| X| 4|

| E| X| 5|

| A| Y| 1|

| C| Y| 2|

+-------+----------+-------+

Endgoal: I'd like to find how many sequences

`A-B-E`

`n`

`n`

Consider group

`X`

In this case there is exactly 1

`D`

`B`

`E`

`B`

`B`

`E`

`A-B-E`

I have thought about using

`collect_list()`

Edit:

Note that the provided dataframe is just an example. The real dataframe (and thus groups) can be arbitrary long.

Answer Source

Edited to answer @Tim's comment + fix patterns of the type "AABE"

Yep, using a window function helps, but I created an `id`

to have an ordering:

```
val df = List(
(1,"A","X",1),
(2,"B","X",2),
(3,"B","X",3),
(4,"D","X",4),
(5,"E","X",5),
(6,"A","Y",1),
(7,"C","Y",2)
).toDF("id","value","group","ts")
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy('group).orderBy('id)
```

Then lag will collect what is needed, but a function is required to generate the `Column`

expression (note the split to eliminate double counting of "AABE". *WARNING (TODO): this creates a bug on patterns of the type "ABAE" which are then ignored*):

```
def createSeq(m:Int) = split(
concat(
(1 to 2*m)
.map(i => coalesce(lag('value,-i).over(w),lit("")))
:_*),"A")(0)
val m=2
val tmp = df
.withColumn("seq",createSeq(m))
+---+-----+-----+---+----+
| id|value|group| ts| seq|
+---+-----+-----+---+----+
| 6| A| Y| 1| C|
| 7| C| Y| 2| |
| 1| A| X| 1|BBDE|
| 2| B| X| 2| BDE|
| 3| B| X| 3| DE|
| 4| D| X| 4| E|
| 5| E| X| 5| |
+---+-----+-----+---+----+
```

Because of the poor set of collection functions available in the `Column`

API, avoiding regex altogether is much easier using a UDF

```
def patternInSeq(m: Int) = udf((str: String) => {
var notFound = str
.split("B")
.filter(_.contains("E"))
.filter(_.indexOf("E") <= m)
.isEmpty
!notFound
})
val res = tmp
.filter(('value === "A") && (locate("B",'seq) > 0))
.filter(locate("B",'seq) <= m && (locate("E",'seq) > 1))
.filter(patternInSeq(m)('seq))
.groupBy('group)
.count
res.show
+-----+-----+
|group|count|
+-----+-----+
| X| 1|
+-----+-----+
```