Traversing a tree stored as an adjacency list using a Recursive CTE built in SQLAlchemy
This is part of a series on combining PostgreSQL Materialized Views, Recursive CTEs, and SQLAlchemy:
- Using SQLAlchemy to create and manage PostgreSQL Materialized Views
- Traversing a tree stored as an adjacency list using a Recursive CTE built in SQLAlchemy (this post)
The problem:
Recently, while working on the RockClimbing.com codebase, I encountered several situations where data needed to be pre-calculated/cached because on-the-fly calculations were too slow.
For example, RockClimbing.com has over 100,000 climbing routes scattered across 35,000 locations.The locations are essentially a tree structure, ranging in specificity from general area to a specific wall: “North America” > “California” > “Yosemite” > “El Capitan” > “North Face”.
As you browse different areas, it’d be nice to see the total number of climbing routes within each area. For example, if you’re viewing North America, on each state we’d like to display the total number of climbing routes in that state.
In other words, for each node, find all child nodes and then sum the routes attached to those child nodes.
There’s a number of ways to map a tree structure to SQL tables. For simplicity, I store the locations as an adjacency list. Adjacency lists make it easy to insert and reorder nodes, although traversing the tree (especially downwards) can be slow because you have to walk through all the nodes in order to identify the children. An index helps, but it’s still not fast, especially when you need to traverse multiple levels deep.
Trying to solve it with Python:
My initial prototype looked something like this:
1 2 3 4 5 6 7 8 9 |
def recursive_child_locations(location): child_l = [] if len(location.children) > 0: for child in location.children: child_l.extend(recursive_child_locations(location)) return child_l recursive_child_route_count = db.session.query(Routes).filter(Routes.location. \ in_(location, recursive_child_locations(self))).count() |
This worked fine with my normal test dataset of 100 locations, but became too slow once I increased the test dataset to 50,000 locations.
Switching to a while loop would improve things slightly since Python doesn’t support tail recursion. But the bigger problem is that accessing location.children is expensive. My location model is configured with lazy='dynamic' which forces SQLAlchemy to issue a new query for every location. A better solution is pre-fetching all the locations and then iterating through them in memory, but that’s still a lot of data being unnecessarily transferred between the database and the app.
Another attempt using a Recursive CTE:
Instead of doing all this data transfer, if I can write a database query that traverses the tree, then I could handle most of the filtering work in the database and only return the results. Plus the database C code will be much faster than interpreted Python code.
Traditionally, traversing adjacency lists was impossible with normal SQL and required writing custom database functions. Thankfully, we’re using PostgreSQL which supports Recursive CTEs (a quick introduction) that make it a good deal easier:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
--Find the recursive child locations under the current node WITH RECURSIVE children(id, parent_id) AS ( --base case SELECT location.id AS id, location.parent_id AS parent_id FROM location WHERE location.id = 37 -- setting this to id includes current object -- setting to parent_id includes only children --combine with recursive part UNION ALL SELECT location.id AS location_id , location.parent_id AS location_parent_id FROM location, children WHERE children.id = location.parent_id ) TABLE children; --Now calculate count of routes attached to these locations SELECT count(route.id) AS count_recursive_child_items_for_single_loc FROM route, children WHERE route.location_id = children.id; |
I tested this using a test dataset to 35,000 locations and 100,000 routes. For the most expensive queries near the top of the tree it took about 200ms-250ms of database time to walk the tree and then calculate the cumulative route count. That’s a huge speedup over doing all the work in Python, but still not fast enough to dynamically generate it on every page load, especially since some pages display the route counts for multiple locations, so they’d need to run the query multiple times.
Pre-calculating the result and storing it:
The nice thing about this problem is we don’t require the results to always be 100% up to date. Only a handful of new routes are added every day, and nothing bad happens if someone browsing the website sees a count that is a few hours out of date.
So a better solution is to pre-calculate the result and cache it somewhere.
I decided against adding a count_recursive_routes column to the location table because it felt hacky–I’d rather separate the calculated/cached data from the original data. I expect the number of things I want to pre-calculate to grow over time, and I don’t want to keep stacking extra columns on the table. Using an in-memory datastore like Redis is another common solution, but so far we haven’t needed it, and I was hesitant to add the extra complexity to our stack just for this.
Instead, I decided to use a PostgreSQL Materialized View. If you’re not familiar with materialized views, my previous blog post provides a good overview.
Generalizing the Recursive CTE across the entire table:
The recursive CTE solution above calculates the recursive route count for a single location, but for our materialized view, we need to generalize the query so it returns route counts for all locations. On the surface, that sounds easy, but it turned out to be much harder. I finally turned to StackOverflow where Erwin Brandstetter suggested solving it using a PostgreSQL ARRAY:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
WITH RECURSIVE tree AS ( SELECT id, parent_id, ARRAY[id] AS path FROM location WHERE parent_id IS NULL UNION ALL SELECT c.id, c.parent_id, path || c.id FROM tree t JOIN location c ON c.parent_id = t.id ) , tree_ct AS ( SELECT t.id, t.path, COALESCE(i.item_ct, 0) AS item_ct FROM tree t LEFT JOIN ( SELECT location_id AS id, count(*) AS item_ct FROM route GROUP BY 1 ) i USING (id) ) SELECT t.id , t.item_ct AS count_direct_child_items , sum(t1.item_ct) AS count_recursive_child_items FROM tree_ct t LEFT JOIN tree_ct t1 ON t1.path[1:array_upper(t.path, 1)] = t.path GROUP BY t.id, t.item_ct; |
This query took approximately 18 minutes to calculate results for the entire test dataset of 35,000 locations and 100,000 routes. This is plenty fast for something that only needs to run once a day as a background job.
Converting the Recursive CTE to a SQLAlchemy selectable:
In order to use the materialized view with SQLAlchemy, I needed to convert the raw SQL query into a SQLAlchemy selectable and pass it to our custom create_mat_view() function.
Converting the recursive CTE turned out to be more straightforward than I expected. I incrementally built the various components, then combined them into a single query. As mentioned previously, if you’re using Flask-SQLAlchemy, the query needs to be built using the SQLAlchemy core select() function, rather than the more typical session.query().
Anytime you’re translating a complicated SQL query to SQLAlchemy, it’s much easier if you print the individual components to make sure you’re assembling the query correctly. In the below code, I left in a few of these print statements that I used for debugging. You’ll notice that sometimes to see the actual output you need to tell SQLAlchemy to compile the statement using the postgresql.dialect().
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
_tree_cte = db.select([ Location.id, Location.parent_id, # TODO swap postgresql.array to db.array once SQLalchemy 1.1 releases postgresql.array([Location.id]).label('path') ]).where(Location.parent_id==None ).cte(name='tree', recursive=True) _tree_cte = _tree_cte.union_all(db.select([ Location.id, Location.parent_id, # TODO swap postgresql.array to db.array once SQLalchemy 1.1 releases (_tree_cte.c.path + postgresql.array([Location.id]) ).label('path'), ]).select_from(db.join(_tree_cte, Location, _tree_cte.c.id==Location.parent_id))) # print(db.select([_tree_cte]).compile(dialect=postgresql.dialect())) # DEBUG direct_count = db.select([Route.category_id.label('id'), db.func.count(Route.id).label('item_ct'), ]).group_by(Route.category_id).alias('direct_count') _tree_count_cte = db.select([_tree_cte.c.id, _tree_cte.c.path, db.func.coalesce(direct_count.c.item_ct, 0 ).label('item_ct'), ]).select_from( db.join(_tree_cte, direct_count, _tree_cte.c.id==direct_count.c.id, isouter=True) ).cte(name='tree_ct') # print(db.select([_tree_count_cte]).compile(dialect=postgresql.dialect())) # DEBUG _aliased_tree_count_cte = _tree_count_cte.alias('t1') gc_mv_query = db.select([ _tree_count_cte.c.id.label('id'), _tree_count_cte.c.item_ct.label('count_direct_child_items'), db.func.sum(_aliased_tree_count_cte.c.item_ct ).label('count_recursive_child_items'), ]).select_from( db.join(_tree_count_cte, _aliased_tree_count_cte, _tree_count_cte.c.path== _aliased_tree_count_cte.c.path[1: db.func.array_upper(_tree_count_cte.c.path, 1)], isouter=True) ).group_by(_tree_count_cte.c.id, _tree_count_cte.c.item_ct) # print(gc_mv_query.compile(dialect=postgresql.dialect())) # DEBUG |
The final print debug statement prints a SQL query that exactly matches the definition of the materialized view (excluding the CREATE MATERIALIZED VIEW portion. Now we just need to pass this to the create_mat_view() function:
1 2 3 4 |
class GearCategoryMV(MaterializedView): __table__ = create_mat_view("gear_category_mv", _gc_mv_query) db.Index('gear_category_mv_id_idx', GearCategoryMV.id, unique=True) |
Querying:
You can read about how to query a materialized view using SQLAlchemy in my previous blog post.
Wrapup:
The more I use SQLAlchemy, the more impressed I am with it (and Mike Bayer, the primary developer behind it). Not many ORMs support creating materialized views using a Recursive CTE, the PostgreSQL-specific ARRAY datatype, and more.