Logo
Published on

Sharding in Spring Boot for Relational DB

Authors
  • Name
    Twitter

Introduction

Our microservice project consisted of various services. One of the service was related to storing data related to user. We used spring boot as our backend and used k8s in GCP. We used mysql db in GCP.

Requirements

As our user base grew our performance was affected and latency went up. As we had a single database and single table many queries and updates returned error due to lock exceptions. Moreover the size of our database grew which further decreased our performance. So a solution was required to handle ever increasing user base.

Solution

Table Sharding

Our hash function was simply mod 10 of user_id:

Our first approach was to create multiple similar tables in a single database and use user_id for sharding key.
We created 10 copies of each table wherever user_id column was present. So two changes were required in code. First to get user_id in request and second to replace the table name in query generated by hibernate.

Regarding first change it was easy to get user_id as we were already getting user_id in our request header.

For second change we extended EmptyInterceptor class of hibernate and overrided onPrepareStatement method which is Called when sql string is being prepared.

This method has one string parameter which is a sql statement. This sql statement contains table name also. So here based on user_id present in our request header we simply replace the table name with the desired table name.

For example, if we had user_id 77. We took it’s mod 10 which is 7 and replaced table name (user_profile) with user_profile_7 as we had already created 10 copies in our database. Below code is for the class that extended EmptyInterceptor. Important note if you are using spring boot 3 then EmptyInterceptor is deprecated there and you can use StatementInspector interface there and override inspect method and move your logic from onPrepareStatement method to inspect method.

public class DynamicTableNameSharding extends EmptyInterceptor {  
    @Override  
    public String onPrepareStatement(String sql) {  
        // replace table name if sharding enabled  
        if (Boolean.parseBoolean(DatabaseEnvironment.TABLE_SHARDING_ENABLED.label)) {  
            for (String tableName : SHARDED_TABLES) {  
                if(sql.contains(tableName)) {  
                    ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();  
                    String shardingNumber = getSharding(attr);  
                    sql = sql.replace(tableName, tableName + shardingNumber);  
                    // do not use break here as a query can contain multiple tables so we need to change names of all tables whose sharding is enabled  
                }  
            }  
        }  
        return super.onPrepareStatement(sql);  
    }  
}

In above function SHARDED_TABLES is a list of tables for which sharding is enabled. getSharding method returns a sharding number based on user id passed in header. As we had multiple tables in one single query in case of joins or complex logic, we used for loop to correctly replace all tables present in query.

We were also using JOOQ for some operations so we also used something similar there by extending DefaultVisitListener class.

Database Sharding

Multiple db servers:

This solved our problem but still there was room for imrovement so we decided to go a step further and shard our DBs also. Similar to creating copies of tables, we created copies of our database servers/instances. So we had now 10 databases servers running up each had 10 copies of a table. So in total we had 100 copies of a table.

So having 10 database servers up also required routing of query to correct database.

First of all in our spring boot application we created 10 data sources with different database urls. Now we needed a way to route a db connection to correct datasource. For that we used AbstractRoutingDataSource which is Abstract DataSource implementation that routes getConnection() calls to one of various target DataSources based on a lookup key. Then we overrided this method - determineCurrentLookupKey.

So this method returns a key which identifies a particular datasource out of 10 data source that we have defined. So we also changed our logic for determining table and database a bit. We used ones unit digit to identify database server and used tens digit number to identify table. For eg a user_id 447 will be routed to 7th db server and 4th copy of table in that database server. So in this way we had 100 tables across 10 database servers, this improved our performance a lot.

Miscellaneous

We used dbngin for local testing as with this we can have multiple db servers up in our localhost on different ports. Homebrew formula below:

brew install --cask dbngin

Also if you have existing data then you might have to do data migration to put data in correct table and correct database.

Also there was no issue with connection pooling as we had 10 Hikari datasource and pooling was handled by Hikari separately for each datasource.

Conclusion

We have used both table and database sharding. We could further improve it by having more databases in a single server possibly total 1000 copies of a table. This may not give proportional performance but it will at least reduce locking issues and will be very helpful in concurrent scenario. Please comment if you have any query. Thanks.