5. Broadcast Variables And Accumulators
Broadcast variables and accumulators belong to Spark’s lower-level distributed programming model. They are less common in DataFrame-first code, but they are useful to understand because they explain how driver-side values can be shared with tasks and how tasks can report diagnostic counts back to the driver.
[1]:
from pathlib import Path
import shutil
from pyspark.sql import SparkSession, functions as F
DATA_DIR = Path.cwd()
OUTPUT_DIR = DATA_DIR / "_spark_output" / "broadcast-accumulators"
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
spark = (
SparkSession.builder
.master("local[*]")
.appName("spark-intro-broadcast-accumulators")
.config("spark.driver.host", "127.0.0.1")
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.default.parallelism", "4")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.ui.showConsoleProgress", "false")
.getOrCreate()
)
sc = spark.sparkContext
sc.setLogLevel("ERROR")
print("Spark version:", spark.version)
print("Spark master:", sc.master)
Spark version: 3.5.8
Spark master: local[*]
5.1. Broadcast Variables
A broadcast variable sends a read-only value from the driver to executors once, then tasks can reuse it. This is useful for small lookup dictionaries in RDD code. In DataFrame code, a broadcast join is usually the better API.
[2]:
tax_rates = {"north": 0.05, "south": 0.06, "west": 0.075}
broadcast_rates = sc.broadcast(tax_rates)
orders = sc.parallelize([
(1, "north", 100.00),
(2, "south", 80.00),
(3, "west", 40.00),
(4, "unknown", 25.00),
])
def add_tax(row):
order_id, region, amount = row
rate = broadcast_rates.value.get(region, 0.0)
return (order_id, region, amount, round(amount * rate, 2))
orders.map(add_tax).collect()
[2]:
[(1, 'north', 100.0, 5.0),
(2, 'south', 80.0, 4.8),
(3, 'west', 40.0, 3.0),
(4, 'unknown', 25.0, 0.0)]
5.2. DataFrame Alternative
When the lookup data is tabular, prefer a small DataFrame and a broadcast join. Spark can reason about this plan more directly than a Python dictionary lookup inside an RDD function.
[3]:
order_df = spark.createDataFrame(orders, ["order_id", "region", "amount"])
rate_df = spark.createDataFrame([(k, v) for k, v in tax_rates.items()], ["region", "tax_rate"])
with_tax = (
order_df
.join(F.broadcast(rate_df), on="region", how="left")
.fillna({"tax_rate": 0.0})
.withColumn("tax", F.round(F.col("amount") * F.col("tax_rate"), 2))
)
with_tax.orderBy("order_id").show()
with_tax.explain("formatted")
+-------+--------+------+--------+---+
| region|order_id|amount|tax_rate|tax|
+-------+--------+------+--------+---+
| north| 1| 100.0| 0.05|5.0|
| south| 2| 80.0| 0.06|4.8|
| west| 3| 40.0| 0.075|3.0|
|unknown| 4| 25.0| 0.0|0.0|
+-------+--------+------+--------+---+
== Physical Plan ==
* Project (7)
+- * Project (6)
+- * BroadcastHashJoin LeftOuter BuildRight (5)
:- * Scan ExistingRDD (1)
+- BroadcastExchange (4)
+- * Filter (3)
+- * Scan ExistingRDD (2)
(1) Scan ExistingRDD [codegen id : 2]
Output [3]: [order_id#0L, region#1, amount#2]
Arguments: [order_id#0L, region#1, amount#2], MapPartitionsRDD[6] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(2) Scan ExistingRDD [codegen id : 1]
Output [2]: [region#6, tax_rate#7]
Arguments: [region#6, tax_rate#7], MapPartitionsRDD[11] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(3) Filter [codegen id : 1]
Input [2]: [region#6, tax_rate#7]
Condition : isnotnull(region#6)
(4) BroadcastExchange
Input [2]: [region#6, tax_rate#7]
Arguments: HashedRelationBroadcastMode(List(input[0, string, false]),false), [plan_id=72]
(5) BroadcastHashJoin [codegen id : 2]
Left keys [1]: [region#1]
Right keys [1]: [region#6]
Join type: LeftOuter
Join condition: None
(6) Project [codegen id : 2]
Output [4]: [region#1, order_id#0L, amount#2, coalesce(nanvl(tax_rate#7, null), 0.0) AS tax_rate#18]
Input [5]: [order_id#0L, region#1, amount#2, region#6, tax_rate#7]
(7) Project [codegen id : 2]
Output [5]: [region#1, order_id#0L, amount#2, tax_rate#18, round((amount#2 * tax_rate#18), 2) AS tax#23]
Input [4]: [region#1, order_id#0L, amount#2, tax_rate#18]
5.3. Accumulators
An accumulator lets tasks add to a value that the driver can read after an action. Use accumulators for diagnostics, not for business logic. Tasks may be retried, so accumulator values are not a transactional data output.
[4]:
invalid_regions = sc.accumulator(0)
valid_regions = set(tax_rates)
broadcast_valid_regions = sc.broadcast(valid_regions)
records = sc.parallelize(["north,10.0", "bad,7.0", "south,5.0", "west,9.0", "bad,1.0"])
def parse_record(line):
region, amount = line.split(",")
if region not in broadcast_valid_regions.value:
invalid_regions.add(1)
return None
return (region, float(amount))
parsed = records.map(parse_record).filter(lambda row: row is not None).collect()
print("valid records:", parsed)
print("invalid region count:", invalid_regions.value)
valid records: [('north', 10.0), ('south', 5.0), ('west', 9.0)]
invalid region count: 2
5.4. What To Remember
Broadcast variables share small read-only values with tasks. Accumulators report task-side diagnostics back to the driver. For structured data, use DataFrame joins and aggregations first; reach for broadcast variables and accumulators when the lower-level behavior is exactly what you need.
[5]:
broadcast_rates.unpersist()
broadcast_valid_regions.unpersist()
spark.stop()