2. Joins, Skew, And Data Movement
Most slow Spark jobs are slow because data has to move. A join can require Spark to bring matching keys together from many partitions. That movement is called a shuffle. This chapter focuses on practical join questions: what rows does the join keep, when does Spark shuffle, when can Spark broadcast, and what happens when one key is much larger than the others.
The examples are tiny, but the behavior is the same behavior that matters at scale.
[1]:
from pathlib import Path
import shutil
from pyspark.sql import SparkSession, functions as F
DATA_DIR = Path.cwd()
OUTPUT_DIR = DATA_DIR / "_spark_output" / "joins-skew"
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
spark = (
SparkSession.builder
.master("local[*]")
.appName("spark-intro-joins-skew")
.config("spark.driver.host", "127.0.0.1")
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.default.parallelism", "4")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.ui.showConsoleProgress", "false")
.config("spark.sql.autoBroadcastJoinThreshold", "-1")
.getOrCreate()
)
sc = spark.sparkContext
sc.setLogLevel("ERROR")
print("Spark version:", spark.version)
print("Spark master:", sc.master)
Spark version: 3.5.8
Spark master: local[*]
2.1. A Small Customer And Order Dataset
The customer table has one row per customer. The order table has many rows per customer. This is a common shape: a smaller dimension table joined to a larger event or transaction table.
[2]:
customers = spark.createDataFrame(
[(1, "Ada", "north"), (2, "Ben", "south"), (3, "Cid", "west"), (4, "Dia", "north")],
["customer_id", "name", "region"],
)
orders = spark.createDataFrame(
[(101, 1, 25.50), (102, 1, 80.00), (103, 2, 19.99), (104, 3, 120.00), (105, 99, 12.00)],
["order_id", "customer_id", "amount"],
)
customers.show()
orders.show()
+-----------+----+------+
|customer_id|name|region|
+-----------+----+------+
| 1| Ada| north|
| 2| Ben| south|
| 3| Cid| west|
| 4| Dia| north|
+-----------+----+------+
+--------+-----------+------+
|order_id|customer_id|amount|
+--------+-----------+------+
| 101| 1| 25.5|
| 102| 1| 80.0|
| 103| 2| 19.99|
| 104| 3| 120.0|
| 105| 99| 12.0|
+--------+-----------+------+
2.2. Join Types
Use the join type to say which side of the relationship must be preserved. Inner joins keep matches. Left joins keep all rows on the left. Semi joins keep left rows that have a match but do not append columns from the right. Anti joins keep left rows that do not have a match.
[3]:
print("inner")
orders.join(customers, on="customer_id", how="inner").orderBy("order_id").show()
print("left")
orders.join(customers, on="customer_id", how="left").orderBy("order_id").show()
print("semi: customers with orders")
customers.join(orders, on="customer_id", how="left_semi").orderBy("customer_id").show()
print("anti: orders without a customer")
orders.join(customers, on="customer_id", how="left_anti").orderBy("order_id").show()
inner
+-----------+--------+------+----+------+
|customer_id|order_id|amount|name|region|
+-----------+--------+------+----+------+
| 1| 101| 25.5| Ada| north|
| 1| 102| 80.0| Ada| north|
| 2| 103| 19.99| Ben| south|
| 3| 104| 120.0| Cid| west|
+-----------+--------+------+----+------+
left
+-----------+--------+------+----+------+
|customer_id|order_id|amount|name|region|
+-----------+--------+------+----+------+
| 1| 101| 25.5| Ada| north|
| 1| 102| 80.0| Ada| north|
| 2| 103| 19.99| Ben| south|
| 3| 104| 120.0| Cid| west|
| 99| 105| 12.0|NULL| NULL|
+-----------+--------+------+----+------+
semi: customers with orders
+-----------+----+------+
|customer_id|name|region|
+-----------+----+------+
| 1| Ada| north|
| 2| Ben| south|
| 3| Cid| west|
+-----------+----+------+
anti: orders without a customer
+-----------+--------+------+
|customer_id|order_id|amount|
+-----------+--------+------+
| 99| 105| 12.0|
+-----------+--------+------+
2.3. Duplicate Key Explosions
Joins match every row on the left with every row on the right for the same key. If both sides have duplicate keys, the output can grow much faster than expected. When a join result is surprisingly large, count duplicate keys on both sides before tuning anything else.
[4]:
left = spark.createDataFrame([(1, "L1"), (1, "L2"), (2, "L3")], ["id", "left_value"])
right = spark.createDataFrame([(1, "R1"), (1, "R2"), (1, "R3")], ["id", "right_value"])
left.join(right, on="id", how="inner").orderBy("left_value", "right_value").show()
left.groupBy("id").count().withColumnRenamed("count", "left_rows").join(
right.groupBy("id").count().withColumnRenamed("count", "right_rows"),
on="id",
how="outer",
).fillna(0).show()
+---+----------+-----------+
| id|left_value|right_value|
+---+----------+-----------+
| 1| L1| R1|
| 1| L1| R2|
| 1| L1| R3|
| 1| L2| R1|
| 1| L2| R2|
| 1| L2| R3|
+---+----------+-----------+
+---+---------+----------+
| id|left_rows|right_rows|
+---+---------+----------+
| 2| 1| 0|
| 1| 2| 3|
+---+---------+----------+
2.4. Shuffle Join Versus Broadcast Join
With automatic broadcast disabled in this session, Spark plans a shuffle join. The physical plan contains Exchange nodes because both sides must be repartitioned by the join key. When one side is small, a broadcast hint can avoid shuffling both tables.
[5]:
shuffle_join = orders.join(customers, on="customer_id", how="inner")
print("Shuffle join plan")
shuffle_join.explain("formatted")
broadcast_join = orders.join(F.broadcast(customers), on="customer_id", how="inner")
print("Broadcast join plan")
broadcast_join.explain("formatted")
Shuffle join plan
== Physical Plan ==
* Project (10)
+- * SortMergeJoin Inner (9)
:- * Sort (4)
: +- Exchange (3)
: +- * Filter (2)
: +- * Scan ExistingRDD (1)
+- * Sort (8)
+- Exchange (7)
+- * Filter (6)
+- * Scan ExistingRDD (5)
(1) Scan ExistingRDD [codegen id : 1]
Output [3]: [order_id#6L, customer_id#7L, amount#8]
Arguments: [order_id#6L, customer_id#7L, amount#8], MapPartitionsRDD[9] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(2) Filter [codegen id : 1]
Input [3]: [order_id#6L, customer_id#7L, amount#8]
Condition : isnotnull(customer_id#7L)
(3) Exchange
Input [3]: [order_id#6L, customer_id#7L, amount#8]
Arguments: hashpartitioning(customer_id#7L, 4), ENSURE_REQUIREMENTS, [plan_id=444]
(4) Sort [codegen id : 2]
Input [3]: [order_id#6L, customer_id#7L, amount#8]
Arguments: [customer_id#7L ASC NULLS FIRST], false, 0
(5) Scan ExistingRDD [codegen id : 3]
Output [3]: [customer_id#0L, name#1, region#2]
Arguments: [customer_id#0L, name#1, region#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(6) Filter [codegen id : 3]
Input [3]: [customer_id#0L, name#1, region#2]
Condition : isnotnull(customer_id#0L)
(7) Exchange
Input [3]: [customer_id#0L, name#1, region#2]
Arguments: hashpartitioning(customer_id#0L, 4), ENSURE_REQUIREMENTS, [plan_id=450]
(8) Sort [codegen id : 4]
Input [3]: [customer_id#0L, name#1, region#2]
Arguments: [customer_id#0L ASC NULLS FIRST], false, 0
(9) SortMergeJoin [codegen id : 5]
Left keys [1]: [customer_id#7L]
Right keys [1]: [customer_id#0L]
Join type: Inner
Join condition: None
(10) Project [codegen id : 5]
Output [5]: [customer_id#7L, order_id#6L, amount#8, name#1, region#2]
Input [6]: [order_id#6L, customer_id#7L, amount#8, customer_id#0L, name#1, region#2]
Broadcast join plan
== Physical Plan ==
* Project (7)
+- * BroadcastHashJoin Inner BuildRight (6)
:- * Filter (2)
: +- * Scan ExistingRDD (1)
+- BroadcastExchange (5)
+- * Filter (4)
+- * Scan ExistingRDD (3)
(1) Scan ExistingRDD [codegen id : 2]
Output [3]: [order_id#6L, customer_id#7L, amount#8]
Arguments: [order_id#6L, customer_id#7L, amount#8], MapPartitionsRDD[9] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(2) Filter [codegen id : 2]
Input [3]: [order_id#6L, customer_id#7L, amount#8]
Condition : isnotnull(customer_id#7L)
(3) Scan ExistingRDD [codegen id : 1]
Output [3]: [customer_id#0L, name#1, region#2]
Arguments: [customer_id#0L, name#1, region#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(4) Filter [codegen id : 1]
Input [3]: [customer_id#0L, name#1, region#2]
Condition : isnotnull(customer_id#0L)
(5) BroadcastExchange
Input [3]: [customer_id#0L, name#1, region#2]
Arguments: HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=497]
(6) BroadcastHashJoin [codegen id : 2]
Left keys [1]: [customer_id#7L]
Right keys [1]: [customer_id#0L]
Join type: Inner
Join condition: None
(7) Project [codegen id : 2]
Output [5]: [customer_id#7L, order_id#6L, amount#8, name#1, region#2]
Input [6]: [order_id#6L, customer_id#7L, amount#8, customer_id#0L, name#1, region#2]
2.5. Skewed Keys
A skewed key has far more records than other keys. If Spark partitions data by that key, one task may receive most of the work while other tasks finish quickly. A fast way to detect skew is to count records by key and sort by the count.
[6]:
skewed_orders = spark.createDataFrame(
[(order_id, 1, 1.0) for order_id in range(1000, 1040)]
+ [(2001, 2, 3.0), (2002, 3, 4.0), (2003, 4, 5.0)],
["order_id", "customer_id", "amount"],
)
skewed_orders.groupBy("customer_id").count().orderBy(F.desc("count")).show()
+-----------+-----+
|customer_id|count|
+-----------+-----+
| 1| 40|
| 2| 1|
| 3| 1|
| 4| 1|
+-----------+-----+
2.6. Salting A Skewed Join
One mitigation is salting. A salt is an extra key that spreads one hot key across several join keys. The other side of the join must be expanded to include the same salt values. Salting is not the first fix for every skewed join, but it is useful when a large-large join has a few heavy keys.
[7]:
num_salts = 4
salts = spark.range(num_salts).withColumnRenamed("id", "salt")
salted_orders = skewed_orders.withColumn(
"salt",
F.pmod(F.xxhash64("order_id"), F.lit(num_salts)).cast("long"),
)
salted_customers = customers.crossJoin(salts)
salted_orders.groupBy("customer_id", "salt").count().orderBy("customer_id", "salt").show()
salted_join = salted_orders.join(salted_customers, on=["customer_id", "salt"], how="inner")
salted_join.select("order_id", "customer_id", "salt", "name", "amount").orderBy("order_id").show(8)
+-----------+----+-----+
|customer_id|salt|count|
+-----------+----+-----+
| 1| 0| 6|
| 1| 1| 11|
| 1| 2| 9|
| 1| 3| 14|
| 2| 2| 1|
| 3| 3| 1|
| 4| 1| 1|
+-----------+----+-----+
+--------+-----------+----+----+------+
|order_id|customer_id|salt|name|amount|
+--------+-----------+----+----+------+
| 1000| 1| 3| Ada| 1.0|
| 1001| 1| 1| Ada| 1.0|
| 1002| 1| 1| Ada| 1.0|
| 1003| 1| 3| Ada| 1.0|
| 1004| 1| 2| Ada| 1.0|
| 1005| 1| 1| Ada| 1.0|
| 1006| 1| 1| Ada| 1.0|
| 1007| 1| 1| Ada| 1.0|
+--------+-----------+----+----+------+
only showing top 8 rows
2.7. What To Remember
Pick the join type by row semantics first. Then inspect duplicate keys, skew, and the physical plan. Prefer broadcast joins when one side is truly small. Use salting only when a hot key is large enough to dominate the work and simpler fixes do not apply.
[8]:
spark.stop()