employee_id | sales_total | manager_employee_id |
1 | 253 | NULL |
2 | 308 | 1 |
3 | 92 | 1 |
4 | 20 | 2 |
5 | 148 | 4 |
6 | 377 | 1 |
7 | 87 | 2 |
Unlocking hierarchical graph operations in pure Spark.
October 21, 2023
Graph databases naturally represent hierarchical data, complex relationships, and arbitrary attributes (e.g. key-value pairs). Are they the best option for datasets heavy with these features? Yes, as far as the data model goes. But, there are other considerations:
Although graph databases are often the better technology, relational databases have the home team advantage. In many enterprises, the value of graph databases can’t cover the immediate switching costs.
This creates the situation where graph-like data is shoehorned into relational databases. Migrating datasets to graph databases may be the ideal, but in the meantime…
Some tools, like NetworkX and GraphFrames, provide methods to construct graph abstractions from DataFrames. (See nx.from_pandas_edgelist
and Creating GraphFrames.) This allows graph operations on top of familiar tabular structures. These solutions won’t outperform a true graph database, but they’re a reasonable stopgap. If you need algorithms like betweeness to identify supply chain risk or node similarity to resolve entities, these tools are great options.
On the other hand, your application may not need the horsepower these tools provide. If you’re running a lean deployment, it may hard to justify the additional dependency. Below, we’ll write a descendants
method to demonstrate how basic graph operations can be implemented in pure PySpark.
Arbitrarily nested hierarchical data, like org structures or bills-of-materials, are commonly stored in relational databases. For example:
employee_id | sales_total | manager_employee_id |
1 | 253 | NULL |
2 | 308 | 1 |
3 | 92 | 1 |
4 | 20 | 2 |
5 | 148 | 4 |
6 | 377 | 1 |
7 | 87 | 2 |
Visually:
The goal is to the crawl this hierarchy to produce a list of all descendants of each node. Descendants of node 4
should be [5]
, of node 2
should be [4, 5, 7]
, and so on. The full target output is:
employee_id | sales_total | manager_employee_id | descendants |
1 | 253 | NULL | [2, 3, 4, 5, 6, 7] |
2 | 308 | 1 | [4, 5, 7] |
3 | 92 | 1 | [] |
4 | 20 | 2 | [5] |
5 | 148 | 4 | [] |
6 | 377 | 1 | [] |
7 | 87 | 2 | [] |
To get a list of all descendants, we’ll first start with a list of descendants in the immediate next generation.
import pyspark.sql.functions as psf
direct_reports = (
df.alias("left")
.join(
df.alias("right"),
1 on=[psf.col(f"left.employee_id") == psf.col(f"right.manager_employee_id")],
how="left",
)
2 .groupby(*[f"left.{col}" for col in df.columns])
3 .agg(psf.collect_set("right.employee_id").alias("direct_report_employee_ids"))
)
left
to each of their direct reports in right
.
The output:
employee_id | sales_total | manager_employee_id | direct_report_employee_ids |
1 | 253 | NULL | [2, 6, 3] |
2 | 308 | 1 | [7, 4] |
3 | 92 | 1 | [] |
4 | 20 | 2 | [5] |
5 | 148 | 4 | [] |
6 | 377 | 1 | [] |
7 | 87 | 2 | [] |
Capturing descendants in the next-level of hierarchy:
deg2_direct_reports = (
direct_reports.alias("left")
.join(
direct_reports.alias("right"),
on=[
psf.array_contains(
"left.direct_report_employee_ids",
1 psf.col("right.manager_employee_id"),
)
],
how="left",
)
2 .groupby(*[f"left.{col}" for col in direct_reports.columns])
3 .agg(psf.collect_set("right.employee_id").alias("deg2_direct_report_employee_ids"))
)
left
to each of their direct reports in right
.
The output:
employee_id | sales_total | manager_employee_id | direct_report_employee_ids | deg2_direct_report_employee_ids |
1 | 253 | NULL | [2, 6, 3] | [7, 4] |
2 | 308 | 1 | [7, 4] | [5] |
3 | 92 | 1 | [] | [] |
4 | 20 | 2 | [5] | [] |
5 | 148 | 4 | [] | [] |
6 | 377 | 1 | [] | [] |
7 | 87 | 2 | [] | [] |
single_generation_descendants
We’ll repeat this process to return all descendants, querying the next generation descendants of the prior generation descendants until the hierarchy terminates with no descendants remaining. Here’s a generalized implementation:
import pyspark.sql as ps
from typing import Optional
def single_generation_descendants(
df,
id_col_name: str, # id of the referent row
parent_id_col_name: str, # reference to the immediate parent
child_ids_col_name: Optional[str] = None, # reference to children in any particular generation
degree_num: Optional[int] = None, # avoids duplicate column names on chained calls
) -> ps.DataFrame:
# accounts for differences in join logic between the first and subsequent generations
assert (
"array" not in df.select(id_col_name).dtypes[0][-1]
), f"The ID column ({id_col_name}) cannot be an array."
assert (
"array" not in df.select(parent_id_col_name).dtypes[0][-1]
), f"The parent ID column ({parent_id_col_name}) cannot be an array."
if child_ids_col_name is None:
join_condition = psf.col(f"left.{id_col_name}") == psf.col(f"right.{parent_id_col_name}")
else:
assert (
"array" in df.select(child_ids_col_name).dtypes[0][-1]
), f"The child IDs column ({child_ids_col_name}) must be an array."
join_condition = psf.array_contains(
psf.col(f"left.{child_ids_col_name}"),
psf.col(f"right.{parent_id_col_name}"),
)
agg_func = psf.collect_set(psf.col(f"right.{id_col_name}"))
descendants_prefix = "next_gen_" if degree_num is None else f"deg{degree_num}_"
return (
df.alias("left")
.join(
df.alias("right").select(id_col_name, parent_id_col_name),
on=[join_condition],
how="left",
)
.groupby(*[psf.col(f"left.{col}") for col in df.columns])
.agg(agg_func.alias(f"{descendants_prefix}descendants"))
)
Integer and integer join condition.
next_gen_descendants
produces the expected result:
Array and integer join condition.
Chaining next_gen_descendants
twice produces the expected result:
id_col_name = "employee_id"
parent_id_col_name = "manager_employee_id"
assertDataFrameEqual(
(
df.transform(
single_generation_descendants, id_col_name, parent_id_col_name, degree_num=1
).transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg1_descendants",
2,
)
),
deg2_direct_reports.withColumnsRenamed(
{
"direct_report_employee_ids": "deg1_descendants",
"deg2_direct_report_employee_ids": "deg2_descendants",
}
),
)
descendants
Recursive algorithms require two fundamental elements:
While the final descendants
function won’t truly be recursive (it won’t call itself), it will have both accumulation logic and an end condition.
Focusing on the end condition, we see next_gen_descendants
returns no descendants on the fourth call:
(
df.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
degree_num=1,
)
.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg1_descendants",
degree_num=2,
)
.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg2_descendants",
degree_num=3,
)
.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg3_descendants",
degree_num=4,
)
)
employee_id | sales_total | manager_employee_id | deg1_descendants | deg2_descendants | deg3_descendants | deg4_descendants |
1 | 253 | NULL | [2, 6, 3] | [7, 4] | [5] | [] |
2 | 308 | 1 | [7, 4] | [5] | [] | [] |
3 | 92 | 1 | [] | [] | [] | [] |
4 | 20 | 2 | [5] | [] | [] | [] |
5 | 148 | 4 | [] | [] | [] | [] |
6 | 377 | 1 | [] | [] | [] | [] |
7 | 87 | 2 | [] | [] | [] | [] |
So, we’re one generation beyond the end state when WHERE size({last_crawled_gen_descendants}) > 0
returns no data.
(
df.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
degree_num=1,
)
.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg1_descendants",
degree_num=2,
)
.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg2_descendants",
degree_num=3,
)
.transform(
single_generation_descendants,
id_col_name,
parent_id_col_name,
"deg3_descendants",
degree_num=4,
)
).filter(psf.size("deg4_descendants") > 0).rdd.isEmpty()
True
Wrapping this together:
def descendants(
df: ps.DataFrame,
id_col_name: str, # column name of the ID column that identifies the node
parent_id_col_name: str, # column of the parent ID column that identifies the parent node
) -> ps.DataFrame:
def _next_gen_exists(
maybe_next_gen_descendants: ps.DataFrame,
) -> bool:
return not maybe_next_gen_descendants.filter(
psf.size(maybe_next_gen_descendants.columns[-1]) > 0
).rdd.isEmpty()
# get next generation descendants
# next_gen_descendants accumulates by appending to the returned DataFrame
1 degree_num = 1
2 maybe_next_gen_descendants = single_generation_descendants(
df, id_col_name, parent_id_col_name, degree_num=degree_num
)
# if at least one descendant exists in the next generation,
# update table of all descendants and check the next generation
3 while _next_gen_exists(maybe_next_gen_descendants):
wide_descendants = maybe_next_gen_descendants
4 degree_num += 1
5 maybe_next_gen_descendants = single_generation_descendants(
wide_descendants,
id_col_name,
parent_id_col_name,
maybe_next_gen_descendants.columns[-1],
degree_num=degree_num,
)
# merge descendants into one list
descendant_col_names = [
col for col in wide_descendants.columns if col not in df.columns
]
# return wide_descendants
7 return wide_descendants.withColumn(
"descendants",
psf.array_sort(psf.flatten(psf.array(*descendant_col_names))),
).drop(*descendant_col_names)
degree_num = 1
degree_num
employee_id | sales_total | manager_employee_id | descendants |
1 | 253 | NULL | [2, 3, 4, 5, 6, 7] |
2 | 308 | 1 | [4, 5, 7] |
3 | 92 | 1 | [] |
4 | 20 | 2 | [5] |
5 | 148 | 4 | [] |
6 | 377 | 1 | [] |
7 | 87 | 2 | [] |
Verifying:
All good.
With a full list of descendants for each record, it’s trivial run aggregations over the rolled up hierarchy.
descendant_df = descendants(df, "employee_id", "manager_employee_id")
descendant_df.alias("left").join(
descendant_df.alias("right").select("employee_id", "sales_total"),
on=[
(psf.col("left.employee_id") == psf.col("right.employee_id"))
| (
psf.array_contains(
psf.col("left.descendants"),
psf.col("right.employee_id"),
)
)
],
).groupby([psf.col(f"left.{col}") for col in descendant_df.columns]).agg(
psf.sum("right.sales_total").alias("rollup_sales_total")
)
employee_id | sales_total | manager_employee_id | descendants | rollup_sales_total |
1 | 253 | NULL | [2, 3, 4, 5, 6, 7] | 1285 |
2 | 308 | 1 | [4, 5, 7] | 563 |
3 | 92 | 1 | [] | 92 |
4 | 20 | 2 | [5] | 168 |
5 | 148 | 4 | [] | 148 |
6 | 377 | 1 | [] | 377 |
7 | 87 | 2 | [] | 87 |
def hierarchical_rollup(
df: ps.DataFrame, id_col_name: str, parent_id_col_name: str
) -> ps.GroupedData:
descendant_df = descendants(df, id_col_name, parent_id_col_name)
rolled_up_data = (
descendant_df.withColumnsRenamed(
{col: f"_left_{col}" for col in descendant_df.columns if col != id_col_name}
)
.join(
descendant_df.withColumnsRenamed({id_col_name: f"_right_{id_col_name}"}),
on=[
(psf.col(id_col_name) == psf.col(f"_right_{id_col_name}"))
| (
psf.array_contains(
psf.col("_left_descendants"),
psf.col(f"_right_{id_col_name}"),
)
)
],
)
.groupby(id_col_name)
)
return rolled_up_data
In some cases, implementing your own graph algorithms in your team’s SQL engine may be the best option. Above, we covered implementing descendants
using the PySpark DataFrame API by using self-joins and built-in array operations. Similar techniques could be used to look up ancestors, degree, etc.