avatarZach Quinn

Summary

The provided content outlines a step-by-step process for refactoring an existing Python ETL pipeline that fetches data from the Reddit API, with a focus on improving code readability, conciseness, and maintainability.

Abstract

The article details the refactoring of a Python script used for extracting, transforming, and loading (ETL) data from Reddit's news subreddits into a BigQuery database. Initially, the script is characterized by redundancy and lack of modularity. The refactoring process involves breaking down the code into distinct functions, such as token retrieval, data request, data frame formatting, and BigQuery data loading. The author emphasizes the importance of logging, config file usage, and idempotency to enhance the script's functionality and reliability. By applying these improvements, the code is reduced from 333 to approximately 175 lines, resulting in a more efficient and maintainable ETL pipeline. The refactored script is designed to be more readable and concise, which is crucial for both current and future developers who may work on the project.

Opinions

  • The author believes that refactoring is essential for maintaining and improving the quality of existing code, especially in the context of data engineering projects.
  • The author values the iterative nature of programming and the continuous improvement of code, suggesting that code should be revisited and revised periodically.
  • There is an emphasis on the importance of modular design and the use of functions to avoid code repetition and improve code reusability.
  • The author advocates for the use of logging to provide clear metadata about program execution, which aids in troubleshooting and transparency.
  • The use of a configuration file is recommended for storing variables and parameters, which helps to keep the main script clean and manageable.
  • The concept of idempotency is highlighted as a best practice to ensure that the ETL process can be run multiple times without affecting the integrity of the data.
  • The author suggests that organizing code into smaller, purpose-driven functions not only simplifies the script but also makes it easier to adapt to future changes or requirements.

Refactoring A Python ETL Pipeline (With Example)

My step-by-step process for revising an existing Python pipeline fetching data from the Reddit API.

Create a job-worthy data portfolio. Learn how with my free project guide.

Refactoring is like renovation for your code. Photo by Jørgen Larsen on Unsplash.

Revise and Refactor Your Python ETL Pipelines

Revisiting a Python script you’ve written months or years ago is like discovering an assignment from high school: You cringe at what you should have known and you marvel at how far you’ve come.

If you’ve had any experience with coding, particularly in Python, you’ll know that programming is an iterative process.

The best problems are solved incrementally and the best code is written and rewritten a bit at a time.

Unlike your high school term paper or a college thesis gathering dust in the attic, production scripts can and should periodically be revisited and revised to make the code more concise, readable and configurable.

This practice is a software engineering process known as refactoring.

The point of refactoring is to maintain the functionality of a piece of code while improving its readability.

While you can refactor a script in a number of different ways, I’ll share my process for revising a simple ETL pipeline that extracts, transforms and loads data from the Reddit API, with a focus on news subreddits.

The Original Script

To get the most out of this walkthrough, I suggest you take a moment to review the code.

One of the biggest problems with this particular Python script is its redundancy, so these operations are repeated when they should be encoded into separate functions.

We’ll get to that in a second.

Behold:

import pandas as pd 
import requests
import os 
import pydata_google_auth
from google.oauth2 import service_account
from google.cloud import language
import json 
from google.cloud import storage 
from google.cloud import bigquery

# Main function. 

def get_today_news():
    
    # Authenticate to GCP. 
  
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"]=''

    # Make initial request to Reddit API to return a token.

    auth = requests.auth.HTTPBasicAuth('CLIENT_ID', 'SECRET_TOKEN')
    data = {
    'grant_type': 'client_credentials',
    'username': 'username',
    'password': 'password'
    }
    headers = {'User-Agent': 'News/0.0.1'}
    request = requests.post('https://www.reddit.com/api/v1/access_token', auth=auth, data=data, headers=headers)
    token = request.json()['access_token']
    headers = {**headers, **{'Authorization': f"bearer {token}"}}
    
    # Make the request to the desired subreddit: r/news. 

    news_requests = requests.get('https://oauth.reddit.com/r/news/hot', headers=headers, params={'limit': '100'})

    # Transformations. 

    # Loop through the returned JSON body to create our columns.

    df = pd.DataFrame()
    for post in news_requests.json()['data']['children']:
        df = df.append({
        'title': post['data']['title'],
        'upvote_ratio': post['data']['upvote_ratio'],
        'score': post['data']['score'],
        'ups': post['data']['ups'],
        'domain': post['data']['domain'],
        'num_comments': post['data']['num_comments']
    }, ignore_index=True)
    
    # Establish a BigQuery client. 

    client = bigquery.Client()
    news_dataset_id = 'reddit_news'
    news_table_id = 'r_news'
    
    news_ref = client.dataset(news_dataset_id)
    news_table_id = news_ref.table(news_table_id)
    
    # Configure the load job.

    job_config = bigquery.LoadJobConfig()
    job_config.write_disposition='WRITE_TRUNCATE'
    job_config.source_format = bigquery.SourceFormat.CSV
    job_config.autodetect=True
    job_config.ignore_unknown_values=True 

    # Load the data to the "r_news" table. 

    job = client.load_table_from_dataframe(
    df,
    news_table_id,
    location='US',
    job_config=job_config)

    job.result()

    # Print a message when complete. 

    print('News table loaded.')

This is a very straight forward example of an ETL pipeline.

Functionally, it really only does 3 things:

  • Gets data from Reddit
  • Parses the JSON to create a data frame
  • Uploads the clean data as a data frame to BigQuery

If we were only hitting one end point, I’d leave this script alone.

The problem is that I make 8 more requests copy/pasting the code.

Here are the refactoring steps I’ll need to take to make this more reflective of the code I write as a more seasoned engineer.

  • Break repeated steps into functions that can be called multiple times
  • Add logging so we know when each step begins and ends
  • Define functions like the BigQuery client globally
  • Move request body and table parameters to config file
  • Change the BigQuery job configuration to WRITE_APPEND to retain historical data
  • Create a main function to call our newly created functions
  • Add an if __name__ block so this file can be run as a script

Pardon the interruption: For more Python, SQL and cloud computing walkthroughs, follow Pipeline: Your Data Engineering Resource.

To receive my latest writing, you can follow me as well.

Breaking Up

At first glance, I can see that I can break steps like the API calls, data frame creation and BigQuery load into functions that can be called by a main() function instead of repeating the same step (*counts again*) 9 times.

Let’s start backward with the function to load to BigQuery.

def bq_load(table_id: str, df: pd.DataFrame):
    
    table_ref = bq_client.dataset("reddit_news")
    table_id = table_ref.table(table_id)
    
    job_config = bigquery.LoadJobConfig()
    job_config.write_disposition='WRITE_APPEND'
    job_config.source_format = bigquery.SourceFormat.CSV
    job_config.schema = cfg.schema

    job = bq_client.load_table_from_dataframe(
    df,
    table_id,
    location='US',
    job_config=job_config)
    
    return job.result()

Unlike the prior code, the bq_load function now allows for the developer to input the desired table and data frame that will be uploaded to BigQuery.

More on that process below.

Next, we’ll create a function for getting the access token returned from the Reddit API.

def get_reddit_token():

    auth = requests.auth.HTTPBasicAuth('********', '********')
    data = {
    'grant_type': 'client_credentials',
    'username': cfg.user,
    'password': cfg.password
    }
    headers = {'User-Agent': 'News/0.0.1'}
    request = requests.post(cfg.base_access_url, auth=auth, data=data, headers=headers)
    token = request.json()['access_token']
    headers = {**headers, **{'Authorization': f"bearer {token}"}}
    
    return headers

In order for the next function, which will make the request, to be functional, I’ll store the reddit token in a “headers” variable.

def make_request(url: str):
    
    headers = get_reddit_token()

    request = requests.get(url, headers=headers, params={'limit': '100'})
    
    return request

The final function takes care of the fun stuff: The data frame generation and formatting.

def format_df(end_point):
    
    df = pd.DataFrame()
    
    for post in end_point.json()['data']['children']:
        df = df.append({
        'title': post['data']['title'],
        'upvote_ratio': post['data']['upvote_ratio'],
        'score': post['data']['score'],
        'ups': post['data']['ups'],
        'domain': post['data']['domain'],
        'num_comments': post['data']['num_comments']
    }, ignore_index=True)
        
    return df

The format data frame function takes the request made to a given end point and iterates through each bit of the JSON body to create columns in the newly established data frame.

This eliminates the need to repeat this operation later in the script.

Main

The main() function calls these helper functions to create the “meat” of this script, consisting of the request, formatting and loading steps, which we can now reduce to three lines.

def main():
  r_news = make_request("https://oauth.reddit.com/r/news/hot")
  r_news_df = format_df(r_news)
  bq_load("r_news", r_news_df)

  return "r_news table loaded." 

This code is the essence of refactoring: Reducing redundancy and making code more readable.

However, there‘s still a crucial component missing from this script.

Logging

Readability doesn’t just concern written code.

It’s also important to create legible and transparent metadata that we can use to confirm program execution and, if necessary, to troubleshoot.

To do this, we create and include logs in Python scripts.

import logging

logging.basicConfig(format='%(asctime)s %(message)s',level=logging.INFO)

logging.info("This is an INFO level logging entry.")

While more logs are better than less, we want to be strategic with placement.

I typically include a logging statement at the beginning of every major operation within a script. You can also note the conclusion of a function.

from datetime import date
import logging

current_dt = date.today()

logging.info(f"Getting data from Reddit API for {current_dt}")
    
logging.info("Getting data for r/news...")

r_news = make_request("https://oauth.reddit.com/r/news/hot")
r_news_df = format_df(r_news)

logging.info("Loading to BigQuery...")

bq_load("r_news", r_news_df)

logging.info(f"r_news table is updated as of {current_dt}.")

Make sure every log you write is concise and precise so you know if a script is executing and you can identify breakage points if it’s not functioning.

Config

Another easy fix for readability is to move your string variables to a config file and then reference that file in your main script.

This is especially helpful for list values like column names and schemas.

Below, you can see the config file for this script, “news_config.”

# Reddit News Configuration File

from google.cloud import bigquery

user = '******'
password = '******'

base_access_url = 'https://www.reddit.com/api/v1/access_token'

r_news = "https://oauth.reddit.com/r/news/hot"
not_the_onion = "https://oauth.reddit.com/r/nottheonion/hot"
offbeat = "https://oauth.reddit.com/r/offbeat/hot"
the_news = "https://oauth.reddit.com/r/thenews/hot"
us_news = "https://oauth.reddit.com/r/USNews/hot"
full_news = "https://oauth.reddit.com/r/Full_news/hot"
quality_news = "https://oauth.reddit.com/r/qualitynews/hot"
uplifting_news = "https://oauth.reddit.com/r/upliftingnews"
in_the_news = "https://oauth.reddit.com/r/inthenews"

tables = ["r_news", "not_the_onion", "offbeat", "the_news", "us_news", "full_news", "quality_news"]

schema = [bigquery.SchemaField("title", "STRING"),
          bigquery.SchemaField("upvote_ratio", "FLOAT"),
          bigquery.SchemaField("score", "FLOAT"),
          bigquery.SchemaField("ups", "FLOAT"),
          bigquery.SchemaField("domain", "STRING"),
          bigquery.SchemaField("num_comments", "FLOAT"),
          bigquery.SchemaField("dt_updated", "TIMESTAMP")
         ]

It’s important to prioritize what you include in a config file and not create a file that is just as cluttered as the main script.

For instance, if you have a single string value, you most likely don’t need to include it in the config file.

Here, I could have just as easily included the full URL string in the make_request() function. It’s a matter of preference.

Idempotency

Another opportunity you can look for when refactoring an ETL script is to ensure your code is idempotent, or that the result doesn’t change no matter how many times the script runs.

For instance, if I ran this code, which appends today’s data to a BigQuery table, more than once, I would end up with duplicated values.

To avoid this possibility, I can simply ensure I’m deleting any data from a prior run so, at the end of the day, I only have one set of data for a given day.

Since this data doesn’t include a date field, I’ll add one using a timestamp with today’s date and time of request.

data_frames = [r_news_df, nto_df, offbeat_df, the_news_df, us_news_df, full_news_df, quality_news_df,
                  uplifting_news_df, in_the_news_df]
    
for df in data_frames:
    df["dt_updated"] = pd.Timestamp.today()

I’m simply iterating through all of the data frames that I’ve created and am adding this “dt_updated” field to each.

When I conduct the deletion, I want to make sure I’m deleting from every table, so I can iterate through a list of tables (stored in the config file).

for tab in cfg.tables:
  bq_client.query(

""" 
DELETE FROM `*********.reddit_news."""+tab+"` WHERE dt_updated >= CURRENT_DATE('America/New_York')"

)

Finally, since we’ve created all of the data frames in the context of the script, we can conduct a single batch load job.

for tabs, dfs in zip(cfg.tables, data_frames):
        bq_load(tabs, dfs)

Full Script

After completing each refactoring step, we’re left with a more legible and concise script.

In the process, we’ve reduced 333 lines of code to approximately 175.

Here’s what it looks like:

import pandas as pd 
import requests
import os 
import pydata_google_auth
from google.oauth2 import service_account
from google.cloud import language
import json 
from google.cloud import storage 
from google.cloud import bigquery
import news_config as cfg
import logging
from datetime import date

os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="*************************"

logging.basicConfig(format='%(asctime)s %(message)s',level=logging.INFO)

current_dt = date.today()

bq_client = bigquery.Client()

def bq_load(table_id: str, df: pd.DataFrame):
    
    table_ref = bq_client.dataset("reddit_news")
    table_id = table_ref.table(table_id)
    
    job_config = bigquery.LoadJobConfig()
    job_config.write_disposition='WRITE_APPEND'
    job_config.source_format = bigquery.SourceFormat.CSV
    job_config.schema = cfg.schema

    job = bq_client.load_table_from_dataframe(
    df,
    table_id,
    location='US',
    job_config=job_config)
    
    return job.result()

def get_reddit_token():

    auth = requests.auth.HTTPBasicAuth('***********', '*************')
    data = {
    'grant_type': 'client_credentials',
    'username': cfg.user,
    'password': cfg.password
    }
    headers = {'User-Agent': 'News/0.0.1'}
    request = requests.post(cfg.base_access_url, auth=auth, data=data, headers=headers)
    token = request.json()['access_token']
    headers = {**headers, **{'Authorization': f"bearer {token}"}}
    
    return headers

def make_request(url: str):
    
    headers = get_reddit_token()

    request = requests.get(url, headers=headers, params={'limit': '100'})
    
    return request

def format_df(end_point: str):
    
    df = pd.DataFrame()
    
    for post in end_point.json()['data']['children']:
        df = df.append({
        'title': post['data']['title'],
        'upvote_ratio': post['data']['upvote_ratio'],
        'score': post['data']['score'],
        'ups': post['data']['ups'],
        'domain': post['data']['domain'],
        'num_comments': post['data']['num_comments']
    }, ignore_index=True)
        
    return df

def main():
    
    logging.info("Getting data for r/news...")
    
    # r/news.
    
    r_news = make_request(cfg.r_news)
    r_news_df = format_df(r_news)
    
    logging.info(f"r/news loaded successfully for {current_dt}. Getting data for r/nottheonion...")
    
    # r/nottheonion.
    
    not_the_onion = make_request(cfg.not_the_onion)
    nto_df = format_df(not_the_onion)
    
    logging.info(f"r/nottheonion loaded successfully for {current_dt}. Getting data for r/offbeat...")
    
    # r/offbeat.
    
    offbeat = make_request(cfg.offbeat)
    offbeat_df = format_df(offbeat)
    
    logging.info(f"r/offbeat loaded successfully for {current_dt}. Getting data for r/thenews...")
    
    # r/thenews.
    
    the_news = make_request(cfg.the_news)
    the_news_df = format_df(the_news)
    
    logging.info(f"r/thenews loaded successfully for {current_dt}. Getting data for r/USNews...")
    
    # r/USNews.
    
    us_news = make_request(cfg.us_news)
    us_news_df = format_df(us_news)
    
    logging.info(f"r/USNews loaded successfully for {current_dt}. Getting data for r/Full_news...")
    
    # r/Full_news
    
    full_news = make_request(cfg.full_news)
    full_news_df = format_df(full_news)
    
    logging.info(f"r/Full_news loaded successfully for {current_dt}. Getting data for r/quality_news...")
    
    # r/quality_news
    
    quality_news = make_request(cfg.quality_news)
    quality_news_df = format_df(quality_news)
    
    logging.info(f"r/quality_news loaded successfully for {current_dt}. Getting data for r/upliftingnews...")
    
    # r/uplifting_news
    
    uplifting_news = make_request(cfg.uplifting_news)
    uplifting_news_df = format_df(uplifting_news)
    
    logging.info(f"r/uplifting_news loaded successfully for {current_dt}. Getting data for r/inthenews...")
    
    # r/inthenews
    
    in_the_news = make_request(cfg.in_the_news)
    in_the_news_df = format_df(in_the_news)
    
    logging.info(f"r/inthenews loaded successfully. Creating dt_updated column...")
    
    data_frames = [r_news_df, nto_df, offbeat_df, the_news_df, us_news_df, full_news_df, quality_news_df,
                  uplifting_news_df, in_the_news_df]
    
    for df in data_frames:
        df["dt_updated"] = pd.Timestamp.today()
    
    logging.info(f"Deleting data for {current_dt}...")
                 
    # Delete data from today.
    
    for tab in cfg.tables:
        bq_client.query(
        """ 
        DELETE FROM `************.reddit_news."""+tab+"` WHERE dt_updated >= CURRENT_DATE('America/New_York')"
        
        )
    
    # Load all at once.
    
    logging.info("Loading to BigQuery...")
                 
    for tabs, dfs in zip(cfg.tables, data_frames):
        bq_load(tabs, dfs)
    
    logging.info(f"All tables loaded successfully for {current_dt}")

if __name__ == "__main__":
  logging.info(f"Getting data from Reddit API for {current_dt}")
  main()

If I wanted to reduce the amount of lines further, I could concatenate the data frames and load to a single table.

However, if I were to receive a request from a stakeholder to retrieve data from different end points, it would make more sense to store each end point’s data in a separate table.

I hope you can see the value of not only revisiting your work months or years down the line, but also the value in applying present knowledge to past builds.

It’ll help future developers better understand your code and will help you understand just how far you’ve come.

Data Engineering
Data Analysis
Data Science
Python
Learning To Code
Recommended from ReadMedium