id | total | row_offset | |
---|---|---|---|
0 | 1 | 10 | 3 |
1 | 2 | 5 | 0 |
2 | 3 | 1 | 4 |
3 | 4 | 20 | 2 |
4 | 5 | 4 | 2 |
5 | 6 | 1 | 1 |
6 | 7 | 3 | 0 |
Dynamic-Like Spark Windows
Unlocking dynamic-window-like functionality in Spark.
Introduction
We’ll start with a DataFrame:
Say we’d like to sum total
over a look-ahead window defined by row_offset
so that the output is:
Code
= (
expected_sum_data 1, 10, 3, 36),
(2, 5, 0, 5),
(3, 1, 4, 29),
(4, 20, 2, 25),
(5, 4, 2, 8),
(6, 1, 1, 4),
(7, 3, 0, 3),
(
)
= StructType(
expected_sum_schema
["id", IntegerType(), False),
StructField("total", IntegerType(), False),
StructField("row_offset", IntegerType(), False),
StructField("windowed_sum", LongType(), False),
StructField(
]
)
= spark.createDataFrame(expected_sum_data, expected_sum_schema)
expected_sum_df
expected_sum_df.toPandas()
id | total | row_offset | windowed_sum | |
---|---|---|---|---|
0 | 1 | 10 | 3 | 36 |
1 | 2 | 5 | 0 | 5 |
2 | 3 | 1 | 4 | 29 |
3 | 4 | 20 | 2 | 25 |
4 | 5 | 4 | 2 | 8 |
5 | 6 | 1 | 1 | 4 |
6 | 7 | 3 | 0 | 3 |
Aggregate window functions seem like a natural fit. Let’s give it a shot.
import pyspark.sql as ps
import pyspark.sql.functions as psf
df.withColumn("windowed_total",
sum("total").over(ps.Window.orderBy("id").rowsBetween(0, psf.col("row_offset"))),
psf. )
PySparkValueError: [CANNOT_CONVERT_COLUMN_INTO_BOOL] Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.
This error isn’t very helpful but clearly something’s wrong. Looking at the docs, it’s clear rowsBetween
only accepts fixed int
values.
Whether rolling up hierarchical data or aggregating over a variable lagging window, fixed window sizes are a conspicuous limitation of Spark (and all SQL flavors I’m familiar with).
Below, we’ll explore a couple workarounds to this common issue.
Conditional Window Approach
Ad-Hoc Implementation
Initial solutions are often better off ignoring DRY, decoupling, and design. The first priority is to get something working. Generalization and optimization can come later.
With this in mind, we’ll begin with an ad-hoc implementation tailored for this specific case and dataset.
# get row offsets observed in data
= df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
row_offsets
# build a window condition for each window size
= [
windowed_sum_conditions_and_choices
("row_offset") == row_offset,
psf.col(sum("total").over(ps.Window.orderBy("id").rowsBetween(0, row_offset)),
psf.
)for row_offset in row_offsets
]
windowed_sum_conditions_and_choices
[(Column<'(row_offset = 3)'>,
Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING)'>),
(Column<'(row_offset = 0)'>,
Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND CURRENT ROW)'>),
(Column<'(row_offset = 4)'>,
Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 4 FOLLOWING)'>),
(Column<'(row_offset = 2)'>,
Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING)'>),
(Column<'(row_offset = 1)'>,
Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING)'>)]
Composing these window conditions and choices into a single CASE WHEN
:
# build conditional sum as composed `CASE WHEN`
= psf.when(*windowed_sum_conditions_and_choices.pop())
conditional_windowed_sum for condition in windowed_sum_conditions_and_choices:
= conditional_windowed_sum.when(*condition)
conditional_windowed_sum = conditional_windowed_sum.otherwise(psf.col("total"))
conditional_windowed_sum
conditional_windowed_sum
Column<'CASE WHEN (row_offset = 1) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING) WHEN (row_offset = 3) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING) WHEN (row_offset = 0) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND CURRENT ROW) WHEN (row_offset = 4) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 4 FOLLOWING) WHEN (row_offset = 2) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) ELSE total END'>
In pseudo-SQL, this is effectively:
CASE
WHEN row_offset = 0 THEN sum(total) OVER (ORDER BY id ROWS BETWEEN 0 AND 0)
WHEN row_offset = 1 THEN sum(total) OVER (ORDER BY id ROWS BETWEEN 0 AND 1)
. . . WHEN row_offset = N THEN sum(total) OVER (ORDER BY id ROWS BETWEEN 0 AND N)
ELSE total
END AS windowed_sum
Calculating this conditional window function:
# compute and return
"windowed_sum", conditional_windowed_sum).toPandas() df.withColumn(
id | total | row_offset | windowed_sum | |
---|---|---|---|---|
0 | 1 | 10 | 3 | 36 |
1 | 2 | 5 | 0 | 5 |
2 | 3 | 1 | 4 | 29 |
3 | 4 | 20 | 2 | 25 |
4 | 5 | 4 | 2 | 8 |
5 | 6 | 1 | 1 | 4 |
6 | 7 | 3 | 0 | 3 |
Looks right! Comparing to the expected table:
from pyspark.testing.utils import assertDataFrameEqual
"windowed_sum", conditional_windowed_sum), expected_sum_df) assertDataFrameEqual(df.withColumn(
All good. Moving on to a more general implementation.
Generalized Implementation
To begin the refactor, notice dynamic_window_sum
has a few distinct jobs:
- Collect the row offsets as a list.
- Generate window conditions and choices.
- Chain the window conditions and choices into a composed
CASE WHEN
. - Apply the sum over the conditional windows.
Of these, (3), chaining the window conditions, is the most generic and well suited for extraction.
from typing import Callable, List, Any
def method_chain(
# method to be chained
method: Callable, # arguments allocate across chained method calls
arg_sets: List[List[Any]], -> Any:
) """
Chains calls to chainable class methods.
E.g., `pyspark.sql.functions.when(arg_sets[0]).when(arg_sets[1]).when(arg_sets[. . . ]).when(arg_sets[-1])`.
"""
= method(*arg_sets.pop())
result for arg_set in arg_sets:
= getattr(result, method.__name__)(*arg_set)
result
return result
With that responsibility extracted, dynamic_window_sum
becomes:
def dynamic_window_sum(
df: ps.DataFrame,-> ps.DataFrame:
) # get row offsets observed in data
= df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
row_offsets
# build a window condition for each window size
= [
windowed_sum_conditions
("row_offset") == row_offset,
psf.col(sum("total").over(ps.Window.orderBy("id").rowsBetween(0, row_offset)),
psf.
)for row_offset in row_offsets
]
# build conditional sum as composed `CASE WHEN`
= method_chain(psf.when, windowed_sum_conditions).otherwise(
conditional_windowed_sum "total")
psf.col(
)
# compute and return
return df.withColumn("windowed_sum", conditional_windowed_sum)
Verifying we still get the expected result:
assertDataFrameEqual(dynamic_window_sum(df), expected_sum_df)
This is better, but dynamic windows will be useful for any sort of aggregation, not just sums. We’d like a generic way of transforming a ps.Window
into a set of conditional windows over which any aggregate function can be applied. We’d also like to allow dynamic start
and end
values for rowsBetween
.
from typing import Tuple, Union, Optional, Literal, Callable
def dynamic_aggregate_window_factory(
# base window spec with `partitionBy` and / or `orderBy` declared
window: ps.Window,
row_offset_col_names: List[str
# names of column(s) identifying window (backward, forward) offsets for each row
],
row_offsets: Union[int], List[Tuple[int, int]]
List[# list of (backward, forward), backward, or forward row offsets to handle
],
unioffset_direction: Optional["foward", "backward"]
Literal[= None, # only needed if `rows_offsets` is a list of uni-directional offsets
] -> Callable:
) if isinstance(row_offsets[0], int):
assert (
is not None
unioffset_direction "`unioffset_direction` must be set if `rows_offsets` is a list of uni-directional offsets."
),
match unioffset_direction:
case None:
= [
window_conditions
(0]) == _rows_offsets[0])
(psf.col(row_offset_col_names[& (psf.col(row_offset_col_names[1]) == _rows_offsets[1]),
0], _rows_offsets[1]),
window.rowsBetween(_rows_offsets[
)for _rows_offsets in row_offsets
]case "forward":
= [
window_conditions
(0]) == forward_row_offset),
(psf.col(row_offset_col_names[0, forward_row_offset),
window.rowsBetween(
)for forward_row_offset in row_offsets
]case "backward":
= [
window_conditions
(0]) == backward_row_offset),
(psf.col(row_offset_col_names[0),
window.rowsBetween(backward_row_offset,
)for backward_row_offset in row_offsets
]
def _apply_over_dynamic_window(func: Callable):
return method_chain(
for condition, window in window_conditions]
psf.when, [(condition, func.over(window))
)
return _apply_over_dynamic_window
# specify base window, collect observed row offsets
= ps.Window.orderBy("id")
window = ["row_offset"]
row_offset_col_names = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
row_offsets = "forward"
unioffset_direction
# build dynamic window application function
= dynamic_aggregate_window_factory(
apply_over_dynamic_window
window, row_offset_col_names, row_offsets, unioffset_direction
)
# apply
= df.withColumns(
computed_df
{"windowed_sum": apply_over_dynamic_window(psf.sum("total")),
"windowed_product": apply_over_dynamic_window(psf.product("total")),
}
)
computed_df.toPandas()
id | total | row_offset | windowed_sum | windowed_product | |
---|---|---|---|---|---|
0 | 1 | 10 | 3 | 36 | 1000.0 |
1 | 2 | 5 | 0 | 5 | 5.0 |
2 | 3 | 1 | 4 | 29 | 240.0 |
3 | 4 | 20 | 2 | 25 | 80.0 |
4 | 5 | 4 | 2 | 8 | 12.0 |
5 | 6 | 1 | 1 | 4 | 3.0 |
6 | 7 | 3 | 0 | 3 | 3.0 |
Looks about right. Let’s create a proper test function to protect against regressions.
Code
from pyspark.sql.types import DoubleType
def create_df():
= ((1, 10, 3), (2, 5, 0), (3, 1, 4), (4, 20, 2), (5, 4, 2), (6, 1, 1), (7, 3, 0))
data
= StructType(
schema
["id", IntegerType(), False),
StructField("total", IntegerType(), False),
StructField("row_offset", IntegerType(), False),
StructField(
]
)
= spark.createDataFrame(data, schema)
df
return df
def create_expected_df():
= (
expected_data 1, 10, 3, 36, 1000.0),
(2, 5, 0, 5, 5.0),
(3, 1, 4, 29, 240.0),
(4, 20, 2, 25, 80.0),
(5, 4, 2, 8, 12.0),
(6, 1, 1, 4, 3.0),
(7, 3, 0, 3, 3.0),
(
)
= StructType(
expected_schema
["id", IntegerType(), False),
StructField("total", IntegerType(), False),
StructField("row_offset", IntegerType(), False),
StructField("windowed_sum", LongType(), False),
StructField("windowed_product", DoubleType(), False),
StructField(
]
)
= spark.createDataFrame(expected_data, expected_schema)
expected_df
return expected_df
def test_dynamic_aggregate_window_factory():
= create_df()
df
# specify base window, collect observed row offsets
= ps.Window.orderBy("id")
window = ["row_offset"]
row_offset_col_names = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
row_offsets = "forward"
unioffset_direction
# build dynamic window application function
= dynamic_aggregate_window_factory(
apply_over_dynamic_window
window, row_offset_col_names, row_offsets, unioffset_direction
)
# apply
= df.withColumns(
computed_df
{"windowed_sum": apply_over_dynamic_window(psf.sum("total")),
"windowed_product": apply_over_dynamic_window(psf.product("total")),
}
)
assertDataFrameEqual(computed_df, create_expected_df())
test_dynamic_aggregate_window_factory()
This works fine enough, but the call to collect distinct row offset values adds a computational dependency and breaks the lazy load paradigm. Self-joining is an alternative approach.
Self-Join Approach
Ad-Hoc Implementation
First, number each row of the DataFrame according to the base window:
= ps.Window().orderBy("id")
window
= df.withColumn("_row_number", psf.row_number().over(window))
numbered_df
numbered_df.toPandas()
id | total | row_offset | _row_number | |
---|---|---|---|---|
0 | 1 | 10 | 3 | 1 |
1 | 2 | 5 | 0 | 2 |
2 | 3 | 1 | 4 | 3 |
3 | 4 | 20 | 2 | 4 |
4 | 5 | 4 | 2 | 5 |
5 | 6 | 1 | 1 | 6 |
6 | 7 | 3 | 0 | 7 |
Next, self-join so that rows within the look-ahead window are linked.
(Renaming columns instead of aliasing DataFrames for better readability.)
= numbered_df.withColumnsRenamed(
self_join_df f"_left_{col}" for col in numbered_df.columns}
{col:
).join(f"_right_{col}" for col in numbered_df.columns}),
numbered_df.withColumnsRenamed({col: =[
on"_left__row_number") <= psf.col("_right__row_number"),
psf.col("_left__row_number") + psf.col("_left_row_offset")
psf.col(>= psf.col("_right__row_number"),
],="left",
how
)
self_join_df.toPandas()
_left_id | _left_total | _left_row_offset | _left__row_number | _right_id | _right_total | _right_row_offset | _right__row_number | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 10 | 3 | 1 | 1 | 10 | 3 | 1 |
1 | 1 | 10 | 3 | 1 | 2 | 5 | 0 | 2 |
2 | 1 | 10 | 3 | 1 | 3 | 1 | 4 | 3 |
3 | 1 | 10 | 3 | 1 | 4 | 20 | 2 | 4 |
4 | 2 | 5 | 0 | 2 | 2 | 5 | 0 | 2 |
5 | 3 | 1 | 4 | 3 | 3 | 1 | 4 | 3 |
6 | 3 | 1 | 4 | 3 | 4 | 20 | 2 | 4 |
7 | 3 | 1 | 4 | 3 | 5 | 4 | 2 | 5 |
8 | 3 | 1 | 4 | 3 | 6 | 1 | 1 | 6 |
9 | 3 | 1 | 4 | 3 | 7 | 3 | 0 | 7 |
10 | 4 | 20 | 2 | 4 | 4 | 20 | 2 | 4 |
11 | 4 | 20 | 2 | 4 | 5 | 4 | 2 | 5 |
12 | 4 | 20 | 2 | 4 | 6 | 1 | 1 | 6 |
13 | 5 | 4 | 2 | 5 | 5 | 4 | 2 | 5 |
14 | 5 | 4 | 2 | 5 | 6 | 1 | 1 | 6 |
15 | 5 | 4 | 2 | 5 | 7 | 3 | 0 | 7 |
16 | 6 | 1 | 1 | 6 | 6 | 1 | 1 | 6 |
17 | 6 | 1 | 1 | 6 | 7 | 3 | 0 | 7 |
18 | 7 | 3 | 0 | 7 | 7 | 3 | 0 | 7 |
Finally, group by the left-side columns which identify the window and aggregate the right-side totals which are in the window.
= (
computed_df "_left_id", "_left_total", "_left_row_offset"])
self_join_df.groupby([sum("_right_total").alias("windowed_sum"))
.agg(psf.
.transform(lambda df: df.withColumnsRenamed({col: col.replace("_left_", "") for col in df.columns})
)
)
computed_df.toPandas()
id | total | row_offset | windowed_sum | |
---|---|---|---|---|
0 | 1 | 10 | 3 | 36 |
1 | 2 | 5 | 0 | 5 |
2 | 3 | 1 | 4 | 29 |
3 | 4 | 20 | 2 | 25 |
4 | 5 | 4 | 2 | 8 |
5 | 6 | 1 | 1 | 4 |
6 | 7 | 3 | 0 | 3 |
Looks right. Let’s verify.
assertDataFrameEqual(
computed_df,
expected_sum_df, )
With things working, we’ll move to a more generalized implementation.
Generalized Implementation
from typing import Dict
def agg_over_dynamic_window(
df: ps.DataFrame,# base window
window: ps.Window,
rows_between: Tuple[str, int], Union[str, int]
Union[# static (int) or dynamic (column name as str) `rowsBetween` (start, end). See `pyspark.sql.Window.rowsBetween`.
],
aggs: Dict[str, psf.Column
# similar to `pd.DataFrame.agg`. E.g., {column_name: aggregate_func}.
], -> ps.DataFrame:
) # using prefix here so agg func calls will refer to columns on right-side of join
= (
rows_between_start f"_left_{rows_between[0]}")
psf.col(if isinstance(rows_between[0], str)
else psf.lit(rows_between[0])
)= (
rows_between_end f"_left_{rows_between[1]}")
psf.col(if isinstance(rows_between[1], str)
else psf.lit(rows_between[1])
)
# using underscore for automatically added intermediate columns
= df.withColumn("_row_number", psf.row_number().over(window))
numbered_df
return (
f"_left_{col}" for col in numbered_df.columns})
numbered_df.withColumnsRenamed({col:
.join(
numbered_df,=[
on"_left__row_number") + rows_between_start <= psf.col("_row_number"),
psf.col("_left__row_number") + rows_between_end >= psf.col("_row_number"),
psf.col(
],="left",
how
)*[psf.col(f"_left_{col}") for col in df.columns])
.groupby(*[func.alias(name) for name, func in aggs.items()])
.agg(
.transform(lambda df: df.withColumnsRenamed(
"_left_", "") for col in df.columns}
{col: col.replace(
)
) )
= ps.Window().orderBy("id")
window
df.transform(
agg_over_dynamic_window,
window,0, "row_offset"],
[
{"windowed_sum": psf.sum("total"),
"windowed_product": psf.product("total"),
}, ).toPandas()
id | total | row_offset | windowed_sum | windowed_product | |
---|---|---|---|---|---|
0 | 1 | 10 | 3 | 36 | 1000.0 |
1 | 2 | 5 | 0 | 5 | 5.0 |
2 | 3 | 1 | 4 | 29 | 240.0 |
3 | 4 | 20 | 2 | 25 | 80.0 |
4 | 5 | 4 | 2 | 8 | 12.0 |
5 | 6 | 1 | 1 | 4 | 3.0 |
6 | 7 | 3 | 0 | 3 | 3.0 |
Time for another test.
Code
def test_agg_over_dynamic_window():
= create_df()
df = ps.Window().orderBy("id")
window = df.transform(
computed_df
agg_over_dynamic_window,
window,0, "row_offset"],
[
{"windowed_sum": psf.sum("total"),
"windowed_product": psf.product("total"),
},
)
assertDataFrameEqual(computed_df, create_expected_df())
test_agg_over_dynamic_window()
The self-join method doesn’t force an additional computation and requires less, arguably clearer, code. How does it perform?
Performance
Although testing on small data is not representative of most PySpark use cases, I reckon the rankings will hold for large datasets.
Conditional Window Approach
Full Compute
%%timeit
# collect observed offsets
= df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
row_offsets # build dynamic window application function
= dynamic_aggregate_window_factory(
apply_over_dynamic_window
window, row_offset_col_names, row_offsets, unioffset_direction
)# apply
df.withColumns(
{"windowed_sum": apply_over_dynamic_window(psf.sum("total")),
"windowed_product": apply_over_dynamic_window(psf.product("total")),
} ).count()
259 ms ± 36.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Precomputed Window Conditionals
%%timeit
# apply
df.withColumns(
{"windowed_sum": apply_over_dynamic_window(psf.sum("total")),
"windowed_product": apply_over_dynamic_window(psf.product("total")),
} ).count()
97.9 ms ± 6.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Self-Join Approach
%%timeit
df.transform(
agg_over_dynamic_window,
window,0, "row_offset"],
[
{"windowed_sum": psf.sum("total"),
"windowed_product": psf.product("total"),
}, ).count()
129 ms ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Summary
While Spark doesn’t support dynamically sized windows natively, we demonstrated two approaches that unlock dynamic-window-like functionality – Conditional Windows and Self-Joins.
Ranking these approaches on code readability and execution performance:
- The Self-Join approach is the most performant option when the entire computation time is considered. This approach also requires less code than the Conditional Window approach.
- The Conditional Window approach with a precomputed conditional is the most performant option. This approach is recommended when (1) the conditional window will be re-used in multiple places throughout a script and (2) the forced computation to collect offsets is acceptable.
- The Conditional Window approach without precomputed conditional windows is the least performant option and also requires the most code.