Group by does not maintain order in Pyspark; use a window function instead
What was a customer’s first purchase? What is a company’s most recent address? When did this last user log in? These types of questions are many times useful for analytics and are needed in batch processing. Data engineers and data scientists who use spark can use the built in functions module to calculate these values. However, many times these values are calculated incorrectly. This article will explain how the orderBy().groupBy() process is not deterministic and how to use window functions instead.
Spark and orderBy
First, to create a dummy dataset, I created a dataframe by doing the following:
from datetime import date
from pyspark.sql import Row
data = [
{"customer_id": "1a", "date": date(2011, 5, 4), "sale": 45},
{"customer_id": "1a", "date": date(2013, 5, 4), "sale": 20},
{"customer_id": "1a", "date": date(2015, 5, 4), "sale": 10},
{"customer_id": "2a", "date": date(2014, 5, 4), "sale": 100},
{"customer_id": "2a", "date": date(2015, 5, 4), "sale": 200},
{"customer_id": "2a", "date": date(2016, 5, 4), "sale": 300},
{"customer_id": "3a", "date": date(2018, 5, 4), "sale": 7},
{"customer_id": "3a", "date": date(2019, 5, 4), "sale": 9},
{"customer_id": "3a", "date": date(2020, 5, 4), "sale": 8},
]
df = spark.createDataFrame(Row(**x) for x in data).repartition(6)
A common technique in order to calculate the first/last value of a Dataframe looks much like the following:
first_value = (
df
.orderBy("date")
.groupBy("customer_id")
.agg(
F.first("purchase_date").alias("first_purchase_date"),
F.first("sale").alias("sale")
)
)
The logic behind this order of operations is that the dataframe has been ordered by date, so now taking the first value will be the value with the first date. While this may be true for some of the dataset, you will not get the desired value for the entire dataset. Why is this so? Because the orderBy function is carried out prior to the grouping. Then, when calling the groupBy, the data can require a shuffle and remix the order. If a single customer id is in more than one partition, the date order is then randomized after the shuffle. The following show the results for the above query:
first_value.show()
+-----------+-------------------+----+
|customer_id|first_purchase_date|sale|
+-----------+-------------------+----+
| 1a| 2013-05-04| 20|
| 2a| 2016-05-04| 300|
| 3a| 2020-05-04| 8|
+-----------+-------------------+----+
As you can see, none of the expected values for first_purchase_date or sale are in the resulting dataframe.
How to use a window function for order
Instead of ordering the data, you should use a window function to get the first value. The following pattern is common in many workflows and allows you to select the first value
from pyspark.sql.window import Window
window = Window.partitionBy("customer_id").orderBy(F.col("date").asc())
first_value = (
df
.withColumn("row_num", F.row_number().over(window))
.filter(F.col("row_num") == 1)
.select(
"customer_id",
F.col("date").alias("first_purchase_date"),
"sale"
)
)
first_value.show()
+-----------+-------------------+----+
|customer_id|first_purchase_date|sale|
+-----------+-------------------+----+
| 1a| 2011-05-04| 45|
| 2a| 2014-05-04| 100|
| 3a| 2018-05-04| 7|
+-----------+-------------------+----+
This above process will return the first date and sale value for a given customer id.
Bonus: Rank
If instead of just the first value, you need all transactions with the first value, you should not use the row_number function. Instead spark has a built in function called rank. For the above dataset if customer 1a has two transactions on the first day, they can both be selected like so:
from pyspark.sql import Row
data = [
{"customer_id": "1a", "date": date(2011, 5, 4), "sale": 45},
{"customer_id": "1a", "date": date(2011, 5, 4), "sale": 20},
{"customer_id": "1a", "date": date(2015, 5, 4), "sale": 10}
]
from pyspark.sql.window import Window
window = Window.partitionBy("customer_id").orderBy(F.col("date").asc())
first_value = (
df
.withColumn("row_num", F.rank().over(window))
.filter(F.col("row_num") == 1)
.select(
"customer_id",
F.col("date").alias("first_purchase_date"),
"sale"
)
)
first_value.show()
+-----------+-------------------+----+
|customer_id|first_purchase_date|sale|
+-----------+-------------------+----+
| 1a| 2011-05-04| 20|
| 1a| 2011-05-04| 45|
+-----------+-------------------+----+