5. Broadcast Variables And Accumulators

Broadcast variables and accumulators belong to Spark’s lower-level distributed programming model. They are less common in DataFrame-first code, but they are useful to understand because they explain how driver-side values can be shared with tasks and how tasks can report diagnostic counts back to the driver.

[1]:
from pathlib import Path
import shutil

from pyspark.sql import SparkSession, functions as F

DATA_DIR = Path.cwd()
OUTPUT_DIR = DATA_DIR / "_spark_output" / "broadcast-accumulators"

if OUTPUT_DIR.exists():
    shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("spark-intro-broadcast-accumulators")
    .config("spark.driver.host", "127.0.0.1")
    .config("spark.driver.bindAddress", "127.0.0.1")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.default.parallelism", "4")
    .config("spark.sql.adaptive.enabled", "false")
    .config("spark.ui.showConsoleProgress", "false")
    .getOrCreate()
)

sc = spark.sparkContext
sc.setLogLevel("ERROR")

print("Spark version:", spark.version)
print("Spark master:", sc.master)

Spark version: 3.5.8
Spark master: local[*]

5.1. Broadcast Variables

A broadcast variable sends a read-only value from the driver to executors once, then tasks can reuse it. This is useful for small lookup dictionaries in RDD code. In DataFrame code, a broadcast join is usually the better API.

[2]:
tax_rates = {"north": 0.05, "south": 0.06, "west": 0.075}
broadcast_rates = sc.broadcast(tax_rates)

orders = sc.parallelize([
    (1, "north", 100.00),
    (2, "south", 80.00),
    (3, "west", 40.00),
    (4, "unknown", 25.00),
])

def add_tax(row):
    order_id, region, amount = row
    rate = broadcast_rates.value.get(region, 0.0)
    return (order_id, region, amount, round(amount * rate, 2))

orders.map(add_tax).collect()

[2]:
[(1, 'north', 100.0, 5.0),
 (2, 'south', 80.0, 4.8),
 (3, 'west', 40.0, 3.0),
 (4, 'unknown', 25.0, 0.0)]

5.2. DataFrame Alternative

When the lookup data is tabular, prefer a small DataFrame and a broadcast join. Spark can reason about this plan more directly than a Python dictionary lookup inside an RDD function.

[3]:
order_df = spark.createDataFrame(orders, ["order_id", "region", "amount"])
rate_df = spark.createDataFrame([(k, v) for k, v in tax_rates.items()], ["region", "tax_rate"])

with_tax = (
    order_df
    .join(F.broadcast(rate_df), on="region", how="left")
    .fillna({"tax_rate": 0.0})
    .withColumn("tax", F.round(F.col("amount") * F.col("tax_rate"), 2))
)

with_tax.orderBy("order_id").show()
with_tax.explain("formatted")

+-------+--------+------+--------+---+
| region|order_id|amount|tax_rate|tax|
+-------+--------+------+--------+---+
|  north|       1| 100.0|    0.05|5.0|
|  south|       2|  80.0|    0.06|4.8|
|   west|       3|  40.0|   0.075|3.0|
|unknown|       4|  25.0|     0.0|0.0|
+-------+--------+------+--------+---+

== Physical Plan ==
* Project (7)
+- * Project (6)
   +- * BroadcastHashJoin LeftOuter BuildRight (5)
      :- * Scan ExistingRDD (1)
      +- BroadcastExchange (4)
         +- * Filter (3)
            +- * Scan ExistingRDD (2)


(1) Scan ExistingRDD [codegen id : 2]
Output [3]: [order_id#0L, region#1, amount#2]
Arguments: [order_id#0L, region#1, amount#2], MapPartitionsRDD[6] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)

(2) Scan ExistingRDD [codegen id : 1]
Output [2]: [region#6, tax_rate#7]
Arguments: [region#6, tax_rate#7], MapPartitionsRDD[11] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)

(3) Filter [codegen id : 1]
Input [2]: [region#6, tax_rate#7]
Condition : isnotnull(region#6)

(4) BroadcastExchange
Input [2]: [region#6, tax_rate#7]
Arguments: HashedRelationBroadcastMode(List(input[0, string, false]),false), [plan_id=72]

(5) BroadcastHashJoin [codegen id : 2]
Left keys [1]: [region#1]
Right keys [1]: [region#6]
Join type: LeftOuter
Join condition: None

(6) Project [codegen id : 2]
Output [4]: [region#1, order_id#0L, amount#2, coalesce(nanvl(tax_rate#7, null), 0.0) AS tax_rate#18]
Input [5]: [order_id#0L, region#1, amount#2, region#6, tax_rate#7]

(7) Project [codegen id : 2]
Output [5]: [region#1, order_id#0L, amount#2, tax_rate#18, round((amount#2 * tax_rate#18), 2) AS tax#23]
Input [4]: [region#1, order_id#0L, amount#2, tax_rate#18]


5.3. Accumulators

An accumulator lets tasks add to a value that the driver can read after an action. Use accumulators for diagnostics, not for business logic. Tasks may be retried, so accumulator values are not a transactional data output.

[4]:
invalid_regions = sc.accumulator(0)
valid_regions = set(tax_rates)
broadcast_valid_regions = sc.broadcast(valid_regions)

records = sc.parallelize(["north,10.0", "bad,7.0", "south,5.0", "west,9.0", "bad,1.0"])

def parse_record(line):
    region, amount = line.split(",")
    if region not in broadcast_valid_regions.value:
        invalid_regions.add(1)
        return None
    return (region, float(amount))

parsed = records.map(parse_record).filter(lambda row: row is not None).collect()
print("valid records:", parsed)
print("invalid region count:", invalid_regions.value)

valid records: [('north', 10.0), ('south', 5.0), ('west', 9.0)]
invalid region count: 2

5.4. What To Remember

Broadcast variables share small read-only values with tasks. Accumulators report task-side diagnostics back to the driver. For structured data, use DataFrame joins and aggregations first; reach for broadcast variables and accumulators when the lower-level behavior is exactly what you need.

[5]:
broadcast_rates.unpersist()
broadcast_valid_regions.unpersist()
spark.stop()