2. Testing Spark Code Locally
Notebook cells are a good way to teach Spark, but production Spark logic should be testable outside a notebook. The easiest pattern is to move transformation logic into functions that accept DataFrames and return DataFrames. Tests can build tiny input DataFrames, call the function, and compare the result with an expected DataFrame.
Spark’s local mode makes those tests practical without a cluster.
[1]:
from pathlib import Path
import shutil
from pyspark.sql import SparkSession, functions as F
DATA_DIR = Path.cwd()
OUTPUT_DIR = DATA_DIR / "_spark_output" / "testing-spark"
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
spark = (
SparkSession.builder
.master("local[*]")
.appName("spark-intro-testing-spark")
.config("spark.driver.host", "127.0.0.1")
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.default.parallelism", "4")
.config("spark.sql.adaptive.enabled", "false")
.config("spark.ui.showConsoleProgress", "false")
.getOrCreate()
)
sc = spark.sparkContext
sc.setLogLevel("ERROR")
print("Spark version:", spark.version)
print("Spark master:", sc.master)
Spark version: 3.5.8
Spark master: local[*]
2.1. A Transformation Function
The function below contains no file reads, no writes, and no global Spark session. That makes it easy to test. The caller owns the input and output boundaries.
[2]:
def build_customer_features(customers, events):
clean_events = (
events
.withColumn("event_date", F.to_date("event_time"))
.withColumn("is_purchase", (F.col("event_type") == "purchase").cast("int"))
)
behavior = (
clean_events
.groupBy("customer_id")
.agg(
F.count("*").alias("event_count"),
F.sum("is_purchase").alias("purchase_count"),
F.round(F.sum("amount"), 2).alias("total_amount"),
)
)
return (
customers
.join(behavior, on="customer_id", how="left")
.fillna({"event_count": 0, "purchase_count": 0, "total_amount": 0.0})
)
2.2. Tiny Test Data
A good Spark unit test is usually tiny. The goal is not to simulate a cluster. The goal is to cover the transformation contract: input columns, edge cases, and expected rows.
[3]:
from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual
customers = spark.createDataFrame(
[(1, "north"), (2, "south"), (3, "west")],
["customer_id", "region"],
)
events = spark.createDataFrame(
[
(1, "2026-01-01", "view", 0.0),
(1, "2026-01-01", "purchase", 10.0),
(2, "2026-01-02", "purchase", 7.5),
],
["customer_id", "event_time", "event_type", "amount"],
)
actual = build_customer_features(customers, events).orderBy("customer_id")
expected = spark.createDataFrame(
[(1, "north", 2, 1, 10.0), (2, "south", 1, 1, 7.5), (3, "west", 0, 0, 0.0)],
["customer_id", "region", "event_count", "purchase_count", "total_amount"],
).orderBy("customer_id")
actual.show()
+-----------+------+-----------+--------------+------------+
|customer_id|region|event_count|purchase_count|total_amount|
+-----------+------+-----------+--------------+------------+
| 1| north| 2| 1| 10.0|
| 2| south| 1| 1| 7.5|
| 3| west| 0| 0| 0.0|
+-----------+------+-----------+--------------+------------+
2.3. Assertions
PySpark includes testing helpers for DataFrame equality and schema equality. Sorting before comparison removes accidental ordering differences from the test.
[4]:
assertSchemaEqual(actual.schema, expected.schema)
assertDataFrameEqual(actual, expected)
print("feature test passed")
feature test passed
2.4. Test The Edge Case You Expect To Break
Here the edge case is a customer with no events. The result should keep that customer and fill the numeric features with zeros. A direct assertion makes the intended behavior hard to miss.
[5]:
row = actual.where(F.col("customer_id") == 3).first().asDict()
assert row["event_count"] == 0
assert row["purchase_count"] == 0
assert row["total_amount"] == 0.0
print(row)
{'customer_id': 3, 'region': 'west', 'event_count': 0, 'purchase_count': 0, 'total_amount': 0.0}
2.5. How This Looks In A Test Suite
In a repository, the Spark session usually belongs in a pytest fixture with session scope. Individual tests should create small DataFrames, call transformation functions, and assert results. Keep tests local, deterministic, and independent of file paths unless the behavior being tested is file IO.
[6]:
import textwrap
example_test = '''
def test_build_customer_features(spark):
customers = spark.createDataFrame([(1, "north")], ["customer_id", "region"])
events = spark.createDataFrame([(1, "2026-01-01", "purchase", 10.0)], ["customer_id", "event_time", "event_type", "amount"])
actual = build_customer_features(customers, events)
assert actual.first()["purchase_count"] == 1
'''
print(textwrap.dedent(example_test).strip())
def test_build_customer_features(spark):
customers = spark.createDataFrame([(1, "north")], ["customer_id", "region"])
events = spark.createDataFrame([(1, "2026-01-01", "purchase", 10.0)], ["customer_id", "event_time", "event_type", "amount"])
actual = build_customer_features(customers, events)
assert actual.first()["purchase_count"] == 1
2.6. What To Remember
Test transformations as functions. Keep input data tiny. Compare DataFrames intentionally. Avoid tests that depend on cluster services, wall-clock time, random ordering, or already-existing output directories.
[7]:
spark.stop()