Skip to main content

/chunkify-task

Convert a single CLD sub-DAG task into a parallelized chunk+union pattern. Splits the task across payer chunks using Airflow dynamic task mapping, then unions the results into the final table.

Usage

/chunkify-task <tasks-module> <function-name>

Provide the module name (filename under tasks/ without .py) and the function name to chunkify. For example: /chunkify-task benchmarks build_benchmarks.

What It Does

A "chunkified" task splits its work across payer chunks using Airflow's dynamic task mapping (.expand()). Each chunk processes a subset of payer networks in parallel, writing to a per-chunk intermediate table. A union task then combines all chunks into the final table — the same table name the rest of the DAG expects.

get_ros_payer_chunks()
|
v
build_<name>_chunk xN (one per chunk, runs in parallel)
|
v
build_<name>_union (combines all chunk tables into the final table)

The skill makes four categories of changes:

  1. Updates tasks/<module>.py — renames the original function to _chunk, adds payer_id_chunk parameter and payer network filtering logic, sets max_active_tis_per_dag=8 to limit Trino parallelism, and adds a new _union function.

  2. Updates the chunk SQL file — inserts {{ n_chunk }} into the table name, adds chunk header comments, and adds a payer_network_list WHERE filter. Multi-network payers (Aetna/7, UHC/76, Cigna/643) get additional network-level disambiguation.

  3. Creates a union SQL file — a new <name>_union.sql that does UNION ALL across all chunk tables, preserving the original PARTITIONING.

  4. Updates __init__.py — replaces the single task call with a TaskGroup containing chunk expansion and union, preserving the original variable name so downstream dependencies don't change.

Files Modified

FileChange
tasks/<module>.pyRename function to _chunk, add payer_id_chunk arg, add _union function
sql/<folder>/<name>.sqlAdd {{ n_chunk }} to table name, add payer filter
sql/<folder>/<name>_union.sqlNew file: UNION ALL across all chunk tables
__init__.pyReplace single task call with TaskGroup containing chunks + union

Reference

Chunkify Task: Reference

Task File Patterns

Chunk task

@task(retries=3, max_active_tis_per_dag=8)
def build_<name>_chunk(payer_id_chunk, dag_run=None, params=None):

sql_loc, version, sub_version, schema_name = get_cld_sub_dag_params(dag_run, params)
n_chunk, payer_network_tuples = payer_id_chunk

payer_network_list = [
{"payer_id": payer_id, "network_type": network_type, "network_name": network_name, "bill_type": bill_type}
for payer_id, network_type, network_name, bill_type in payer_network_tuples
]

run_query(
loc=sql_loc,
file="/<folder>/<name>.sql",
params={
"sub_version": sub_version,
"schema_name": f"{schema_name}{version}",
# ... any extra params from the original function ...
"payer_network_list": payer_network_list,
"n_chunk": n_chunk,
"run_id": dag_run.run_id if dag_run else None,
"task_name": "build_<name>_chunk",
},
drop_table_name=f"{schema_name}{version}.tmp_int_<prefix>_{n_chunk}_{sub_version}",
)

Union task

@task(retries=10)
def build_<name>_union(payer_ids, dag_run=None, params=None):

sql_loc, version, sub_version, schema_name = get_cld_sub_dag_params(dag_run, params)
n_chunks = [x[0] for x in payer_ids]

run_query(
loc=sql_loc,
file="/<folder>/<name>_union.sql",
params={
"sub_version": sub_version,
"schema_name": f"{schema_name}{version}",
"n_chunks": n_chunks,
"run_id": dag_run.run_id if dag_run else None,
"task_name": "build_<name>_union",
},
drop_table_name=f"{schema_name}{version}.tmp_int_<prefix>_{sub_version}",
)

SQL Chunk File Changes

Header comment block

-- Table: {{ schema_name }}.tmp_int_<prefix>_{{ n_chunk }}_{{ sub_version }}
-- N Chunk: {{ n_chunk }}
-- Payer Network List: {{ payer_network_list }}

Payer network filter

Applied to the final SELECT. <ALIAS> is the alias of the primary table.

WHERE (
{% for item in payer_network_list %}
(<ALIAS>.payer_id = '{{ item.payer_id }}'
{% if item.payer_id in ['7', '76', '643'] and item.network_type %}
AND n.network_type = '{{ item.network_type }}'
{% endif %}
{% if item.payer_id in ['7', '76', '643'] and item.network_name %}
AND n.network_name = '{{ item.network_name }}'
{% endif %}
{% if item.payer_id in ['7', '76', '643'] and item.bill_type %}
AND <ALIAS>.bill_type = '{{ item.bill_type }}'
{% endif %}
)
{% if not loop.last %} OR {% endif %}
{% endfor %}
)

Payer IDs 7, 76, 643 are multi-network payers (Aetna, UHC, Cigna) that require network-level disambiguation.


Union SQL Template

CREATE OR REPLACE TABLE {{ schema_name }}.tmp_int_<prefix>_{{ sub_version }}
WITH (
PARTITIONING = ARRAY[<copy from chunk SQL>]
)
AS
{% for n_chunk in n_chunks %}
SELECT * FROM {{ schema_name }}.tmp_int_<prefix>_{{ n_chunk }}_{{ sub_version }}
{% if not loop.last %}UNION ALL{% endif %}
{% endfor %}

__init__.py TaskGroup Pattern

with TaskGroup("<name>_in_chunks", tooltip="<name> in payer chunks") as <original_var>:
payer_id_chunks = utils.get_ros_payer_chunks()
<name>_chunks = (
<module>.build_<name>_chunk.partial().expand(payer_id_chunk=payer_id_chunks)
)
<name>_union = <module>.build_<name>_union(payer_id_chunks)

<name>_chunks >> <name>_union

Real Examples

benchmarks

  • Module: tasks/benchmarks.py
  • Function: build_benchmarksbuild_benchmarks_chunk + build_benchmarks_union
  • SQL: sql/benchmarks/benchmarks.sql → + sql/benchmarks/benchmarks_union.sql
  • Table prefix: tmp_int_benchmarks

accuracy_raw

  • Module: tasks/accuracy.py
  • Function: build_accuracy_rawbuild_accuracy_raw_chunk + build_accuracy_raw_union
  • SQL: sql/accuracy/accuracy_raw.sql → + sql/accuracy/accuracy_raw_union.sql
  • Table prefix: tmp_int_accuracy_raw