How DuckDB enables unit testing SQL queries

August 05, 2023

During my first data engineering gig at Amazon in the audience insights team, I took for granted the unit testing of Spark code. Those were Scala functions at the end of day, so it was natural to test them: craft some test inputs, invoke, get a result and assert. Now at a smaller company, and dealing with pure SQL, I see myself missing these a lot. Sure, we test for data quality with DBT notably, but this is not the same at all. It’s much harder to battle your test your code against edge cases. I find myself missing the reassurance, understanding that unit tests grant you, especially when confronted with complex SQL code.

Part of it is bad practice, but to be fair the other part is that SQL tooling is lacking. With Spark, the runtime is typically right there in your environment, which is not the case for SQL. I don’t believe in having a test database for that purpose for many reasons, but depending on the engine one option could be to create a containerized instance in your test environment (i.e. CI/CD) and run the tests on it. Even with DBT, unit testing plugins are clunky at best. The other option is using DuckDB.

I first heard of DuckDB as a “local database that enables reading directly from local CSV files”, during the summer 2022 hype. Another very cool thing with DuckDB is the Python package, which allows you to run DuckDB in your Python programs, and run queries on local pandas Dataframes seamlessly.

import duckdb
import pandas as pd

df = pd.DataFrame([{"a": 1}, {"a": 2}])
duckdb.sql("select max(a) from df").df() # => [2]

Hence it clicked: you could construct test Dataframes in your test suite, “inject” them in your query by naming them the same as the relations in your query, run the query, and get back a result Dataframe that you can run assertions on.

For example in a current side project, I am dealing with the following query to compute the typing speed of a user:

with time_windows as (
  select 
      generate_series as window_start
  from 
    generate_series(
      now() - interval '6 hours',
      now(),
      interval '15 minutes'
    ) as window_start
),
type_intervals as (
  select 
    window_start,
    source_url,
    session_id,
    record_time,
    lead(record_time) over (
      partition by window_start, source_url, session_id
      order by record_time asc
    ) - record_time  as interval_to_next_event,
    is_return as is_error
  from keyevents 
  join time_windows on 
    record_time >= window_start and 
    record_time < window_start + interval '15 minutes'
  order by record_time asc
),
flows as (
  select 
    *,
    case 
      when lag(interval_to_next_event) over (partition by window_start, source_url, session_id order by record_time asc) > interval '5 seconds' then 1 
      else 0
    end as flow_indicator
  from type_intervals
),
grouped_flows as (
  select 
    *,
    sum(flow_indicator) over (
      partition by window_start, source_url, session_id 
      order by record_time asc
    ) as flow_id
  from flows
),
stats_by_flow as (
  select 
    window_start,
    source_url,
    session_id,
    flow_id,
    avg(extract(milliseconds from interval_to_next_event)) as avg_type_speed,
    count(*) as event_count,
    sum(cast(is_error as int)) as error_count
  from grouped_flows
  group by window_start, source_url,session_id,flow_id
  having count(*) > 1
)
select 
  window_start,
  sum(avg_type_speed * event_count) / sum(event_count) as speed,
  coalesce(
    stddev(avg_type_speed),
    0
  ) as volatility,
  sum(event_count) as event_count,
  sum(error_count) as error_count
from stats_by_flow
where event_count > 1
group by window_start
order by window_start desc

Let’s say for example that I want to test if the weighted average of typing speed is correctly computed. I can generate a time series with 75% of records at a speed of 8 and 25% records at a speed of 3, so that I know what output to expect:

import duckdb
import pandas as pd
from datetime import datetime as dt, timedelta

def make_event(t, source_url="example.com", session_id="1", is_error=False):
    return {
        "session_id": session_id,
        "source_url": source_url,
        "record_time": t,
        "is_return": is_error
    }
keyevents = pd.DataFrame((
    [ make_event(now - timedelta(milliseconds=k * 3)) for k in range (10)] + 
    [ make_event(now - timedelta(milliseconds=k * 8 + 55), source_url="google.com") for k in range (30)]
))

Then, as I am naming my DataFrame in the same name as my real table, Duckdb will use it to compute the query, and I can run my test simply as follows, get back a dataframe and run some assertions on it. Here I know the expected number of return rows and the expected typing speed weighted average.

with open("typing_speed_current.sql") as fd:
    query = fd.read()

def test_weight_is_accounted_per_flow():
    now = dt.now()
    keyevents = pd.DataFrame((
        [ make_event(now - timedelta(milliseconds=k * 3)) for k in range (10)] + 
        [ make_event(now - timedelta(milliseconds=k * 8 + 55), source_url="google.com") for k in range (30)]
    ))
    res = duckdb.sql(query).df()
    assert len(res) == 1
    expected = 0.25 * 3 + 0.75 * 8
    assert res.speed.values[0] == expected

There you have it. The only big caveats with this solution, is that I am not sure how to deal with fully qualified relation names (because you can’t name a variable schema.table), and also that you need a Postgre compatible SQL query.


Written by Frédéric Gessler Deep learning, hardware accelerators and compilers enthusiast. Hacker. Builder. EPFL and Amazon alumni