Implementing a ConnectionPool in Apache Spark’s foreachPartition(..)
I was in the middle of a project. The situation, as usual, was not good at all in terms of achieving the required performance. There are about 1 billion rows of an employee table to be read from MYSQL database. I have to select some 400 millions of rows from this big table based on a filter criteria, say all employees joined in last seven years (based on a joining_num column). The client is one of the biggest in transportation industry and they have about thirty thousand offices across United States and Latin America. For each of the rows filtered from the above table, I need to connect to another table and find out the name of the department in which they initially joined, for some human resource related application. Solving this with a traditional JDBC is not an option, due to time the process is taking for fetching all records and iterating within for fetching the next step of related records. A database re-modelling is disapproved since many web and reporting components in the existing system are using these tables. A view or another copy of the combined table is not what I prefer, since proliferation of data is to be avoided, my principle in any data lake projects or transformation. And in this big data world, I have enormous compute power available to me. Hence leveraging distributed computing technology is obviously the best way to approach this problem.

I am not spending a lot of time to guide you with what can be the best solution, but a decent way of doing this is by the proper use of Apache Spark.
The Problem
Definitely journey is not smooth, if I need to concentrate on performance, efficiency, scale and of course sharable to wider audience for benefit.Let me list out the problems to be solved if we use Spark

As illustrated above, when we invoke a Spark application, we need to ensure that the actions or VERBS are happening at the worker nodes, that's where the power of distributed programming really is happening. When you read data into Spark, either by a Spark JDBC or by using the sc.textFile(…) etc., Spark creates chunks of information which are resilient. The underlying chink is actually named Resilient Distributed Datasets (RDDs), on which the abstractions like DataFrame and DataSets are built on. When you perform the Spark execution on a cluster, these RDDs are generated in memory on the worker nodes, not on driver or client. The moment you execute a collect(..) or foreach(..), all such RDDS are brought onto the driver, sometimes causing infamous OOM exception (Out Of Memory) when dealing with large data.
So what’s the big deal ?
Remember we want to iterate through each RDD on worker, by connecting to the database and fetching additional information. But..

Note that we have millions of rows coming from database as part of our first query. Each partition on worker may contains tens and millions of rows. If we create a connection and close that connection within foreachParition(..), that's expensive. Instead, if we create a Connection in driver (client) and pass it to workers, that’s not going to work. Why ?

Note that in a distributed computing cluster, the driver, the workers etc are in different nodes ( Unlike a single node pseudo cluster, which we are not discussing it here) and hence different JVMs. So if we think that we can create a java.sql.Connection object in the driver and reuse it in worker, we are wrong — this is a common mistake in Spark’s world giving you most infamous error in Spark — Task Not Serializable exception. If you have java installed, please type
serialver <<fully qualified name of your java class>>

Manly google links solve this by making the field as transient, telling Java not to serialize the field, but then your object has to be recreated on the workers, thus creating and closing the connection within each partition. Another option would be to create an explicit connection within foreachPartition — Isn’t this what we wanted to eliminate at the first place?
Serialization in Java — All you need to know
Serializable is a marker interface in Java. That means we don't have to implement any methods. Let me explain the concept with a small real world scenario. Assume that I have a house, where I live for last few years and I have electricity. I want to play a CD player in my home which uses electricity (not battery operated). I am half way through my favorite song. My friend calls me to show his newly built home, where the electric connection is not available yet — but I want to continue listening my song in that home. That’s impossible with the current media player as there is no electricity at my friend’s newly built house. If you assume both my home and my friends home as two separate JVMS (like Spark driver and worker in a distributed cluster), the state of my media player cannot be replicated. In other words, it uses a connection (electric connection, and you think now about database connection) in the base JVM which cannot be replicated, or in other words the electric connection is not serializable. This is true if you have done similar things like reading an Image from your local File system, made a socket connection, created a ResultSet from underlying TCP/IP or created a JDBC connection. That’s why most of these classes I mentioned, like java.sql.Connection, java.sql.ResultSet, java.awt.Image etc are not serializable. If JVM needs to convert your object to bytes and pass it through network, the class must implement Serializable.
Workaround: The Secret
Well, the complete discussion above applies to objects. Its applicable to instances of classes. That’s good which gives us little room for tricking the JVM. You can easily guess at this point what I mean. A static class with static methods are not serialized. Why ? static methods belong to classes, and not instances. Class is loaded per JVM, that’s true, but instances are not created or serialized. Serialization is expensive. (You can search why the later guys in distributed programming implemented Writable interface instead of Serializable) Voila! You got it. Just remember this trick to avoid any TaskNotSerializable exceptions in your code.
How will you do it Scala? In Scala, when you create an object, by default its a static class. The methods inside the Scala object are by default static methods. So lets write our code to implement a connection pool in Spark distributed programming.
The Complete Solution
We will use the famous Apache DBCP2 library for creating a connection pool. The library gives many useful methods internally to manage time to live.
First, all relevant imports:

Let’s implement the static class ( a.k.a Object in Scala)

Being an object in Scala, note that the moment worker nodes are initialized for execution, the connection pool is created. Also remember that this object is not globally used by foreachPartition as a singleton — it will be one per partition with maximum three connections. Still its much much better than creating each connection within the iterative loop, and then closing it explicitly.
Now lets use it in our Spark code.

Observe the lines from 49 to 63, where we are using foreachPartition. We are using the partitions of the dataframe that we read in line 46. At line 50, we are invoking the getDataSource method described earlier. Its now guaranteed that each partition is not using more than 3 connections (our setInitialSize value and setMaxTotal value in our first Scala object Container). We are closing only the connection that’s used after the while loop, once, to avoid starvation of connections. Additionally, though its not the focus of this article, I have used push down predicate at line 46 for the query to database. Within foreachPartition, we used PreparedStatement to avoid creation of statement object multiple times.
You may bring further optimization, by moving the statement creation outside the while loop as follows:

There is every possibility that someone can say to use broadcast variable for achieving almost the same thing. But I would prefer the method without it — since boradcast variable is basically used for passing a collection to all worker nodes and is invoked when you explicitly get the value by your code
Other than DBCP library, you may use a library like BoneCP, which may further improve the connection pooling performance.
My results: After implementing this, my 6 hours of earlier job was reduced to 1 hour 10 minutes. Please see how significantly time consuming is to create a connection within the partition code.
Any suggestions for improvement is welcome. Please connect me at linked-in.
Ravishankar Nair