The first and most important step towards developing a powerful machine learning model is acquiring good data. It doesn’t matter if you’re using a simple logistic regression or the fanciest state-of-the-art neural network to make predictions: If you don’t have rich input, your model will be garbage in, garbage out.
This exposes an unfortunate truth that every hopeful, young data scientist has to come to terms with: Most of the time spent developing a production-ready model is invested in writing ETL that gets your data in a clean format.
In this post aimed at SQL practitioners who would rather spend their time writing Python, I’ll outline a trick that can make your ETL more maintainable and easier to write using one of my favorite libraries, jinja2. R enthusiasts can also follow along, developing analogous patterns using one of the language’s many text templating packages, such as whisker.
Unlike the “one weird trick” from those ridiculous advertisements, ours actually works! Before we dig into the details, I’ll provide a little more motivation.
Writing SQL is a pain
As we’ve written about previously, Stitch Fix engineers do not write ETL for data scientists. Instead they provide tools that empower data scientists to more easily perform computations and deploy applications themselves. In turn data scientists are responsible for the end-to-end development of whatever data product they’re building from data acquisition, to processing and modeling, to feeding their polished results into an API, web app, or analysis. That means most data scientists have to write a decent amount of ETL. Usually, this means writing SQL, and there are a number of fundamentally annoying issues with writing SQL.
It’s hard to stay DRY
Especially at a large organization, sources of truth for data resources naturally change over time. For example, the underlying schema of your application data may change to accomodate scalability needs or you may swap out an advertising vendor, making your marketing data live in different tables from one month to the next.
Writing a complex query almost always involves writing the same chunks of code over and over again. This means if something does change (say the name of a column in a table your ETL is reading from), you may have to make the same update 10 places in a single query. This makes Python programmers feel slightly nauseous as we’re violating the don’t repeat yourself (DRY) principle.
Spaghettification and hot potatoes
As the project which depends on your ETL becomes increasingly sophisticated and data resources change, SQL files tend to have more and more logic heaped onto them, turning your once elegant code into a swirling mess of ad hoc logic. It then becomes harder and harder for new collaborators to contribute to your project as its learning curve becomes impossibly steep, whence both your code becomes spaghetti and your project becomes a hot potato. Not as tasty as it sounds!
Thinking inside of the box
Think of your development framework as a box within which you’re working. When you begin working within the framework, there is lots of room to explore and grow. As you learn more and become more capable, you start to fill out the box and incremental learnings tend to focus on minutia. If your ability and thirst for learning start to exceed the flexibility of your framework, you’ll feel stifled as you perceive your growth as the walls closing in around you.
For this reason, end-users of a database who are less concerned with its internal workings (eg. most data scientists) can become bored writing SQL and doing database work faster than they would working on, say, advanced modeling or software engineering work. If you’re not feeling inspired with what you’re working on, it’s easy to start ignoring best-practices, accelerating the advancement of spaghettification.
Getting Pythonic with Jinja2
If you’ve ever written a web application with Flask or Django, you might know a couple things about writing HTML:
- It’s a pain for a lot of the same reasons that writing SQL is: It’s usually verbose, redundant, and difficult to modularize.
- There’s an out-of-the-box solution to make it less painful and more Pythonic!
This out-of-the-box solution is the text templating mini-language jinja2
1. You can think of it as Python’s built-in str.format
on steroids:
$ python
>>> from jinja2 import Template
>>> my_items = ['apple', 'orange', 'banana']
>>> template_str = """
<ul>{% for i in my_items %}
<li>{{ i }}{% if hungry %}!{% endif %}</li>{% endfor %}
</ul>
"""
>>> print Template(template_str).render(my_items=my_items, hungry=True)
<ul>
<li>apple!</li>
<li>orange!</li>
<li>banana!</li>
</ul>
Looking at the above chunk of code, we can see jinja2
allows us to
- render for-loops with
{% for ... %}...{% endfor %}
, - implement if-else logic with
{% if ... %}...{% endif %}
, and - unpack variables with
{{ ... }}
using Python data structures. The syntax is intuitive and has friendly documentation with a wealth of additional capabilities.
Our “one weird trick” is that we can apply this tool to writing SQL, helping us to stay DRY and improve maintainability.
Example: Implementing a linear model
We’ll create a toy example to show how we can “productionize” a linear regression model typically encountered in stats text books – modeling an individual’s income based on height and other biometric features. While this use-case is unrealistically simple, it will illustrate how jinja2
can make a modeling codebase which requires ETL easy both to rapidly prototype and iterate on.
A functioning version of the finished code can be found in this GitHub repo. For additional examples, check out the Appendix section at the end of this post.
Let’s assume the following setup, which has fake data and a few realistic data-environment issues to illustrate the advantages of applying these tools in an enterprise use-case.
Setup
Our input data:
- We have a table called
customer_attr
of attributes on some customers includingincome
,height
,weight
,is_male
, and income along with a customer id. - We’ll assume all inputs are integers and that each column has some missing values.
Our goal:
- We’ll assume
income
has the most missing values so our goal will be to fill in the gaps by predicting from the other attributes. - We’ll want to write a script which will train the model on available data and write out predictions to a table called
predicted_income
.
Our environment:
- We’ll pretend that
customer_attr
is too big to fit into memory on the machine which will run our code, but we want to usescikit-learn
to create our model. We’ll use a 10% sample from the data whereincome
is not null for training. - For simplicity, we’ll assume we’re provided 2 functions:
query_to_pandas(query)
which executes a SQL query against a database and returns the output as a DataFrame.query_to_table(query, table)
which executes a SQL query and writes the output totable
in our database.
Implementation
We’ll make the first iteration of our model extremely simple: income ~ height
. To implement our model we’ll create 3 files:
predictors.json
will describe the input data for our model.template.sql
will specify the queries we use for training and evaluating our model.model.py
will allow us to execute everything in Python.
The amount of code I’ll write to run a model with a single regressor may appear unnecessarily complex. However, we’ll see payoffs later on when we update our model with additional features.
In predictors.json
we’ll give our one input variable a name and give it a number to fill missing values with.
[
{"name": "height", "fill_na": 70}
]
In template.sql
, we have a template of the queries for our model. Don’t bother trying to read this thoroughly now – I’ll provide an explanation shortly.
WITH cleaned_input as (
SELECT customer_id{% for p in predictors %}
, CASE WHEN {{ p['name'] }} IS NULL
THEN {{ p['na_value'] }}
ELSE {{ p['name'] }} END
AS {{ p['name'] }}{% endfor %}{% if train %}
, income{% endif %}
FROM customer_attr{% if train %}
WHERE customer_id % 10 = 0 -- Our 10% sample
AND income IS NOT NULL{% endif %}
)
{% if train %}
SELECT *
FROM cleaned_input
{% else %}
SELECT customer_id
, {{ coefs['intercept'] }}{% for p in predictors %}
+ {{ coefs[p['name']] }}*{{ p['name'] }}{% endfor %}
AS predicted_income
FROM cleaned_input
{% endif %}
Our Python file model.py
trains a linear model in memory using a sample of our data, then uses the coefficients to evaluate the model on the entire data set purely using SQL.
import json
from sklearn.linear_model import LinearRegression
from jinja2 import Template
from my_tools import query_to_pandas, query_to_table
with open('predictors.json') as f:
PREDICTORS = json.load(f)
with open('template.sql') as f:
QUERY_TEMPLATE = f.read()
def get_model_coefs():
""" Returns a dictionary of coefs from training """
query = Template(
QUERY_TEMPLATE
).render(
predictors=PREDICTORS,
train=True
)
print "-- Training query --"
print query
data = query_to_pandas(query)
model = LinearRegression().fit(
data[[p['name'] for p in PREDICTORS]],
data['income']
)
output = {'intercept': model.intercept_}
for i, p in enumerate(PREDICTORS):
output[p['name']] = model.coef_[i]
return output
def evaluate_model(coefs):
""" Uses coefs to evaluate a model
and write output to a table """
query = Template(
QUERY_TEMPLATE
).render(
predictors=PREDICTORS,
train=False,
coefs=coefs
)
print "-- Evaluation query --"
print query
query_to_table(query)
if __name__ == "__main__":
coefs = get_model_coefs()
evaluate_model(coefs)
Here is the output (with some annotations added):
-- Training query --
WITH cleaned_input as (
SELECT customer_id
, CASE WHEN height IS NULL -- 1
THEN 70
ELSE height END
AS height
, income
FROM customer_attr
WHERE customer_id % 10 = 0 -- 2
AND income IS NOT NULL
)
SELECT * -- 3
FROM cleaned_input
-- Evaluation query --
WITH cleaned_input as (
SELECT customer_id
, CASE WHEN height IS NULL
THEN
ELSE height END
AS height
FROM customer_attr -- 2
)
SELECT customer_id -- 3
, 32.9282024803 -- 4
+ 0.0947854710774*height
AS predicted_income
FROM cleaned_input
Looking back through template.sql
and model.py
we can walk through what happened when we rendered our template:
- The template selects all of the input features we specified and fills in missing values as needed.
- It can tell whether we’re running a query for training or evaluating and sample if necessary.
- It can choose a correct
SELECT
statement depending on whether we’re training or evaluating. - If we’re evaluating predictions, the template will unpack our regression equation.
Adding new features
To add a feature – say, is_male
– to our toy model, we only need to add a line to predictors.json
:
[
{"name": "height", "fill_na": 70},
{"name": "is_male", "fill_na": 0.5}
]
Now we’ll run models.py
again and look at the output:
-- Training query --
WITH cleaned_input as (
SELECT customer_id
, CASE WHEN height IS NULL
THEN 70
ELSE height END
AS height
, CASE WHEN is_male IS NULL
THEN 0.5
ELSE is_male END
AS is_male
, income
FROM customer_attr
WHERE customer_id % 10 = 0
AND income IS NOT NULL
)
SELECT *
FROM cleaned_input
-- Evaluation query --
WITH cleaned_input as (
SELECT customer_id
, CASE WHEN height IS NULL
THEN 70
ELSE height END
AS height
, CASE WHEN is_male IS NULL
THEN 0.5
ELSE is_male END
AS is_male
FROM customer_attr
)
SELECT customer_id
, 30.0
+ 0.1*height
+ 5.0*is_male
AS predicted_income
FROM cleaned_input
Benefits of this approach
Marvel at all of the additional query logic we got by adding a single line to our codebase! Our single new line of json produced 9 more lines of SQL. If we were to add a dozen features, our use of jinja2
would save us over 100 lines of SQL code. Because each variable is referenced only once in our codebase, making updates to input data is easy. Clearly this trick is useful for staying DRY and avoiding spaghettification.
As a bonus, we can perform our model evaluation purely in SQL! This prevents us from having to perform costly calculations in-memory (like pulling the entire customer_attr
table into memory, which we assumed was impossible). While in practice our linear model may be overly simple for the needs of most predictive tasks, it allows us to get a prototype with minimal software dependencies ready in an extremely short amount of time, so it can start producing business value.
Wrapping up
In this post we demonstrated how a web development tool, jinja2
, empowers us with a framework to put together creative solutions to database work2. This “one weird trick”, usually outside of the standard ETL toolkit, allow us to keep our code DRY and make ETL development feel more like standard Python development.
Even if you choose not to incorporate it into your workflow, there is a more important takeaway I’d like to convey: If you find a problem is redundant or just plain uninteresting, it’s far better to innovate or automate your way to a solution than slump into dispassion.
By investing time in learning things outside of your comfort zone, you may find new approaches to old problems with different perspectives. For me, working on a few web apps ended up being more valuable than I’d anticipated.
Appendix: Other novel use-cases
In our example above, we used jinja2
to manage features for a simple regression model. Here are a few other useful applications:
One-hot encoding with LIKE statements
For predictive models which learn from text, we can use text templating to engineer one-hot encoded dummy variables directly from a database.
-- Loop over words in word_list to generate columns
SELECT text_field{% for word in word_list %}
, CASE WHEN LOWER(text_field) LIKE '%{{ word }}%'
THEN 1 ELSE 0 END AS text_like_{{ word }}{% endfor %}
FROM my_table
Programmatically handling weird table names
It’s not uncommon for raw application data to have snapshots warehoused in separate tables with date-labeled names – my_table_20170101
, my_table_20170102
, etc. jinja2
makes combining data from a such collection of tables trivially simple:
-- Loop over a date range and UNION
{% for yyyymmdd in date_list %}
SELECT my_column
, '{{ yyyymmdd }}' AS as_of_date
FROM my_table_{{ yyyymmdd }}
{% if not loop.last %}UNION ALL{% endif%}{% endfor %}
Here we use the loop.last
control structure to prevent UNION ALL
from appearing one-too-many times.
Pivot tables
While Spark does have a pivot operation, performing pivots or transposes in most flavors of SQL is a non-possibility. Using jinja2
, it becomes extremely simple to perform these operations in any version of SQL. These operations in WITH
clauses and sub-queries can greatly reduce the complexity and execution time of large SQL scripts.
As an example, lets say we have a table of customer purchase records with columns customer_id
, price_paid
, and item_type
where item_type
has a small number of values. We can cleanly present the data on how much each customer spent on each item type as a pivot table:
-- Loops over item_type to generate columns
SELECT customer_id{% for t in item_type %}
, SUM(CASE WHEN item_type={{ t }}
THEN price_paid ELSE 0 END)
as {{ t }}{% endfor %}
FROM purchases
GROUP BY customer_id
Here we get a row for each customer and a column for each item type.
With this tool in our repertoire, building pipelines can feel closer to working with data scientists’ favorite data munging tools, dplyr and pandas.