Dynamic-Like Spark Windows

Unlocking dynamic-window-like functionality in Spark.

Published

October 15, 2023

Introduction

We’ll start with a DataFrame:

id total row_offset
0 1 10 3
1 2 5 0
2 3 1 4
3 4 20 2
4 5 4 2
5 6 1 1
6 7 3 0

Say we’d like to sum total over a look-ahead window defined by row_offset so that the output is:

Code
expected_sum_data = (
    (1, 10, 3, 36),
    (2, 5, 0, 5),
    (3, 1, 4, 29),
    (4, 20, 2, 25),
    (5, 4, 2, 8),
    (6, 1, 1, 4),
    (7, 3, 0, 3),
)

expected_sum_schema = StructType(
    [
        StructField("id", IntegerType(), False),
        StructField("total", IntegerType(), False),
        StructField("row_offset", IntegerType(), False),
        StructField("windowed_sum", LongType(), False),
    ]
)

expected_sum_df = spark.createDataFrame(expected_sum_data, expected_sum_schema)

expected_sum_df.toPandas()
id total row_offset windowed_sum
0 1 10 3 36
1 2 5 0 5
2 3 1 4 29
3 4 20 2 25
4 5 4 2 8
5 6 1 1 4
6 7 3 0 3

Aggregate window functions seem like a natural fit. Let’s give it a shot.

import pyspark.sql as ps
import pyspark.sql.functions as psf

df.withColumn(
    "windowed_total",
    psf.sum("total").over(ps.Window.orderBy("id").rowsBetween(0, psf.col("row_offset"))),
)
PySparkValueError: [CANNOT_CONVERT_COLUMN_INTO_BOOL] Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.

This error isn’t very helpful but clearly something’s wrong. Looking at the docs, it’s clear rowsBetween only accepts fixed int values.

Whether rolling up hierarchical data or aggregating over a variable lagging window, fixed window sizes are a conspicuous limitation of Spark (and all SQL flavors I’m familiar with).

Below, we’ll explore a couple workarounds to this common issue.

Conditional Window Approach

Ad-Hoc Implementation

Initial solutions are often better off ignoring DRY, decoupling, and design. The first priority is to get something working. Generalization and optimization can come later.

With this in mind, we’ll begin with an ad-hoc implementation tailored for this specific case and dataset.

# get row offsets observed in data
row_offsets = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()

# build a window condition for each window size
windowed_sum_conditions_and_choices = [
    (
        psf.col("row_offset") == row_offset,
        psf.sum("total").over(ps.Window.orderBy("id").rowsBetween(0, row_offset)),
    )
    for row_offset in row_offsets
]

windowed_sum_conditions_and_choices
[(Column<'(row_offset = 3)'>,
  Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING)'>),
 (Column<'(row_offset = 0)'>,
  Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND CURRENT ROW)'>),
 (Column<'(row_offset = 4)'>,
  Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 4 FOLLOWING)'>),
 (Column<'(row_offset = 2)'>,
  Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING)'>),
 (Column<'(row_offset = 1)'>,
  Column<'sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING)'>)]

Composing these window conditions and choices into a single CASE WHEN:

# build conditional sum as composed `CASE WHEN`
conditional_windowed_sum = psf.when(*windowed_sum_conditions_and_choices.pop())
for condition in windowed_sum_conditions_and_choices:
    conditional_windowed_sum = conditional_windowed_sum.when(*condition)
conditional_windowed_sum = conditional_windowed_sum.otherwise(psf.col("total"))

conditional_windowed_sum
Column<'CASE WHEN (row_offset = 1) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING) WHEN (row_offset = 3) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING) WHEN (row_offset = 0) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND CURRENT ROW) WHEN (row_offset = 4) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 4 FOLLOWING) WHEN (row_offset = 2) THEN sum(total) OVER (ORDER BY id ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) ELSE total END'>

In pseudo-SQL, this is effectively:

CASE
    WHEN row_offset = 0 THEN sum(total) OVER (ORDER BY id ROWS BETWEEN 0 AND 0)
    WHEN row_offset = 1 THEN sum(total) OVER (ORDER BY id ROWS BETWEEN 0 AND 1)
    . . . 
    WHEN row_offset = N THEN sum(total) OVER (ORDER BY id ROWS BETWEEN 0 AND N)
    ELSE total
END AS windowed_sum

Calculating this conditional window function:

# compute and return
df.withColumn("windowed_sum", conditional_windowed_sum).toPandas()
id total row_offset windowed_sum
0 1 10 3 36
1 2 5 0 5
2 3 1 4 29
3 4 20 2 25
4 5 4 2 8
5 6 1 1 4
6 7 3 0 3

Looks right! Comparing to the expected table:

from pyspark.testing.utils import assertDataFrameEqual

assertDataFrameEqual(df.withColumn("windowed_sum", conditional_windowed_sum), expected_sum_df)

All good. Moving on to a more general implementation.

Generalized Implementation

To begin the refactor, notice dynamic_window_sum has a few distinct jobs:

  1. Collect the row offsets as a list.
  2. Generate window conditions and choices.
  3. Chain the window conditions and choices into a composed CASE WHEN.
  4. Apply the sum over the conditional windows.

Of these, (3), chaining the window conditions, is the most generic and well suited for extraction.

from typing import Callable, List, Any


def method_chain(
    method: Callable,  # method to be chained
    arg_sets: List[List[Any]],  # arguments allocate across chained method calls
) -> Any:
    """
    Chains calls to chainable class methods.
    E.g., `pyspark.sql.functions.when(arg_sets[0]).when(arg_sets[1]).when(arg_sets[. . . ]).when(arg_sets[-1])`.
    """
    result = method(*arg_sets.pop())
    for arg_set in arg_sets:
        result = getattr(result, method.__name__)(*arg_set)

    return result

With that responsibility extracted, dynamic_window_sum becomes:

def dynamic_window_sum(
    df: ps.DataFrame,
) -> ps.DataFrame:
    # get row offsets observed in data
    row_offsets = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()

    # build a window condition for each window size
    windowed_sum_conditions = [
        (
            psf.col("row_offset") == row_offset,
            psf.sum("total").over(ps.Window.orderBy("id").rowsBetween(0, row_offset)),
        )
        for row_offset in row_offsets
    ]

    # build conditional sum as composed `CASE WHEN`
    conditional_windowed_sum = method_chain(psf.when, windowed_sum_conditions).otherwise(
        psf.col("total")
    )

    # compute and return
    return df.withColumn("windowed_sum", conditional_windowed_sum)

Verifying we still get the expected result:

assertDataFrameEqual(dynamic_window_sum(df), expected_sum_df)

This is better, but dynamic windows will be useful for any sort of aggregation, not just sums. We’d like a generic way of transforming a ps.Window into a set of conditional windows over which any aggregate function can be applied. We’d also like to allow dynamic start and end values for rowsBetween.

from typing import Tuple, Union, Optional, Literal, Callable


def dynamic_aggregate_window_factory(
    window: ps.Window,  # base window spec with `partitionBy` and / or `orderBy` declared
    row_offset_col_names: List[
        str
    ],  # names of column(s) identifying window (backward, forward) offsets for each row
    row_offsets: Union[
        List[int], List[Tuple[int, int]]
    ],  # list of (backward, forward), backward, or forward row offsets to handle
    unioffset_direction: Optional[
        Literal["foward", "backward"]
    ] = None,  # only needed if `rows_offsets` is a list of uni-directional offsets
) -> Callable:
    if isinstance(row_offsets[0], int):
        assert (
            unioffset_direction is not None
        ), "`unioffset_direction` must be set if `rows_offsets` is a list of uni-directional offsets."

    match unioffset_direction:
        case None:
            window_conditions = [
                (
                    (psf.col(row_offset_col_names[0]) == _rows_offsets[0])
                    & (psf.col(row_offset_col_names[1]) == _rows_offsets[1]),
                    window.rowsBetween(_rows_offsets[0], _rows_offsets[1]),
                )
                for _rows_offsets in row_offsets
            ]
        case "forward":
            window_conditions = [
                (
                    (psf.col(row_offset_col_names[0]) == forward_row_offset),
                    window.rowsBetween(0, forward_row_offset),
                )
                for forward_row_offset in row_offsets
            ]
        case "backward":
            window_conditions = [
                (
                    (psf.col(row_offset_col_names[0]) == backward_row_offset),
                    window.rowsBetween(backward_row_offset, 0),
                )
                for backward_row_offset in row_offsets
            ]

    def _apply_over_dynamic_window(func: Callable):
        return method_chain(
            psf.when, [(condition, func.over(window)) for condition, window in window_conditions]
        )

    return _apply_over_dynamic_window
# specify base window, collect observed row offsets
window = ps.Window.orderBy("id")
row_offset_col_names = ["row_offset"]
row_offsets = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
unioffset_direction = "forward"

# build dynamic window application function
apply_over_dynamic_window = dynamic_aggregate_window_factory(
    window, row_offset_col_names, row_offsets, unioffset_direction
)

# apply
computed_df = df.withColumns(
    {
        "windowed_sum": apply_over_dynamic_window(psf.sum("total")),
        "windowed_product": apply_over_dynamic_window(psf.product("total")),
    }
)

computed_df.toPandas()
id total row_offset windowed_sum windowed_product
0 1 10 3 36 1000.0
1 2 5 0 5 5.0
2 3 1 4 29 240.0
3 4 20 2 25 80.0
4 5 4 2 8 12.0
5 6 1 1 4 3.0
6 7 3 0 3 3.0

Looks about right. Let’s create a proper test function to protect against regressions.

Code
from pyspark.sql.types import DoubleType


def create_df():
    data = ((1, 10, 3), (2, 5, 0), (3, 1, 4), (4, 20, 2), (5, 4, 2), (6, 1, 1), (7, 3, 0))

    schema = StructType(
        [
            StructField("id", IntegerType(), False),
            StructField("total", IntegerType(), False),
            StructField("row_offset", IntegerType(), False),
        ]
    )

    df = spark.createDataFrame(data, schema)

    return df


def create_expected_df():
    expected_data = (
        (1, 10, 3, 36, 1000.0),
        (2, 5, 0, 5, 5.0),
        (3, 1, 4, 29, 240.0),
        (4, 20, 2, 25, 80.0),
        (5, 4, 2, 8, 12.0),
        (6, 1, 1, 4, 3.0),
        (7, 3, 0, 3, 3.0),
    )

    expected_schema = StructType(
        [
            StructField("id", IntegerType(), False),
            StructField("total", IntegerType(), False),
            StructField("row_offset", IntegerType(), False),
            StructField("windowed_sum", LongType(), False),
            StructField("windowed_product", DoubleType(), False),
        ]
    )

    expected_df = spark.createDataFrame(expected_data, expected_schema)

    return expected_df


def test_dynamic_aggregate_window_factory():
    df = create_df()

    # specify base window, collect observed row offsets
    window = ps.Window.orderBy("id")
    row_offset_col_names = ["row_offset"]
    row_offsets = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
    unioffset_direction = "forward"

    # build dynamic window application function
    apply_over_dynamic_window = dynamic_aggregate_window_factory(
        window, row_offset_col_names, row_offsets, unioffset_direction
    )

    # apply
    computed_df = df.withColumns(
        {
            "windowed_sum": apply_over_dynamic_window(psf.sum("total")),
            "windowed_product": apply_over_dynamic_window(psf.product("total")),
        }
    )

    assertDataFrameEqual(computed_df, create_expected_df())
test_dynamic_aggregate_window_factory()

This works fine enough, but the call to collect distinct row offset values adds a computational dependency and breaks the lazy load paradigm. Self-joining is an alternative approach.

Self-Join Approach

Ad-Hoc Implementation

First, number each row of the DataFrame according to the base window:

window = ps.Window().orderBy("id")

numbered_df = df.withColumn("_row_number", psf.row_number().over(window))

numbered_df.toPandas()
id total row_offset _row_number
0 1 10 3 1
1 2 5 0 2
2 3 1 4 3
3 4 20 2 4
4 5 4 2 5
5 6 1 1 6
6 7 3 0 7

Next, self-join so that rows within the look-ahead window are linked.

(Renaming columns instead of aliasing DataFrames for better readability.)

self_join_df = numbered_df.withColumnsRenamed(
    {col: f"_left_{col}" for col in numbered_df.columns}
).join(
    numbered_df.withColumnsRenamed({col: f"_right_{col}" for col in numbered_df.columns}),
    on=[
        psf.col("_left__row_number") <= psf.col("_right__row_number"),
        psf.col("_left__row_number") + psf.col("_left_row_offset")
        >= psf.col("_right__row_number"),
    ],
    how="left",
)

self_join_df.toPandas()
_left_id _left_total _left_row_offset _left__row_number _right_id _right_total _right_row_offset _right__row_number
0 1 10 3 1 1 10 3 1
1 1 10 3 1 2 5 0 2
2 1 10 3 1 3 1 4 3
3 1 10 3 1 4 20 2 4
4 2 5 0 2 2 5 0 2
5 3 1 4 3 3 1 4 3
6 3 1 4 3 4 20 2 4
7 3 1 4 3 5 4 2 5
8 3 1 4 3 6 1 1 6
9 3 1 4 3 7 3 0 7
10 4 20 2 4 4 20 2 4
11 4 20 2 4 5 4 2 5
12 4 20 2 4 6 1 1 6
13 5 4 2 5 5 4 2 5
14 5 4 2 5 6 1 1 6
15 5 4 2 5 7 3 0 7
16 6 1 1 6 6 1 1 6
17 6 1 1 6 7 3 0 7
18 7 3 0 7 7 3 0 7

Finally, group by the left-side columns which identify the window and aggregate the right-side totals which are in the window.

computed_df = (
    self_join_df.groupby(["_left_id", "_left_total", "_left_row_offset"])
    .agg(psf.sum("_right_total").alias("windowed_sum"))
    .transform(
        lambda df: df.withColumnsRenamed({col: col.replace("_left_", "") for col in df.columns})
    )
)

computed_df.toPandas()
id total row_offset windowed_sum
0 1 10 3 36
1 2 5 0 5
2 3 1 4 29
3 4 20 2 25
4 5 4 2 8
5 6 1 1 4
6 7 3 0 3

Looks right. Let’s verify.

assertDataFrameEqual(
    computed_df,
    expected_sum_df,
)

With things working, we’ll move to a more generalized implementation.

Generalized Implementation

from typing import Dict


def agg_over_dynamic_window(
    df: ps.DataFrame,
    window: ps.Window,  # base window
    rows_between: Tuple[
        Union[str, int], Union[str, int]
    ],  # static (int) or dynamic (column name as str) `rowsBetween` (start, end). See `pyspark.sql.Window.rowsBetween`.
    aggs: Dict[
        str, psf.Column
    ],  # similar to `pd.DataFrame.agg`. E.g., {column_name: aggregate_func}.
) -> ps.DataFrame:
    # using prefix here so agg func calls will refer to columns on right-side of join
    rows_between_start = (
        psf.col(f"_left_{rows_between[0]}")
        if isinstance(rows_between[0], str)
        else psf.lit(rows_between[0])
    )
    rows_between_end = (
        psf.col(f"_left_{rows_between[1]}")
        if isinstance(rows_between[1], str)
        else psf.lit(rows_between[1])
    )

    # using underscore for automatically added intermediate columns
    numbered_df = df.withColumn("_row_number", psf.row_number().over(window))

    return (
        numbered_df.withColumnsRenamed({col: f"_left_{col}" for col in numbered_df.columns})
        .join(
            numbered_df,
            on=[
                psf.col("_left__row_number") + rows_between_start <= psf.col("_row_number"),
                psf.col("_left__row_number") + rows_between_end >= psf.col("_row_number"),
            ],
            how="left",
        )
        .groupby(*[psf.col(f"_left_{col}") for col in df.columns])
        .agg(*[func.alias(name) for name, func in aggs.items()])
        .transform(
            lambda df: df.withColumnsRenamed(
                {col: col.replace("_left_", "") for col in df.columns}
            )
        )
    )
window = ps.Window().orderBy("id")

df.transform(
    agg_over_dynamic_window,
    window,
    [0, "row_offset"],
    {
        "windowed_sum": psf.sum("total"),
        "windowed_product": psf.product("total"),
    },
).toPandas()
id total row_offset windowed_sum windowed_product
0 1 10 3 36 1000.0
1 2 5 0 5 5.0
2 3 1 4 29 240.0
3 4 20 2 25 80.0
4 5 4 2 8 12.0
5 6 1 1 4 3.0
6 7 3 0 3 3.0

Time for another test.

Code
def test_agg_over_dynamic_window():
    df = create_df()
    window = ps.Window().orderBy("id")
    computed_df = df.transform(
        agg_over_dynamic_window,
        window,
        [0, "row_offset"],
        {
            "windowed_sum": psf.sum("total"),
            "windowed_product": psf.product("total"),
        },
    )

    assertDataFrameEqual(computed_df, create_expected_df())
test_agg_over_dynamic_window()

The self-join method doesn’t force an additional computation and requires less, arguably clearer, code. How does it perform?

Performance

Although testing on small data is not representative of most PySpark use cases, I reckon the rankings will hold for large datasets.

Conditional Window Approach

Full Compute

%%timeit
# collect observed offsets
row_offsets = df.select("row_offset").distinct().toPandas().iloc[:, 0].values.tolist()
# build dynamic window application function
apply_over_dynamic_window = dynamic_aggregate_window_factory(
    window, row_offset_col_names, row_offsets, unioffset_direction
)
# apply
df.withColumns(
    {
        "windowed_sum": apply_over_dynamic_window(psf.sum("total")),
        "windowed_product": apply_over_dynamic_window(psf.product("total")),
    }
).count()
259 ms ± 36.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Precomputed Window Conditionals

%%timeit
# apply
df.withColumns(
    {
        "windowed_sum": apply_over_dynamic_window(psf.sum("total")),
        "windowed_product": apply_over_dynamic_window(psf.product("total")),
    }
).count()
97.9 ms ± 6.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Self-Join Approach

%%timeit
df.transform(
    agg_over_dynamic_window,
    window,
    [0, "row_offset"],
    {
        "windowed_sum": psf.sum("total"),
        "windowed_product": psf.product("total"),
    },
).count()
129 ms ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Summary

While Spark doesn’t support dynamically sized windows natively, we demonstrated two approaches that unlock dynamic-window-like functionality – Conditional Windows and Self-Joins.

Ranking these approaches on code readability and execution performance:

  1. The Self-Join approach is the most performant option when the entire computation time is considered. This approach also requires less code than the Conditional Window approach.
  2. The Conditional Window approach with a precomputed conditional is the most performant option. This approach is recommended when (1) the conditional window will be re-used in multiple places throughout a script and (2) the forced computation to collect offsets is acceptable.
  3. The Conditional Window approach without precomputed conditional windows is the least performant option and also requires the most code.