2. Testing Spark Code Locally

Notebook cells are a good way to teach Spark, but production Spark logic should be testable outside a notebook. The easiest pattern is to move transformation logic into functions that accept DataFrames and return DataFrames. Tests can build tiny input DataFrames, call the function, and compare the result with an expected DataFrame.

Spark’s local mode makes those tests practical without a cluster.

[1]:
from pathlib import Path
import shutil

from pyspark.sql import SparkSession, functions as F

DATA_DIR = Path.cwd()
OUTPUT_DIR = DATA_DIR / "_spark_output" / "testing-spark"

if OUTPUT_DIR.exists():
    shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("spark-intro-testing-spark")
    .config("spark.driver.host", "127.0.0.1")
    .config("spark.driver.bindAddress", "127.0.0.1")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.default.parallelism", "4")
    .config("spark.sql.adaptive.enabled", "false")
    .config("spark.ui.showConsoleProgress", "false")
    .getOrCreate()
)

sc = spark.sparkContext
sc.setLogLevel("ERROR")

print("Spark version:", spark.version)
print("Spark master:", sc.master)

Spark version: 3.5.8
Spark master: local[*]

2.1. A Transformation Function

The function below contains no file reads, no writes, and no global Spark session. That makes it easy to test. The caller owns the input and output boundaries.

[2]:
def build_customer_features(customers, events):
    clean_events = (
        events
        .withColumn("event_date", F.to_date("event_time"))
        .withColumn("is_purchase", (F.col("event_type") == "purchase").cast("int"))
    )

    behavior = (
        clean_events
        .groupBy("customer_id")
        .agg(
            F.count("*").alias("event_count"),
            F.sum("is_purchase").alias("purchase_count"),
            F.round(F.sum("amount"), 2).alias("total_amount"),
        )
    )

    return (
        customers
        .join(behavior, on="customer_id", how="left")
        .fillna({"event_count": 0, "purchase_count": 0, "total_amount": 0.0})
    )

2.2. Tiny Test Data

A good Spark unit test is usually tiny. The goal is not to simulate a cluster. The goal is to cover the transformation contract: input columns, edge cases, and expected rows.

[3]:
from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual

customers = spark.createDataFrame(
    [(1, "north"), (2, "south"), (3, "west")],
    ["customer_id", "region"],
)

events = spark.createDataFrame(
    [
        (1, "2026-01-01", "view", 0.0),
        (1, "2026-01-01", "purchase", 10.0),
        (2, "2026-01-02", "purchase", 7.5),
    ],
    ["customer_id", "event_time", "event_type", "amount"],
)

actual = build_customer_features(customers, events).orderBy("customer_id")
expected = spark.createDataFrame(
    [(1, "north", 2, 1, 10.0), (2, "south", 1, 1, 7.5), (3, "west", 0, 0, 0.0)],
    ["customer_id", "region", "event_count", "purchase_count", "total_amount"],
).orderBy("customer_id")

actual.show()

+-----------+------+-----------+--------------+------------+
|customer_id|region|event_count|purchase_count|total_amount|
+-----------+------+-----------+--------------+------------+
|          1| north|          2|             1|        10.0|
|          2| south|          1|             1|         7.5|
|          3|  west|          0|             0|         0.0|
+-----------+------+-----------+--------------+------------+

2.3. Assertions

PySpark includes testing helpers for DataFrame equality and schema equality. Sorting before comparison removes accidental ordering differences from the test.

[4]:
assertSchemaEqual(actual.schema, expected.schema)
assertDataFrameEqual(actual, expected)
print("feature test passed")

feature test passed

2.4. Test The Edge Case You Expect To Break

Here the edge case is a customer with no events. The result should keep that customer and fill the numeric features with zeros. A direct assertion makes the intended behavior hard to miss.

[5]:
row = actual.where(F.col("customer_id") == 3).first().asDict()
assert row["event_count"] == 0
assert row["purchase_count"] == 0
assert row["total_amount"] == 0.0
print(row)

{'customer_id': 3, 'region': 'west', 'event_count': 0, 'purchase_count': 0, 'total_amount': 0.0}

2.5. How This Looks In A Test Suite

In a repository, the Spark session usually belongs in a pytest fixture with session scope. Individual tests should create small DataFrames, call transformation functions, and assert results. Keep tests local, deterministic, and independent of file paths unless the behavior being tested is file IO.

[6]:
import textwrap

example_test = '''
def test_build_customer_features(spark):
    customers = spark.createDataFrame([(1, "north")], ["customer_id", "region"])
    events = spark.createDataFrame([(1, "2026-01-01", "purchase", 10.0)], ["customer_id", "event_time", "event_type", "amount"])
    actual = build_customer_features(customers, events)
    assert actual.first()["purchase_count"] == 1
'''
print(textwrap.dedent(example_test).strip())

def test_build_customer_features(spark):
    customers = spark.createDataFrame([(1, "north")], ["customer_id", "region"])
    events = spark.createDataFrame([(1, "2026-01-01", "purchase", 10.0)], ["customer_id", "event_time", "event_type", "amount"])
    actual = build_customer_features(customers, events)
    assert actual.first()["purchase_count"] == 1

2.6. What To Remember

Test transformations as functions. Keep input data tiny. Compare DataFrames intentionally. Avoid tests that depend on cluster services, wall-clock time, random ordering, or already-existing output directories.

[7]:
spark.stop()