During my first data engineering gig at Amazon in the audience insights team, I took for granted the unit testing of Spark code. Those were Scala functions at the end of day, so it was natural to test them: craft some test inputs, invoke, get a result and assert. Now at a smaller company, and dealing with pure SQL, I see myself missing these a lot. Sure, we test for data quality with DBT notably, but this is not the same at all. It’s much harder to battle your test your code against edge cases. I find myself missing the reassurance, understanding that unit tests grant you, especially when confronted with complex SQL code.
Part of it is bad practice, but to be fair the other part is that SQL tooling is lacking. With Spark, the runtime is typically right there in your environment, which is not the case for SQL. I don’t believe in having a test database for that purpose for many reasons, but depending on the engine one option could be to create a containerized instance in your test environment (i.e. CI/CD) and run the tests on it. Even with DBT, unit testing plugins are clunky at best. The other option is using DuckDB.
I first heard of DuckDB as a “local database that enables reading directly from local CSV files”, during the summer 2022 hype. Another very cool thing with DuckDB is the Python package, which allows you to run DuckDB in your Python programs, and run queries on local pandas Dataframes seamlessly.
import duckdb
import pandas as pd
df = pd.DataFrame([{"a": 1}, {"a": 2}])
duckdb.sql("select max(a) from df").df() # => [2]
Hence it clicked: you could construct test Dataframes in your test suite, “inject” them in your query by naming them the same as the relations in your query, run the query, and get back a result Dataframe that you can run assertions on.
For example in a current side project, I am dealing with the following query to compute the typing speed of a user:
with time_windows as (
select
generate_series as window_start
from
generate_series(
now() - interval '6 hours',
now(),
interval '15 minutes'
) as window_start
),
type_intervals as (
select
window_start,
source_url,
session_id,
record_time,
lead(record_time) over (
partition by window_start, source_url, session_id
order by record_time asc
) - record_time as interval_to_next_event,
is_return as is_error
from keyevents
join time_windows on
record_time >= window_start and
record_time < window_start + interval '15 minutes'
order by record_time asc
),
flows as (
select
*,
case
when lag(interval_to_next_event) over (partition by window_start, source_url, session_id order by record_time asc) > interval '5 seconds' then 1
else 0
end as flow_indicator
from type_intervals
),
grouped_flows as (
select
*,
sum(flow_indicator) over (
partition by window_start, source_url, session_id
order by record_time asc
) as flow_id
from flows
),
stats_by_flow as (
select
window_start,
source_url,
session_id,
flow_id,
avg(extract(milliseconds from interval_to_next_event)) as avg_type_speed,
count(*) as event_count,
sum(cast(is_error as int)) as error_count
from grouped_flows
group by window_start, source_url,session_id,flow_id
having count(*) > 1
)
select
window_start,
sum(avg_type_speed * event_count) / sum(event_count) as speed,
coalesce(
stddev(avg_type_speed),
0
) as volatility,
sum(event_count) as event_count,
sum(error_count) as error_count
from stats_by_flow
where event_count > 1
group by window_start
order by window_start desc
Let’s say for example that I want to test if the weighted average of typing speed is correctly computed. I can generate a time series with 75% of records at a speed of 8 and 25% records at a speed of 3, so that I know what output to expect:
import duckdb
import pandas as pd
from datetime import datetime as dt, timedelta
def make_event(t, source_url="example.com", session_id="1", is_error=False):
return {
"session_id": session_id,
"source_url": source_url,
"record_time": t,
"is_return": is_error
}
keyevents = pd.DataFrame((
[ make_event(now - timedelta(milliseconds=k * 3)) for k in range (10)] +
[ make_event(now - timedelta(milliseconds=k * 8 + 55), source_url="google.com") for k in range (30)]
))
Then, as I am naming my DataFrame in the same name as my real table, Duckdb will use it to compute the query, and I can run my test simply as follows, get back a dataframe and run some assertions on it. Here I know the expected number of return rows and the expected typing speed weighted average.
with open("typing_speed_current.sql") as fd:
query = fd.read()
def test_weight_is_accounted_per_flow():
now = dt.now()
keyevents = pd.DataFrame((
[ make_event(now - timedelta(milliseconds=k * 3)) for k in range (10)] +
[ make_event(now - timedelta(milliseconds=k * 8 + 55), source_url="google.com") for k in range (30)]
))
res = duckdb.sql(query).df()
assert len(res) == 1
expected = 0.25 * 3 + 0.75 * 8
assert res.speed.values[0] == expected
There you have it. The only big caveats with this solution, is that I am not sure how to deal with fully qualified relation names (because you can’t name a variable schema.table), and also that you need a Postgre compatible SQL query.