avatarRavishankar Nair

Summary

The author discusses the implementation of a ConnectionPool in Apache Spark's foreachPartition(..) method to efficiently handle large-scale data processing tasks involving database operations without causing performance bottlenecks or serialization issues.

Abstract

The article details the author's approach to a significant performance challenge involving the processing of approximately 400 million rows from an employee table in a MYSQL database. The author's task was to filter this data and then enrich it with department names from another table, a process that was inefficient using traditional JDBC methods. The author explains why database re-modeling or creating views was not feasible and why leveraging Apache Spark's distributed computing capabilities was the preferred solution. The problem was compounded by the need to maintain the state of database connections across distributed Spark workers without causing serialization issues or Out Of Memory (OOM) errors. The author describes the use of Apache DBCP2 library to create a static ConnectionPool class in Scala, which avoids serialization of connection objects and optimizes the use of database connections across Spark partitions. This implementation significantly reduced the processing time from 6 hours to 1 hour and 10 minutes.

Opinions

  • The author believes that using Apache Spark's distributed computing framework is the best approach for handling large-scale data processing tasks.
  • Creating and closing a database connection within each partition of a Spark job is considered expensive and inefficient.
  • The author emphasizes the importance of understanding Java serialization to avoid common pitfalls in Spark, such as the TaskNotSerializable exception.
  • Static methods and classes in Scala are recommended for managing resources like database connections in a Spark cluster to avoid serialization overhead.
  • The author suggests that using a connection pool is a more efficient solution than creating a new connection for each partition, and it also helps to prevent connection starvation.
  • The author values performance optimization, as evidenced by the use of pushdown predicates and prepared statements within the Spark job.
  • The author is open to suggestions for further improvements and invites feedback from the community.

Implementing a ConnectionPool in Apache Spark’s foreachPartition(..)

I was in the middle of a project. The situation, as usual, was not good at all in terms of achieving the required performance. There are about 1 billion rows of an employee table to be read from MYSQL database. I have to select some 400 millions of rows from this big table based on a filter criteria, say all employees joined in last seven years (based on a joining_num column). The client is one of the biggest in transportation industry and they have about thirty thousand offices across United States and Latin America. For each of the rows filtered from the above table, I need to connect to another table and find out the name of the department in which they initially joined, for some human resource related application. Solving this with a traditional JDBC is not an option, due to time the process is taking for fetching all records and iterating within for fetching the next step of related records. A database re-modelling is disapproved since many web and reporting components in the existing system are using these tables. A view or another copy of the combined table is not what I prefer, since proliferation of data is to be avoided, my principle in any data lake projects or transformation. And in this big data world, I have enormous compute power available to me. Hence leveraging distributed computing technology is obviously the best way to approach this problem.

High level requirement that I am trying to solve

I am not spending a lot of time to guide you with what can be the best solution, but a decent way of doing this is by the proper use of Apache Spark.

The Problem

Definitely journey is not smooth, if I need to concentrate on performance, efficiency, scale and of course sharable to wider audience for benefit.Let me list out the problems to be solved if we use Spark

This is how you should manage your code in Spark when dealing with large datasets

As illustrated above, when we invoke a Spark application, we need to ensure that the actions or VERBS are happening at the worker nodes, that's where the power of distributed programming really is happening. When you read data into Spark, either by a Spark JDBC or by using the sc.textFile(…) etc., Spark creates chunks of information which are resilient. The underlying chink is actually named Resilient Distributed Datasets (RDDs), on which the abstractions like DataFrame and DataSets are built on. When you perform the Spark execution on a cluster, these RDDs are generated in memory on the worker nodes, not on driver or client. The moment you execute a collect(..) or foreach(..), all such RDDS are brought onto the driver, sometimes causing infamous OOM exception (Out Of Memory) when dealing with large data.

So what’s the big deal ?

Remember we want to iterate through each RDD on worker, by connecting to the database and fetching additional information. But..

When or Where should we create a connection ?

Note that we have millions of rows coming from database as part of our first query. Each partition on worker may contains tens and millions of rows. If we create a connection and close that connection within foreachParition(..), that's expensive. Instead, if we create a Connection in driver (client) and pass it to workers, that’s not going to work. Why ?

Unless objects are written and loaded back or they are not serialized, we cant load then in other JVM!

Note that in a distributed computing cluster, the driver, the workers etc are in different nodes ( Unlike a single node pseudo cluster, which we are not discussing it here) and hence different JVMs. So if we think that we can create a java.sql.Connection object in the driver and reuse it in worker, we are wrong — this is a common mistake in Spark’s world giving you most infamous error in Spark — Task Not Serializable exception. If you have java installed, please type

serialver  <<fully qualified name of your java class>>

Manly google links solve this by making the field as transient, telling Java not to serialize the field, but then your object has to be recreated on the workers, thus creating and closing the connection within each partition. Another option would be to create an explicit connection within foreachPartition — Isn’t this what we wanted to eliminate at the first place?

Serialization in Java — All you need to know

Serializable is a marker interface in Java. That means we don't have to implement any methods. Let me explain the concept with a small real world scenario. Assume that I have a house, where I live for last few years and I have electricity. I want to play a CD player in my home which uses electricity (not battery operated). I am half way through my favorite song. My friend calls me to show his newly built home, where the electric connection is not available yet — but I want to continue listening my song in that home. That’s impossible with the current media player as there is no electricity at my friend’s newly built house. If you assume both my home and my friends home as two separate JVMS (like Spark driver and worker in a distributed cluster), the state of my media player cannot be replicated. In other words, it uses a connection (electric connection, and you think now about database connection) in the base JVM which cannot be replicated, or in other words the electric connection is not serializable. This is true if you have done similar things like reading an Image from your local File system, made a socket connection, created a ResultSet from underlying TCP/IP or created a JDBC connection. That’s why most of these classes I mentioned, like java.sql.Connection, java.sql.ResultSet, java.awt.Image etc are not serializable. If JVM needs to convert your object to bytes and pass it through network, the class must implement Serializable.

Workaround: The Secret

Well, the complete discussion above applies to objects. Its applicable to instances of classes. That’s good which gives us little room for tricking the JVM. You can easily guess at this point what I mean. A static class with static methods are not serialized. Why ? static methods belong to classes, and not instances. Class is loaded per JVM, that’s true, but instances are not created or serialized. Serialization is expensive. (You can search why the later guys in distributed programming implemented Writable interface instead of Serializable) Voila! You got it. Just remember this trick to avoid any TaskNotSerializable exceptions in your code.

How will you do it Scala? In Scala, when you create an object, by default its a static class. The methods inside the Scala object are by default static methods. So lets write our code to implement a connection pool in Spark distributed programming.

The Complete Solution

We will use the famous Apache DBCP2 library for creating a connection pool. The library gives many useful methods internally to manage time to live.

First, all relevant imports:

Relevant Imports for Our ConnectionPool

Let’s implement the static class ( a.k.a Object in Scala)

Our ConnectionPool as a Scala Object with getDataSource method

Being an object in Scala, note that the moment worker nodes are initialized for execution, the connection pool is created. Also remember that this object is not globally used by foreachPartition as a singleton — it will be one per partition with maximum three connections. Still its much much better than creating each connection within the iterative loop, and then closing it explicitly.

Now lets use it in our Spark code.

The complete code

Observe the lines from 49 to 63, where we are using foreachPartition. We are using the partitions of the dataframe that we read in line 46. At line 50, we are invoking the getDataSource method described earlier. Its now guaranteed that each partition is not using more than 3 connections (our setInitialSize value and setMaxTotal value in our first Scala object Container). We are closing only the connection that’s used after the while loop, once, to avoid starvation of connections. Additionally, though its not the focus of this article, I have used push down predicate at line 46 for the query to database. Within foreachPartition, we used PreparedStatement to avoid creation of statement object multiple times.

You may bring further optimization, by moving the statement creation outside the while loop as follows:

Further optimization to prepare Statement only once

There is every possibility that someone can say to use broadcast variable for achieving almost the same thing. But I would prefer the method without it — since boradcast variable is basically used for passing a collection to all worker nodes and is invoked when you explicitly get the value by your code

Other than DBCP library, you may use a library like BoneCP, which may further improve the connection pooling performance.

My results: After implementing this, my 6 hours of earlier job was reduced to 1 hour 10 minutes. Please see how significantly time consuming is to create a connection within the partition code.

Any suggestions for improvement is welcome. Please connect me at linked-in.

Ravishankar Nair

Spark
Connectionpool
Foreachpartition
Ravishankar Nair
Connection
Recommended from ReadMedium