2. DataFrames
A simple way to think of a DataFrame in Spark is to see it as a distributed table of data or a distributed CSV file. However, this distributed table of data or CSV file comes highly glorified with bells and whistles and battery included. The highly prized features of Spark DataFrames is that they are not unlike R or pandas data frames and you can issue SQL-like commands against them.
As a side note, there are three main distributed data structures in Spark.
- RDD
- DataFrame
- DataSet
The RDD data structure is the original distributed data structure and the records can be anything. RDDs were very friendly to experienced data engineers and programmers. DataFrames are a movement away RDDs, and provide tabular structure to records and made Spark accessible to other types of data programmers (such as those who are comfortable with SQL). Still, DataFrames were too generic, and DataSets were created to have the tabular structure of DataFrames where the records could be specifically
defined. A DataFrame is just a DataSet, where the records are of the type Record, and a DataSet is said to be a strongly-typed, user-defined distributed, tabular data structure.
It is said, when you are using a RDD, you are describing how you are doing something and when you are using a DataFrame or DataSet, you are describing what you are doing. The how you are doing something relates to the imperative programming paradigm, and the what you are doing relates to the functional programming paradigm. It is argued that code that tells what you are doing is easier to understand than code that tells how you are doing it. However, which one of
these approaches is easier to understand differs from person to person. Sometimes, functional programming style of coding results in highly nested code. Take the example below assuming that we are composing behavior through the composition of functions.
is(this(how(we(want(to(code))))))
2.1. Acquiring a DataFrame
2.1.1. Convert Pandas DataFrame to Spark DataFrame
The easiest way to get a Spark DataFrame is to convert a pandas DataFrame to a Spark one. There’s a convenience method from the sqlContext to do so, createDataFrame(). Below, the pandas DataFrame is set to pdf and the Spark DataFrame is set to sdf.
[1]:
import pandas as pd
from random import randint
n_cols = 10
n_rows = 10
data = [tuple([c for c in range(n_cols)]) for r in range(n_rows)]
cols = [f'x{i}' for i in range(n_cols)]
pdf = pd.DataFrame(data, columns=cols)
sdf = sqlContext.createDataFrame(pdf)
[2]:
pdf
[2]:
| x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 1 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 2 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 3 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 4 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 5 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 6 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 7 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 8 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
| 9 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 
Notice how the Spark DataFrame’s records are of type Row? The Row behaves just like Python’s tuple and dictionary types.
[3]:
sdf.collect()
[3]:
[Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9)]
Let’s grab the first row of this DataFrame and see how we can interact with it.
[4]:
row = sdf.take(1)[0]
row
[4]:
Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9)
We can access the fields of a Row via index notation (like a tuple).
[5]:
for i in range(len(row)):
    print(row[i])
0
1
2
3
4
5
6
7
8
9
We can also reference the values by keys in a row (like a dictionary).
[6]:
for i in range(len(row)):
    key = f'x{i}'
    print(row[key])
0
1
2
3
4
5
6
7
8
9
Sometimes, it’s just better to convert the row to a dictionary and then use our Python knowledge of iterating and manipulating dictionaries.
[7]:
for k, v in row.asDict().items():
    print(f'{k}: {v}')
x0: 0
x1: 1
x2: 2
x3: 3
x4: 4
x5: 5
x6: 6
x7: 7
x8: 8
x9: 9
Enough of Rows, let’s get back to the Spark DataFrame. We can display the contents of a Spark DataFrame with show().
[8]:
sdf.show()
+---+---+---+---+---+---+---+---+---+---+
| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+---+---+---+---+---+---+---+---+---+---+
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
+---+---+---+---+---+---+---+---+---+---+
What if we want to inspect the schema of a DataFrame?
[9]:
sdf.printSchema()
root
 |-- x0: long (nullable = true)
 |-- x1: long (nullable = true)
 |-- x2: long (nullable = true)
 |-- x3: long (nullable = true)
 |-- x4: long (nullable = true)
 |-- x5: long (nullable = true)
 |-- x6: long (nullable = true)
 |-- x7: long (nullable = true)
 |-- x8: long (nullable = true)
 |-- x9: long (nullable = true)
2.1.2. Convert a RDD to DataFrame
A RDD can be converted to a DataFrame, however, we need to create a schema. Below is a concise way of creating a schema for a RDD. Notice that createDataFrame() is overloaded? Before, we passed in a pandas DataFrame. Here, we pass in a RDD and a schema.
[10]:
from random import randint
from pyspark.sql.types import *
n_cols = 10
n_rows = 10
rdd = sc.parallelize([[c for c in range(n_cols)] for r in range(n_rows)])
schema = StructType([StructField(f'x{i}', IntegerType(), True) for i in range(n_cols)])
df = sqlContext.createDataFrame(rdd, schema)
If the above example is too concise, let’s do build the schema manually. The schema is defined by a StructType and the StructType is based off of a list of StructFields.
[11]:
struct_fields = [
    StructField('x0', IntegerType(), True),
    StructField('x1', IntegerType(), True),
    StructField('x2', IntegerType(), True),
    StructField('x3', IntegerType(), True),
    StructField('x4', IntegerType(), True),
    StructField('x5', IntegerType(), True),
    StructField('x6', IntegerType(), True),
    StructField('x7', IntegerType(), True),
    StructField('x8', IntegerType(), True),
    StructField('x9', IntegerType(), True)
]
struct_type = StructType(struct_fields)
df = sqlContext.createDataFrame(rdd, struct_type)
Let’s inspect the RDD. It’s a list of lists (of integers).
[12]:
rdd.collect()
[12]:
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
Let’s inspect the DataFrame. It’s a list of Rows.
[13]:
df.collect()
[13]:
[Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9),
 Row(x0=0, x1=1, x2=2, x3=3, x4=4, x5=5, x6=6, x7=7, x8=8, x9=9)]
Here’s a display of the DataFrame.
[14]:
df.show()
+---+---+---+---+---+---+---+---+---+---+
| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+---+---+---+---+---+---+---+---+---+---+
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
|  0|  1|  2|  3|  4|  5|  6|  7|  8|  9|
+---+---+---+---+---+---+---+---+---+---+
And let’s inspect the schema.
[15]:
df.printSchema()
root
 |-- x0: integer (nullable = true)
 |-- x1: integer (nullable = true)
 |-- x2: integer (nullable = true)
 |-- x3: integer (nullable = true)
 |-- x4: integer (nullable = true)
 |-- x5: integer (nullable = true)
 |-- x6: integer (nullable = true)
 |-- x7: integer (nullable = true)
 |-- x8: integer (nullable = true)
 |-- x9: integer (nullable = true)
2.1.3. Convert JSON data to Spark DataFrame
We have seen how to create a Spark DataFrame from a pandas DataFrame or a RDD. Let’s see how we can create a Spark DataFrame from reading a JSON file. First, let’s upload the JSON file to HDFS.
[16]:
%%sh
hdfs dfs -copyFromLocal -f /root/ipynb/people.json /people.json
2022-02-19 16:46:03,575 INFO sasl.SaslDataTransferClient: SASL encryption trust check: localHostTrusted = false, remoteHostTrusted = false
Use the sqlContext.read.json() method to read the JSON file from HDFS.
[17]:
df = sqlContext.read.json('hdfs://localhost/people.json')
Let’s see what’s inside the Spark DataFrame. Notice how the Spark DataFrame is still tabular or table-ish? We know that JSON is a highly nested structure, and so where there’s nesting, only the top-level keys are mapped to the fields/columns of the DataFrame.
[18]:
df.show()
+--------------------+---+----------+------+---+---------+-----+--------------------+------+
|             address|age|first_name|height| id|last_name| male|              sports|weight|
+--------------------+---+----------+------+---+---------+-----+--------------------+------+
|[Washington, DC, ...| 27|      John|   6.5|  1|      Doe| true|    [hockey, tennis]| 155.5|
|[Washington, DC, ...| 22|      Jane|   5.7|  2|    Smith|false|[basketball, tennis]| 135.5|
|[Los Angeles, CA,...| 25|      Jack|   6.6|  3|    Smith| true|  [baseball, soccer]| 175.5|
|[Los Angeles, CA,...| 18|     Janet|   5.5|  4|      Doe|false|    [judo, baseball]| 125.5|
+--------------------+---+----------+------+---+---------+-----+--------------------+------+
But then, inspect the schema. The schema reflects the nested JSON data.
[19]:
df.printSchema()
root
 |-- address: struct (nullable = true)
 |    |-- city: string (nullable = true)
 |    |-- state: string (nullable = true)
 |    |-- street: string (nullable = true)
 |    |-- zip: long (nullable = true)
 |-- age: long (nullable = true)
 |-- first_name: string (nullable = true)
 |-- height: double (nullable = true)
 |-- id: long (nullable = true)
 |-- last_name: string (nullable = true)
 |-- male: boolean (nullable = true)
 |-- sports: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- weight: double (nullable = true)
2.2. Data Types and Schema
Spark has many data types. Below, we show a few of the many types. Take note of the use of an ArrayType whose elements are MapType; this type is very sophisticated.
[20]:
from pyspark.sql.types import *
meta = [{'is_president': 'True'}, {'is_active': 'False'}]
data = [
    ('Joe Biden', 23, 180.8, 6.2, {'line1': '123 Main', 'line2': 'Washington, DC 20021'}, [80.1, 90.2], meta),
    ('Donald Trump', 22, 190.9, 6.5, {'line1': '234 Main', 'line2': 'Washington, DC 20031'}, [90.1, 99.2], meta)
]
schema = StructType([
    StructField('name', StringType(), True),
    StructField('age', IntegerType(), True),
    StructField('weight', FloatType(), True),
    StructField('height', DoubleType(), True),
    StructField('address', MapType(StringType(), StringType()), True),
    StructField('scores', ArrayType(FloatType()), True),
    StructField('meta', ArrayType(MapType(StringType(), StringType())), True)
])
df = spark.createDataFrame(data=data, schema=schema)
[21]:
df.printSchema()
root
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- weight: float (nullable = true)
 |-- height: double (nullable = true)
 |-- address: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)
 |-- scores: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- meta: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)
[22]:
df.show(truncate=False)
+------------+---+------+------+--------------------------------------------------+------------+----------------------------------------------+
|name        |age|weight|height|address                                           |scores      |meta                                          |
+------------+---+------+------+--------------------------------------------------+------------+----------------------------------------------+
|Joe Biden   |23 |180.8 |6.2   |[line2 -> Washington, DC 20021, line1 -> 123 Main]|[80.1, 90.2]|[[is_president -> True], [is_active -> False]]|
|Donald Trump|22 |190.9 |6.5   |[line2 -> Washington, DC 20031, line1 -> 234 Main]|[90.1, 99.2]|[[is_president -> True], [is_active -> False]]|
+------------+---+------+------+--------------------------------------------------+------------+----------------------------------------------+
2.3. DataFrame operations
What can we actually do with a Spark DataFrame? Can we do amazing things with DataFrames as with RDDs?
2.3.1. Create data
Let’s create a dummy Spark DataFrame first.
[23]:
import pandas as pd
from random import randint, choice
def generate_num(col):
    if col == 3:
        p = randint(1, 100)
        if p < 70:
            return None
        return randint(1, 10)
    return randint(1, 10)
def generate_height():
    return choice(['tall', 'short'])
n_cols = 10
n_rows = 10
pdf = pd.DataFrame(
    [tuple([generate_height()] + [generate_num(c) for c in range(n_cols)]) for r in range(n_rows)],
    columns=['height'] + [f'x{i}' for i in range(n_cols)])
sdf = sqlContext.createDataFrame(pdf)
[24]:
pdf
[24]:
| height | x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | tall | 5 | 3 | 1 | NaN | 6 | 5 | 4 | 10 | 3 | 2 | 
| 1 | tall | 5 | 2 | 3 | NaN | 2 | 2 | 2 | 4 | 1 | 5 | 
| 2 | tall | 8 | 9 | 9 | 3.0 | 7 | 4 | 6 | 9 | 3 | 10 | 
| 3 | tall | 2 | 8 | 2 | NaN | 7 | 2 | 4 | 7 | 2 | 8 | 
| 4 | short | 2 | 1 | 1 | NaN | 7 | 3 | 3 | 10 | 9 | 1 | 
| 5 | short | 5 | 7 | 4 | NaN | 4 | 5 | 8 | 5 | 5 | 6 | 
| 6 | tall | 5 | 8 | 10 | 3.0 | 7 | 1 | 8 | 6 | 5 | 7 | 
| 7 | tall | 1 | 1 | 2 | 5.0 | 8 | 4 | 2 | 4 | 9 | 4 | 
| 8 | short | 2 | 7 | 10 | NaN | 8 | 7 | 3 | 8 | 6 | 4 | 
| 9 | short | 7 | 7 | 3 | NaN | 7 | 6 | 4 | 9 | 7 | 5 | 
2.3.2. Select
We can use select() to grab specific columns from a Spark DataFrame.
[25]:
sdf.select('x0').show()
+---+
| x0|
+---+
|  5|
|  5|
|  8|
|  2|
|  2|
|  5|
|  5|
|  1|
|  2|
|  7|
+---+
What if we want to grab multiple columns?
[26]:
sdf.select('x0', 'x1').show()
+---+---+
| x0| x1|
+---+---+
|  5|  3|
|  5|  2|
|  8|  9|
|  2|  8|
|  2|  1|
|  5|  7|
|  5|  8|
|  1|  1|
|  2|  7|
|  7|  7|
+---+---+
We can also select specific columns as follows using a Column object. You will see both variants in the wild; one with a list of column names, and the one below referencing the Column itself (from the Data Frame). When would you want to use the literal column name versus the object Column? Look below. When we need to transform the values in the column inline, we have to use the Column object.
[27]:
sdf.select(sdf['x0'], sdf['x1']).show()
+---+---+
| x0| x1|
+---+---+
|  5|  3|
|  5|  2|
|  8|  9|
|  2|  8|
|  2|  1|
|  5|  7|
|  5|  8|
|  1|  1|
|  2|  7|
|  7|  7|
+---+---+
We can even modify values that we are retrieving. Below, we multiply the first column we want by two and the second column we want by three. Observe the column names. Yuck!
[28]:
sdf.select(sdf['x0'] * 2, sdf['x1'] * 3).show()
+--------+--------+
|(x0 * 2)|(x1 * 3)|
+--------+--------+
|      10|       9|
|      10|       6|
|      16|      27|
|       4|      24|
|       4|       3|
|      10|      21|
|      10|      24|
|       2|       3|
|       4|      21|
|      14|      21|
+--------+--------+
We can fix the column names with alias(). Uh-oh! Notice how the parentheses are creeping in? Is this style of coding for what we are doing clear in intention?
[29]:
sdf.select((sdf['x0'] * 2).alias('y0'), (sdf['x1'] * 3).alias('y1')).show()
+---+---+
| y0| y1|
+---+---+
| 10|  9|
| 10|  6|
| 16| 27|
|  4| 24|
|  4|  3|
| 10| 21|
| 10| 24|
|  2|  3|
|  4| 21|
| 14| 21|
+---+---+
I suppose a little formatting might help.
[30]:
sdf.select(
    (sdf['x0'] * 2).alias('y0'),
    (sdf['x1'] * 3).alias('y1'))\
    .show()
+---+---+
| y0| y1|
+---+---+
| 10|  9|
| 10|  6|
| 16| 27|
|  4| 24|
|  4|  3|
| 10| 21|
| 10| 24|
|  2|  3|
|  4| 21|
| 14| 21|
+---+---+
We can also apply boolean expressions with select().
[31]:
sdf.select(sdf['x0'] > 5).show()
+--------+
|(x0 > 5)|
+--------+
|   false|
|   false|
|    true|
|   false|
|   false|
|   false|
|   false|
|   false|
|   false|
|    true|
+--------+
How do we get distinct values?
[32]:
sdf.select('x0').distinct().show()
+---+
| x0|
+---+
|  7|
|  5|
|  1|
|  8|
|  2|
+---+
We can do a set difference operation with subtract() as follows. We use distinct() to enforce uniqueness.
[33]:
x0 = sdf.select('x0')
x1 = sdf.select('x1')
diff = x0.subtract(x1)
diff.distinct().show()
+---+
| x0|
+---+
|  5|
+---+
We can also add columns.
[34]:
sdf.withColumn('height_truthy', sdf['height'] == 'tall').show()
+------+---+---+---+---+---+---+---+---+---+---+-------------+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|height_truthy|
+------+---+---+---+---+---+---+---+---+---+---+-------------+
|  tall|  5|  3|  1|NaN|  6|  5|  4| 10|  3|  2|         true|
|  tall|  5|  2|  3|NaN|  2|  2|  2|  4|  1|  5|         true|
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|         true|
|  tall|  2|  8|  2|NaN|  7|  2|  4|  7|  2|  8|         true|
| short|  2|  1|  1|NaN|  7|  3|  3| 10|  9|  1|        false|
| short|  5|  7|  4|NaN|  4|  5|  8|  5|  5|  6|        false|
|  tall|  5|  8| 10|3.0|  7|  1|  8|  6|  5|  7|         true|
|  tall|  1|  1|  2|5.0|  8|  4|  2|  4|  9|  4|         true|
| short|  2|  7| 10|NaN|  8|  7|  3|  8|  6|  4|        false|
| short|  7|  7|  3|NaN|  7|  6|  4|  9|  7|  5|        false|
+------+---+---+---+---+---+---+---+---+---+---+-------------+
Dropping columns is achieved with drop().
[35]:
sdf.withColumn('height_truthy', sdf['height'] == 'tall').columns
[35]:
['height',
 'x0',
 'x1',
 'x2',
 'x3',
 'x4',
 'x5',
 'x6',
 'x7',
 'x8',
 'x9',
 'height_truthy']
[36]:
sdf.withColumn('height_truthy', sdf['height'] == 'tall').drop('height_truthy').columns
[36]:
['height', 'x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
2.3.3. Filtering
How do we filter records in a Spark DataFrame?
[37]:
sdf.filter(sdf['x0'] > 5).show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
| short|  7|  7|  3|NaN|  7|  6|  4|  9|  7|  5|
+------+---+---+---+---+---+---+---+---+---+---+
Multiple filters?
[38]:
sdf.filter(sdf['x0'] > 5).filter(sdf['x1'] > 5).show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
| short|  7|  7|  3|NaN|  7|  6|  4|  9|  7|  5|
+------+---+---+---+---+---+---+---+---+---+---+
What about dropping duplicates?
[39]:
sdf.select('x0').dropDuplicates().show()
+---+
| x0|
+---+
|  7|
|  5|
|  1|
|  8|
|  2|
+---+
2.3.4. Ordering
How do we order the records?
[40]:
sdf.orderBy('x0').show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
|  tall|  1|  1|  2|5.0|  8|  4|  2|  4|  9|  4|
|  tall|  2|  8|  2|NaN|  7|  2|  4|  7|  2|  8|
| short|  2|  1|  1|NaN|  7|  3|  3| 10|  9|  1|
| short|  2|  7| 10|NaN|  8|  7|  3|  8|  6|  4|
|  tall|  5|  2|  3|NaN|  2|  2|  2|  4|  1|  5|
| short|  5|  7|  4|NaN|  4|  5|  8|  5|  5|  6|
|  tall|  5|  3|  1|NaN|  6|  5|  4| 10|  3|  2|
|  tall|  5|  8| 10|3.0|  7|  1|  8|  6|  5|  7|
| short|  7|  7|  3|NaN|  7|  6|  4|  9|  7|  5|
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
+------+---+---+---+---+---+---+---+---+---+---+
And ordering by multiple columns?
[41]:
sdf.orderBy('height', 'x0').show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
| short|  2|  7| 10|NaN|  8|  7|  3|  8|  6|  4|
| short|  2|  1|  1|NaN|  7|  3|  3| 10|  9|  1|
| short|  5|  7|  4|NaN|  4|  5|  8|  5|  5|  6|
| short|  7|  7|  3|NaN|  7|  6|  4|  9|  7|  5|
|  tall|  1|  1|  2|5.0|  8|  4|  2|  4|  9|  4|
|  tall|  2|  8|  2|NaN|  7|  2|  4|  7|  2|  8|
|  tall|  5|  8| 10|3.0|  7|  1|  8|  6|  5|  7|
|  tall|  5|  2|  3|NaN|  2|  2|  2|  4|  1|  5|
|  tall|  5|  3|  1|NaN|  6|  5|  4| 10|  3|  2|
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
+------+---+---+---+---+---+---+---+---+---+---+
And ordering descendingly?
[42]:
sdf.orderBy(sdf['height'].desc(), sdf['x0'].desc()).show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
|  tall|  5|  8| 10|3.0|  7|  1|  8|  6|  5|  7|
|  tall|  5|  3|  1|NaN|  6|  5|  4| 10|  3|  2|
|  tall|  5|  2|  3|NaN|  2|  2|  2|  4|  1|  5|
|  tall|  2|  8|  2|NaN|  7|  2|  4|  7|  2|  8|
|  tall|  1|  1|  2|5.0|  8|  4|  2|  4|  9|  4|
| short|  7|  7|  3|NaN|  7|  6|  4|  9|  7|  5|
| short|  5|  7|  4|NaN|  4|  5|  8|  5|  5|  6|
| short|  2|  7| 10|NaN|  8|  7|  3|  8|  6|  4|
| short|  2|  1|  1|NaN|  7|  3|  3| 10|  9|  1|
+------+---+---+---+---+---+---+---+---+---+---+
2.3.5. Missing values
How do we handle missing values? First, we can drop them.
[43]:
sdf.dropna().show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
|  tall|  5|  8| 10|3.0|  7|  1|  8|  6|  5|  7|
|  tall|  1|  1|  2|5.0|  8|  4|  2|  4|  9|  4|
+------+---+---+---+---+---+---+---+---+---+---+
Or maybe we want to set missing values to zero?
[44]:
sdf.fillna(-1).show()
+------+---+---+---+----+---+---+---+---+---+---+
|height| x0| x1| x2|  x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+----+---+---+---+---+---+---+
|  tall|  5|  3|  1|-1.0|  6|  5|  4| 10|  3|  2|
|  tall|  5|  2|  3|-1.0|  2|  2|  2|  4|  1|  5|
|  tall|  8|  9|  9| 3.0|  7|  4|  6|  9|  3| 10|
|  tall|  2|  8|  2|-1.0|  7|  2|  4|  7|  2|  8|
| short|  2|  1|  1|-1.0|  7|  3|  3| 10|  9|  1|
| short|  5|  7|  4|-1.0|  4|  5|  8|  5|  5|  6|
|  tall|  5|  8| 10| 3.0|  7|  1|  8|  6|  5|  7|
|  tall|  1|  1|  2| 5.0|  8|  4|  2|  4|  9|  4|
| short|  2|  7| 10|-1.0|  8|  7|  3|  8|  6|  4|
| short|  7|  7|  3|-1.0|  7|  6|  4|  9|  7|  5|
+------+---+---+---+----+---+---+---+---+---+---+
2.3.6. Group by
How do we do grouping?
[45]:
sdf.groupBy('height').count().show()
+------+-----+
|height|count|
+------+-----+
|  tall|    6|
| short|    4|
+------+-----+
We can also do aggregations agg() after a group-by.
[46]:
sdf.groupBy('height').agg({'x0': 'mean'}).show()
+------+-----------------+
|height|          avg(x0)|
+------+-----------------+
|  tall|4.333333333333333|
| short|              4.0|
+------+-----------------+
Multiple aggregations over different columns.
[47]:
sdf.groupBy('height').agg({'x0': 'mean', 'x1': 'mean'}).show()
+------+-----------------+-----------------+
|height|          avg(x0)|          avg(x1)|
+------+-----------------+-----------------+
|  tall|4.333333333333333|5.166666666666667|
| short|              4.0|              5.5|
+------+-----------------+-----------------+
Oh-uh, it seems if we want multiple aggregation over the same column, the results will not compute.
[48]:
sdf.groupBy('height').agg({'x0': 'mean', 'x0': 'stddev'}).show()
+------+-----------------+
|height|       stddev(x0)|
+------+-----------------+
|  tall|2.503331114069145|
| short|2.449489742783178|
+------+-----------------+
Here are some group functions.
[49]:
from pyspark.sql.functions import countDistinct, avg, stddev
sdf.select(avg('x0'), stddev('x0'), countDistinct('x0')).show()
+-------+------------------+------------------+
|avg(x0)|   stddev_samp(x0)|count(DISTINCT x0)|
+-------+------------------+------------------+
|    4.2|2.3475755815545347|                 5|
+-------+------------------+------------------+
2.3.7. Cross-tabulation
If we wanted to do cross-tabulation, we need to use crosstab().
[50]:
sdf.crosstab('height', 'x1').show()
+---------+---+---+---+---+---+---+
|height_x1|  1|  2|  3|  7|  8|  9|
+---------+---+---+---+---+---+---+
|    short|  1|  0|  0|  3|  0|  0|
|     tall|  1|  1|  1|  0|  2|  1|
+---------+---+---+---+---+---+---+
2.3.8. Statistics
I want statistics. Ugh, the standard deviation has too much precision. How can we fix the precision?
[51]:
sdf.describe('x0', 'x1').show()
+-------+------------------+----------------+
|summary|                x0|              x1|
+-------+------------------+----------------+
|  count|                10|              10|
|   mean|               4.2|             5.3|
| stddev|2.3475755815545343|3.16403399335581|
|    min|                 1|               1|
|    max|                 8|               9|
+-------+------------------+----------------+
With some coding gymnastics, we need to cast the columns x0 and x1 to DoubleType and then use format_number() to specify the precision.
[52]:
from pyspark.sql.functions import format_number, col
sdf.describe('x0', 'x1')\
    .withColumn('x0', col('x0').cast(DoubleType()))\
    .withColumn('x1', col('x1').cast(DoubleType()))\
    .select('summary', format_number('x0', 2).alias('x0'), format_number('x1', 2).alias('x1'))\
    .show()
+-------+-----+-----+
|summary|   x0|   x1|
+-------+-----+-----+
|  count|10.00|10.00|
|   mean| 4.20| 5.30|
| stddev| 2.35| 3.16|
|    min| 1.00| 1.00|
|    max| 8.00| 9.00|
+-------+-----+-----+
Here’s another way using selectExpr().
[53]:
sdf.describe('x0', 'x1')\
    .selectExpr('summary', 'cast(x0 as double) as x0', 'cast(x1 as double) as x1')\
    .select(
        'summary',
        format_number('x0', 2).alias('x0'),
        format_number('x1', 2).alias('x1'))\
    .show()
+-------+-----+-----+
|summary|   x0|   x1|
+-------+-----+-----+
|  count|10.00|10.00|
|   mean| 4.20| 5.30|
| stddev| 2.35| 3.16|
|    min| 1.00| 1.00|
|    max| 8.00| 9.00|
+-------+-----+-----+
2.3.9. Sampling
We can also sample.
[54]:
sdf.sample(True, 0.5, 37).show()
+------+---+---+---+---+---+---+---+---+---+---+
|height| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+------+---+---+---+---+---+---+---+---+---+---+
|  tall|  8|  9|  9|3.0|  7|  4|  6|  9|  3| 10|
| short|  2|  7| 10|NaN|  8|  7|  3|  8|  6|  4|
+------+---+---+---+---+---+---+---+---+---+---+
2.3.10. Windowing
Many times you want to window over your data. For example, if you have sequential records, you might want to a compute a new field based on the previous record and current one. Let’s mock an example of some transactions with dates. In this example, we want to compute the difference in days of the current transaction with the previous.
[55]:
pdf = pd.DataFrame({
    'user_id': [1 if i % 2 == 0 else 2 for i in range(1, 31)],
    'transaction_id': [i for i in range(1, 31)],
    'date': [f'2022-01-{d:02}' for d in range(1, 31)]
})
pdf['date'] = pd.to_datetime(pdf['date'])
sdf = sqlContext.createDataFrame(pdf)
sdf.show(10)
+-------+--------------+-------------------+
|user_id|transaction_id|               date|
+-------+--------------+-------------------+
|      2|             1|2022-01-01 00:00:00|
|      1|             2|2022-01-02 00:00:00|
|      2|             3|2022-01-03 00:00:00|
|      1|             4|2022-01-04 00:00:00|
|      2|             5|2022-01-05 00:00:00|
|      1|             6|2022-01-06 00:00:00|
|      2|             7|2022-01-07 00:00:00|
|      1|             8|2022-01-08 00:00:00|
|      2|             9|2022-01-09 00:00:00|
|      1|            10|2022-01-10 00:00:00|
+-------+--------------+-------------------+
only showing top 10 rows
To get the previous record’s date aligned with the current record, we can create a Window and use the lag function. In the window and lag, we specify the field we want to order and lag by. The first record will not have a previous record and so the previous date will be null.
[56]:
from pyspark.sql.window import Window
from pyspark.sql import functions as F
window = Window.orderBy('date')
sdf \
    .withColumn('prev_date', F.lag(F.col('date')).over(window)) \
    .show(10)
+-------+--------------+-------------------+-------------------+
|user_id|transaction_id|               date|          prev_date|
+-------+--------------+-------------------+-------------------+
|      2|             1|2022-01-01 00:00:00|               null|
|      1|             2|2022-01-02 00:00:00|2022-01-01 00:00:00|
|      2|             3|2022-01-03 00:00:00|2022-01-02 00:00:00|
|      1|             4|2022-01-04 00:00:00|2022-01-03 00:00:00|
|      2|             5|2022-01-05 00:00:00|2022-01-04 00:00:00|
|      1|             6|2022-01-06 00:00:00|2022-01-05 00:00:00|
|      2|             7|2022-01-07 00:00:00|2022-01-06 00:00:00|
|      1|             8|2022-01-08 00:00:00|2022-01-07 00:00:00|
|      2|             9|2022-01-09 00:00:00|2022-01-08 00:00:00|
|      1|            10|2022-01-10 00:00:00|2022-01-09 00:00:00|
+-------+--------------+-------------------+-------------------+
only showing top 10 rows
We can can specify the lag by specifying the count argument. Here, we lag by 2 records.
[57]:
sdf \
    .withColumn('prev_date', F.lag(F.col('date'), count=2).over(window)) \
    .show(10)
+-------+--------------+-------------------+-------------------+
|user_id|transaction_id|               date|          prev_date|
+-------+--------------+-------------------+-------------------+
|      2|             1|2022-01-01 00:00:00|               null|
|      1|             2|2022-01-02 00:00:00|               null|
|      2|             3|2022-01-03 00:00:00|2022-01-01 00:00:00|
|      1|             4|2022-01-04 00:00:00|2022-01-02 00:00:00|
|      2|             5|2022-01-05 00:00:00|2022-01-03 00:00:00|
|      1|             6|2022-01-06 00:00:00|2022-01-04 00:00:00|
|      2|             7|2022-01-07 00:00:00|2022-01-05 00:00:00|
|      1|             8|2022-01-08 00:00:00|2022-01-06 00:00:00|
|      2|             9|2022-01-09 00:00:00|2022-01-07 00:00:00|
|      1|            10|2022-01-10 00:00:00|2022-01-08 00:00:00|
+-------+--------------+-------------------+-------------------+
only showing top 10 rows
In this synthetic data, each transaction is associated with a user id. We can specify to window over each user using partitionBy.
[58]:
window = Window.partitionBy('user_id').orderBy('date')
sdf \
    .withColumn('prev_date', F.lag(F.col('date')).over(window)) \
    .show(31)
+-------+--------------+-------------------+-------------------+
|user_id|transaction_id|               date|          prev_date|
+-------+--------------+-------------------+-------------------+
|      1|             2|2022-01-02 00:00:00|               null|
|      1|             4|2022-01-04 00:00:00|2022-01-02 00:00:00|
|      1|             6|2022-01-06 00:00:00|2022-01-04 00:00:00|
|      1|             8|2022-01-08 00:00:00|2022-01-06 00:00:00|
|      1|            10|2022-01-10 00:00:00|2022-01-08 00:00:00|
|      1|            12|2022-01-12 00:00:00|2022-01-10 00:00:00|
|      1|            14|2022-01-14 00:00:00|2022-01-12 00:00:00|
|      1|            16|2022-01-16 00:00:00|2022-01-14 00:00:00|
|      1|            18|2022-01-18 00:00:00|2022-01-16 00:00:00|
|      1|            20|2022-01-20 00:00:00|2022-01-18 00:00:00|
|      1|            22|2022-01-22 00:00:00|2022-01-20 00:00:00|
|      1|            24|2022-01-24 00:00:00|2022-01-22 00:00:00|
|      1|            26|2022-01-26 00:00:00|2022-01-24 00:00:00|
|      1|            28|2022-01-28 00:00:00|2022-01-26 00:00:00|
|      1|            30|2022-01-30 00:00:00|2022-01-28 00:00:00|
|      2|             1|2022-01-01 00:00:00|               null|
|      2|             3|2022-01-03 00:00:00|2022-01-01 00:00:00|
|      2|             5|2022-01-05 00:00:00|2022-01-03 00:00:00|
|      2|             7|2022-01-07 00:00:00|2022-01-05 00:00:00|
|      2|             9|2022-01-09 00:00:00|2022-01-07 00:00:00|
|      2|            11|2022-01-11 00:00:00|2022-01-09 00:00:00|
|      2|            13|2022-01-13 00:00:00|2022-01-11 00:00:00|
|      2|            15|2022-01-15 00:00:00|2022-01-13 00:00:00|
|      2|            17|2022-01-17 00:00:00|2022-01-15 00:00:00|
|      2|            19|2022-01-19 00:00:00|2022-01-17 00:00:00|
|      2|            21|2022-01-21 00:00:00|2022-01-19 00:00:00|
|      2|            23|2022-01-23 00:00:00|2022-01-21 00:00:00|
|      2|            25|2022-01-25 00:00:00|2022-01-23 00:00:00|
|      2|            27|2022-01-27 00:00:00|2022-01-25 00:00:00|
|      2|            29|2022-01-29 00:00:00|2022-01-27 00:00:00|
+-------+--------------+-------------------+-------------------+
We can do something useful with windowing and lagging by getting the difference in the two dates.
[59]:
window = Window.partitionBy('user_id').orderBy('date')
sdf \
    .withColumn('prev_date', F.lag(F.col('date')).over(window)) \
    .withColumn('date_diff_seconds', F.unix_timestamp('date') - F.unix_timestamp('prev_date')) \
    .show(31)
+-------+--------------+-------------------+-------------------+-----------------+
|user_id|transaction_id|               date|          prev_date|date_diff_seconds|
+-------+--------------+-------------------+-------------------+-----------------+
|      1|             2|2022-01-02 00:00:00|               null|             null|
|      1|             4|2022-01-04 00:00:00|2022-01-02 00:00:00|           172800|
|      1|             6|2022-01-06 00:00:00|2022-01-04 00:00:00|           172800|
|      1|             8|2022-01-08 00:00:00|2022-01-06 00:00:00|           172800|
|      1|            10|2022-01-10 00:00:00|2022-01-08 00:00:00|           172800|
|      1|            12|2022-01-12 00:00:00|2022-01-10 00:00:00|           172800|
|      1|            14|2022-01-14 00:00:00|2022-01-12 00:00:00|           172800|
|      1|            16|2022-01-16 00:00:00|2022-01-14 00:00:00|           172800|
|      1|            18|2022-01-18 00:00:00|2022-01-16 00:00:00|           172800|
|      1|            20|2022-01-20 00:00:00|2022-01-18 00:00:00|           172800|
|      1|            22|2022-01-22 00:00:00|2022-01-20 00:00:00|           172800|
|      1|            24|2022-01-24 00:00:00|2022-01-22 00:00:00|           172800|
|      1|            26|2022-01-26 00:00:00|2022-01-24 00:00:00|           172800|
|      1|            28|2022-01-28 00:00:00|2022-01-26 00:00:00|           172800|
|      1|            30|2022-01-30 00:00:00|2022-01-28 00:00:00|           172800|
|      2|             1|2022-01-01 00:00:00|               null|             null|
|      2|             3|2022-01-03 00:00:00|2022-01-01 00:00:00|           172800|
|      2|             5|2022-01-05 00:00:00|2022-01-03 00:00:00|           172800|
|      2|             7|2022-01-07 00:00:00|2022-01-05 00:00:00|           172800|
|      2|             9|2022-01-09 00:00:00|2022-01-07 00:00:00|           172800|
|      2|            11|2022-01-11 00:00:00|2022-01-09 00:00:00|           172800|
|      2|            13|2022-01-13 00:00:00|2022-01-11 00:00:00|           172800|
|      2|            15|2022-01-15 00:00:00|2022-01-13 00:00:00|           172800|
|      2|            17|2022-01-17 00:00:00|2022-01-15 00:00:00|           172800|
|      2|            19|2022-01-19 00:00:00|2022-01-17 00:00:00|           172800|
|      2|            21|2022-01-21 00:00:00|2022-01-19 00:00:00|           172800|
|      2|            23|2022-01-23 00:00:00|2022-01-21 00:00:00|           172800|
|      2|            25|2022-01-25 00:00:00|2022-01-23 00:00:00|           172800|
|      2|            27|2022-01-27 00:00:00|2022-01-25 00:00:00|           172800|
|      2|            29|2022-01-29 00:00:00|2022-01-27 00:00:00|           172800|
+-------+--------------+-------------------+-------------------+-----------------+
Another way to write the previous code concisely is as follows.
[60]:
window = Window.partitionBy('user_id').orderBy('date')
sdf \
    .withColumn('diff', F.unix_timestamp('date') - F.lag(F.unix_timestamp(F.col('date'))).over(window)) \
    .show(31)
+-------+--------------+-------------------+------+
|user_id|transaction_id|               date|  diff|
+-------+--------------+-------------------+------+
|      1|             2|2022-01-02 00:00:00|  null|
|      1|             4|2022-01-04 00:00:00|172800|
|      1|             6|2022-01-06 00:00:00|172800|
|      1|             8|2022-01-08 00:00:00|172800|
|      1|            10|2022-01-10 00:00:00|172800|
|      1|            12|2022-01-12 00:00:00|172800|
|      1|            14|2022-01-14 00:00:00|172800|
|      1|            16|2022-01-16 00:00:00|172800|
|      1|            18|2022-01-18 00:00:00|172800|
|      1|            20|2022-01-20 00:00:00|172800|
|      1|            22|2022-01-22 00:00:00|172800|
|      1|            24|2022-01-24 00:00:00|172800|
|      1|            26|2022-01-26 00:00:00|172800|
|      1|            28|2022-01-28 00:00:00|172800|
|      1|            30|2022-01-30 00:00:00|172800|
|      2|             1|2022-01-01 00:00:00|  null|
|      2|             3|2022-01-03 00:00:00|172800|
|      2|             5|2022-01-05 00:00:00|172800|
|      2|             7|2022-01-07 00:00:00|172800|
|      2|             9|2022-01-09 00:00:00|172800|
|      2|            11|2022-01-11 00:00:00|172800|
|      2|            13|2022-01-13 00:00:00|172800|
|      2|            15|2022-01-15 00:00:00|172800|
|      2|            17|2022-01-17 00:00:00|172800|
|      2|            19|2022-01-19 00:00:00|172800|
|      2|            21|2022-01-21 00:00:00|172800|
|      2|            23|2022-01-23 00:00:00|172800|
|      2|            25|2022-01-25 00:00:00|172800|
|      2|            27|2022-01-27 00:00:00|172800|
|      2|            29|2022-01-29 00:00:00|172800|
+-------+--------------+-------------------+------+
2.3.11. User defined function (UDF)
If you have complicated logic to transform a column, you can use User-Defined Functions or UDFs. To create a UDF, define the function that will do the transformation first. Below, we define times_two() to take in a number input and return that number times two. The second thing you need to do is create the UDF using udf(), which requires two arguments:
- the function that will do the transform 
- the return type 
The return type comes from the pyspark.sql.types module. Finally, you can apply your UDF as an argument to select().
[61]:
import random
n_cols = 10
n_rows = 10
data = [tuple([random.randint(0, 5) for c in range(n_cols)]) for r in range(n_rows)]
columns = [f'x{i}' for i in range(n_cols)]
pdf = pd.DataFrame(data, columns=columns)
sdf = sqlContext.createDataFrame(pdf)
sdf.show()
+---+---+---+---+---+---+---+---+---+---+
| x0| x1| x2| x3| x4| x5| x6| x7| x8| x9|
+---+---+---+---+---+---+---+---+---+---+
|  1|  1|  1|  4|  5|  4|  3|  2|  3|  1|
|  4|  3|  2|  3|  0|  3|  2|  4|  5|  2|
|  4|  3|  2|  0|  1|  3|  3|  5|  1|  3|
|  4|  4|  1|  0|  2|  4|  5|  2|  5|  5|
|  5|  2|  3|  3|  5|  2|  0|  3|  0|  3|
|  3|  2|  4|  2|  1|  3|  1|  2|  5|  3|
|  1|  4|  2|  5|  0|  1|  4|  0|  1|  5|
|  0|  2|  4|  5|  4|  0|  1|  2|  2|  0|
|  2|  1|  4|  0|  4|  1|  4|  3|  0|  2|
|  2|  5|  5|  5|  1|  1|  4|  0|  4|  3|
+---+---+---+---+---+---+---+---+---+---+
[62]:
from pyspark.sql.functions import udf
from pyspark.sql.types import *
def times_two(num):
    return num * 2
times_two_udf = udf(times_two, IntegerType())
sdf.select('x0', times_two_udf('x0').alias('times_two')).show()
+---+---------+
| x0|times_two|
+---+---------+
|  1|        2|
|  4|        8|
|  4|        8|
|  4|        8|
|  5|       10|
|  3|        6|
|  1|        2|
|  0|        0|
|  2|        4|
|  2|        4|
+---+---------+
A UDF can also accept multiple arguments. Here’s an example.
[63]:
def add_them(a, b):
    return a + b
add_them_udf = udf(add_them, IntegerType())
sdf.select('x0', 'x1', add_them_udf('x0', 'x1').alias('add_them')).show()
+---+---+--------+
| x0| x1|add_them|
+---+---+--------+
|  1|  1|       2|
|  4|  3|       7|
|  4|  3|       7|
|  4|  4|       8|
|  5|  2|       7|
|  3|  2|       5|
|  1|  4|       5|
|  0|  2|       2|
|  2|  1|       3|
|  2|  5|       7|
+---+---+--------+
2.3.12. User defined function (UDF) with annotation
It’s probably easiest to use the @udf decorator on a function. Note that the @udf decorator is parameterized; we have to specify the return type.
[64]:
from pyspark.sql.functions import udf
@udf('int')
def times_three(num):
    return num * 3
sdf.select('x0', times_three('x0').alias('times_three')).show()
+---+-----------+
| x0|times_three|
+---+-----------+
|  1|          3|
|  4|         12|
|  4|         12|
|  4|         12|
|  5|         15|
|  3|          9|
|  1|          3|
|  0|          0|
|  2|          6|
|  2|          6|
+---+-----------+