Multithreading pattern for config driven Databricks SQL queries

30 Views Asked by At

I'm currently working on a project around creating config driven pattern for processing multiple processes at once. The issue I'm having is how to do this in the best way possible primarily using Databricks SQL. Currently I have a process that reads from a table and pulls in all of the possible configurations that we would need to process. Then each row is fed to a function running multiple spark.sql statement that could do any of the following, Create Table, create view based on ctes, Merge into, drop table, vacuum table, along with logging throughout the process that is used to monitor runtime, query drifts from changing configs, failure, and failure errors. Another thing to note is that none of the configs rely on another config and can be ran out of order.

Before showing the code, I realize that dataframes maybe, or probably are, better for doing something like this. The team that I work on is currently in the process of moving from on-prem/Azure MSSQL to Databricks and we're starting with Databricks SQL to get everyone into the swing of thing before jumping into dataframes.

Current setup:

Foreach Loop

sql_user = 'dbconfigreader'
sql_pass =  dbutils.secrets.get(scope = REDACTED, key = REDACTED)
jdbc_url = REDACTED

config_table = (spark.read
  .format("jdbc")
  .option("url",jdbc_url)
  .option("dbtable", "REDACTED")
  .option("user", sql_user)
  .option("password", sql_pass)
  .load()
)


x = "Test"
driver_manager = spark._sc._gateway.jvm.java.sql.DriverManager
con = driver_manager.getConnection(jdbc_url, sql_user, sql_pass)
configs = config_table.where(f"""ScheduleName = '{x}' and IsEnabled = 1""")
config_ls = configs.toJSON().collect()
for config in config_ls:
    failure = process_table(config)
    print(failure)
con.close()

Process Function

import json
from datetime import datetime

def process_table(config):
  sql_statement = None
  columns = None
  failure_reason = None
  success = 1
  start_time = datetime.now()
  item = json.loads(config)
  full_sql = ''
  sql_statement = None
  try:
    full_sql += f"USE CATALOG {environment};"
    spark.sql(f"USE CATALOG {environment}")
    spark.sql(f"""CREATE TABLE IF NOT EXISTS silver.{item['SilverTable']} ({item['StartingTableStructure']}, Hash STRING NOT NULL, CreatedOnDate TIMESTAMP NOT NULL, LastUpdateDate TIMESTAMP, DeletedOnDate TIMESTAMP)""")

    # Run query for Quarantine table
    if 'QuarantineQuery' in item:
      if item['QuarantineQuery'] is not None:
        sql_statement = f"""CREATE OR REPLACE TABLE dq.{item['QuarantineTable']} AS {item['QuarantineQuery']} SELECT * FROM quarantine_final"""
        full_sql += sql_statement
        spark.sql(sql_statement)

    # Run query for stage table
    view = f"""CREATE OR REPLACE TABLE silver.stage_{item['SilverTable']} AS {item['StageQuery']} SELECT * FROM stage_final"""
    full_sql += view + ';'
    spark.sql(view)


    
    hash_columns = ",".join([f"IFNULL({column.strip()}, '')" for column in item['HashedColumns'].split(',')])
    update_columns = ",".join([f"`{column.replace(' ','')}` = Source.`{column.replace(' ','')}`" for column in item['MergeSourceColumns'].split(',')])
    source_columns = ",".join([f"{column.strip()}" for column in item['MergeSourceColumns'].split(',')])

    sql_statement = f"""MERGE INTO silver.{item['SilverTable']} as Target USING (
        SELECT
          {source_columns},
          sha2(
            CONCAT({ hash_columns }),
            256
          ) as Hash
        FROM
          silver.stage_{item['SilverTable']}
      ) as Source
      ON {item['MergeJoin']}
      WHEN MATCHED AND Target.`Hash` <> Source.`Hash` THEN 
      UPDATE SET {update_columns}, `Hash` = Source.`Hash`, LastUpdateDate = GETDATE()
      WHEN NOT MATCHED BY TARGET THEN 
      INSERT ({source_columns}, `Hash`, `CreatedOnDate`, `LastUpdateDate`, `DeletedOnDate`)
      VALUES ({source_columns}, `Hash`, GETDATE(), GETDATE(), NULL)
      WHEN NOT MATCHED BY Source THEN
      UPDATE SET Target.`DeletedOnDate` = GETDATE(); """
      
    full_sql += sql_statement
    spark.sql(sql_statement)

    sql_statement = (f"""DELETE FROM silver.{item['SilverTable']} 
              WHERE dateadd(Day, -5, GETDATE()) > DeletedOnDate""")
    full_sql += sql_statement + ";"
    spark.sql(sql_statement)


    # Clean up stage tables 
    sql_statement = f"""DROP Table silver.stage_{item['SilverTable']};""" 
    full_sql += sql_statement
    spark.sql(sql_statement)
    sql_statement = f"""VACUUM silver.{item['SilverTable']} RETAIN 168 HOURS"""
    full_sql += sql_statement + ";"
    spark.sql(sql_statement)
    if 'QuarantineQuery' in item:
      if spark.catalog.tableExists('silver', item['QuarantineTable']):
        sql_statement = f"""VACUUM dq.{item['QuarantineTable']} RETAIN 168 HOURS"""
        full_sql += sql_statement + ";"
        spark.sql(sql_statement)

  except Exception as e: 
    success = 0
    failure_reason = e

  finally: 
    end_time = datetime.now()
    sql_statement = None
    full_sql = full_sql.replace('\'', '\'\'')
    if failure_reason is None: 
      failure_reason = "NULL"
    sql_statement = f"""EXEC dbo.pPut_dbRunLog '{item['SilverTable']}', '{start_time.strftime('%Y-%m-%d %H:%M:%S')}', '{end_time.strftime('%Y-%m-%d %H:%M:%S')}', {success}, {failure_reason}, '{full_sql}'"""
    exec_statement  = con.prepareCall(sql_statement)
    exec_statement.execute()
    if success == 1:
      sql_statement = f"""EXEC dbo.pPut_dbTableWatermark '{item['SilverTable']}', '{start_time.strftime('%Y-%m-%d %H:%M:%S')}'"""
      exec_statement  = con.prepareCall(sql_statement)
      exec_statement.execute()
    exec_statement.close()
  return failure_reason

The above works with the testing that I have been doing outside of it being sequential. So how do I go about running everything in parallel?

I've tried turning the function into a UDF but have failed with calling spark.sql statements on worker nodes. I have tinkered with the ThreadPoolExecutor with some success of speeding up the process.

import concurrent.futures as cf

...


x = "Test"
driver_manager = spark._sc._gateway.jvm.java.sql.DriverManager
con = driver_manager.getConnection(jdbc_url, sql_user, sql_pass)
configs = config_table.where(f"""ScheduleName = '{x}' and IsEnabled = 1""")
config_ls = configs.toJSON().collect()
results = []
with cf.ThreadPoolExecutor(max_workers=2) as executor:
    futures = {executor.submit(process_table, config): config for config in config_ls}
    for future in cf.as_completed(futures):
        results.append(future)
con.close()

The high-level thought with the ThreadPoolExecutor is to have a smaller driver running multiple threads creating the sql statements and have larger worker nodes running the created statements. That seemed to work but it feels more like a hack than a solution. Is there something obvious that I'm missing or is there a better way to parallelize the code I have above?

0

There are 0 best solutions below