Data Skew in Spark : Using Salting while avoiding common mistakes
Data skew occurs when the data distribution across partitions is uneven. Imagine you’re working with user transaction data, and two users (“power users”) have hundreds of thousands of transactions, while most others only have a few. If you try to join this data with another dataset on user ID, few partitions containing the power users will be skewed and this can lead to various problems like —
- Hotspots: A few partitions handling way more data than others.
- Straggling tasks: Some tasks take forever to finish because they’re overloaded with data.
- Out-of-memory errors: Sometimes Spark can’t handle the overload and crashes.
Not fun, right?
Why Data Skew Matters in Spark
Apache Spark distributes data across partitions to parallelise processing. When performing operations like joins, groupBy etc.. skewed data can cause certain tasks to become stragglers, delaying the completion of the entire job.
Introducing Salting!
Salting is a clever way to handle data skew. The idea is to modify the skewed keys by adding a random value (a “salt”) so that the records spread out more evenly across partitions. Here’s how it works:
- Identify Skewed Keys: Figure out which keys are causing the skew.
- Add Salt to Skewed Keys: Append a random number (salt) to the skewed keys in your large dataset.
- Adjust the Other Dataset: Replicate or modify the corresponding keys in the other dataset to match the salted keys.
- Perform the Join: Join the datasets on these new salted keys.
- Post-Processing: Remove the salt if needed after the join.
A Practical Example Using Scala
Alright, enough theory. Let’s see how this works in practice!
Setting Up Our Data
We’ll create two datasets:
- Large Dataset A (
large_df):: A large dataset with transaction data, where someuser_ids are heavily skewed. - Small Dataset B (
small_df): A smaller dataset with user profiles.
Here’s how we’ll set them up:
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import scala.util.Random
val spark = SparkSession.builder()
.appName("SaltingExample")
.getOrCreate()
import spark.implicits._
// Large dataset with skewed user_ids
val large_df = Seq.fill(1000000) {
val skewedUserIds = Seq("user_1", "user_2")
val normalUserIds = (3 to 1000).map(i => s"user_$i")
val userId = if (Random.nextDouble() < 0.1) {
skewedUserIds(Random.nextInt(skewedUserIds.length))
} else {
normalUserIds(Random.nextInt(normalUserIds.length))
}
(userId, Random.nextDouble() * 100)
}.toDF("user_id", "transaction_amount")
// Small dataset with user profiles
val small_df = (1 to 1000).map { i =>
(s"user_$i", s"User Name $i", s"user_$i@example.com")
}.toDF("user_id", "name", "email")What’s Happening Here?
large_df: We’ve created a million rows of transactions. There’s a 10% chance that each transaction will have auser_idof either"user_1"or"user_2"(our skewed users). The remaining 90% are spread acrossuser_3touser_1000.small_df: This is a simple user profile dataset with 1,000 users, fromuser_1touser_1000.
Checking for Data Skew
Before we do anything, let’s check how skewed our large_df is.
val userCounts = large_df.groupBy("user_id").count()
userCounts.orderBy(desc("count")).show(5)Sample Output :
+--------+-----+
| user_id|count|
+--------+-----+
| user_2|51290|
| user_1|49690|
|user_270| 1200|
|user_678| 1190|
|user_382| 1180|
+--------+-----+
As you can see, user_1 and user_2 have around 50,000 transactions each, while other users have significantly fewer. That's our data skew right there!
Trying to Join Without Salting
Let’s attempt a regular join and see what happens:
val joined_df = large_df.join(small_df, "user_id")
joined_df.show()This join operation might cause performance issues because the partitions handling user_1 and user_2 will be overloaded. Not good!
Implementing Salting to Fix the Skew
Step 1: Identify Skewed Keys
We already know that user_1 and user_2 are skewed.
val skewedKeys = Seq("user_1", "user_2")Step 2: Add Salt to Skewed Keys in large_df
We’ll add a random salt to the skewed user_ids in our large dataset.
One mistake many developers do — to salt the entire skewed column. There’s no need to salt the entire skewed column, instead salt only the skewed values of that column.
// Number of salts
val numSalts = 10
// UDF to add salt to skewed keys
val addSaltUDF = udf((userId: String) => {
if (skewedKeys.contains(userId)) {
userId + "_" + Random.nextInt(numSalts)
} else {
userId
}
})
val salted_large_df = large_df.withColumn("salted_user_id", addSaltUDF(col("user_id")))Step 3: Adjust small_df to Match
For each skewed key in small_df, we’ll replicate it with every possible salt value to ensure they match during the join.
// Create salts DataFrame
val salts = spark.range(0, numSalts).withColumnRenamed("id", "salt")
// Filter skewed keys in small_df
val skewed_small_df = small_df.filter(col("user_id").isin(skewedKeys: _*))
// Replicate skewed keys with salts
val expanded_skewed_small_df = skewed_small_df.crossJoin(salts)
.withColumn("salted_user_id", concat(col("user_id"), lit("_"), col("salt")))
.drop("salt")
// Non-skewed keys remain the same
val non_skewed_small_df = small_df.filter(!col("user_id").isin(skewedKeys: _*))
.withColumn("salted_user_id", col("user_id"))
// Combine them back together
val salted_small_df = expanded_skewed_small_df.union(non_skewed_small_df)Step 4: Perform the Join on Salted Keys
Now we can perform the join on the salted keys.
val salted_joined_df = salted_large_df.join(salted_small_df, "salted_user_id")Step 5: Post-Processing
After the join, we can drop the salted keys if they’re no longer needed.
val final_df = salted_joined_df.drop("salted_user_id")Why Does This Work?
By adding a random salt to the skewed user_ids in large_df, we've spread those records across multiple keys:
"user_1"becomes"user_1_0","user_1_1", ...,"user_1_9".- Same for
"user_2".
We also adjusted small_df to include these salted keys, so the join operation can still find matches. This spreads the join workload across multiple partittions, making the operation much much faster!!
Benefits
- Balanced Load: The skewed data is now distributed across multiple partitions.
- Improved Performance: No single partition is overloaded, reducing the risk of slowdowns or straggling tasks.
- Scalable Solution: This works well even as data sizes grow.
Other Ways to Handle Data Skew
Salting is great, but there are other techniques you might consider:
1. Broadcast Joins
If one of your datasets is small enough to fit into memory, you can broadcast it to all executors to avoid shuffling..
import org.apache.spark.sql.functions.broadcast
val joined_df = large_df.join(broadcast(small_df), "user_id")2. Spark’s Built-in Skew Handling (Spark 3.0+)
Spark 3.0 introduced adaptive query execution with skew join handling.
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")3. Repartitioning after salting
Manually repartition your data(after salting) to distribute it more effectively.
val repartitioned_df = large_df.repartition(20, col("salted_user_id"))Wrapping Up
Data skew can be a major performance bottleneck in Spark, but with the right approach, you can overcome it. Salting is one such method, and while it may not be perfect for every situation, it’s a handy — trick to have up your sleeve.
Key Takeaways:
- Salting spreads skewed keys across multiple partitions to balance the load.
- Identify skewed keys and adjust both datasets to match the salted keys.
- Experiment with different numbers of salts to find the right balance.
- Consider Alternatives: Depending on your data, other methods like broadcast — joins might be more suitable.
If you found this article useful, please like and share your feedback.
Subscribe here to receive alerts when I post more such insightful contents.
Follow me on Medium and LinkedIn to stay connected : https://www.linkedin.com/in/ritam378
#ApacheSpark #Spark #DataEngineering #Optimisation #SparkTuning





