Multi-Tenancy with PostgreSQL RLS

 RLS, row level security, is a postgreSQL feature to define policies to control how specific rows of data display, in essence, it's an additional filter to apply to a PostgreSQL database table.

Add discriminator column(s)

<changeSet id="add-tenancy-id">
<addColumn tableName="Test_Table">
<column name="Tenant_Id" type="VARCHAR(30)" defaultValue="0">
<constrains nullable="false" />
</column>
</addColumn>
</changeSet>

Define row level policies

<changeSet id="Test_row_level_security">
<sql dbms="postgresql" splitStatements="true">
ALTER TABLE TEST_TABLE ENABLE ROW LEVEL SECURITY;
DROP POLICY IF EXISTS TEST_tenant_isolation_policy ON TEST_TABLE;
CREATE POLICY TEST_tenant_isolation_policy ON TEST_TABLE
USING (tenant_id = current_setting('app.tenant_id')::VARCHAR;
</sql>
</changeSet>

Add a app level database user

<changeSet id="add_app_user_privileges">
<sql dbms="postgresql" splitStatements="true">
CREATE USER ${user_name} WITH PASSWORD ${password};
GRANT CONNECT ON DATABASE ${database} TO ${user_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA ${schema} GRANT SELECT, INSERT, UPDATE, DELETE, REFERENCES ON TABLES TO ${user_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA ${schema} GRANT USAGE ON SEQUENCES TO ${user_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA ${schema} GRANT EXECUTE ON FUNCTIONS TO ${user_name};
</sql>
</changeSet>

<changeSet id="grant_permission_to_app_user">
<sql dbms="postgresql" splitStatements="true">
REVOKE ALL
ON ALL TABLES IN SCHEMA ${schema}
FROM ${schema};

GRANT SELECT, INSERT, UPDATE, DELETE
ON ALL TABLES IN SCHEMA ${schema}
TO ${user_name};
</sql>
</changeSet>
the placeholders for user_name, password, etc can be defined in application.yml as bellow
spring:
liquibase:
enabled:true
changeLog: classpath:/db/changelog/db.changelog-master.xml
parameters:
database: db
schema: public
user_name: user
password: password

Data Source Configuration

  master data source: Liquibase looks for master data source for database migration needs

  tenancy data source: default data source for the Microservices.

@Bean
@ConfigurationProperties("multitenancy.master.datasource")
public DataSourceProperties masterDataSourceProperties(){
return new DataSourceProperties();
}

@Bean
@LiquibaseDataSource
@ConfigurationProperties("multitenancy.master.datasource.hikari")
public DataSource masterDataSource(DataSourceTruststoreProperties truststoreProperties) throws Exception{
HikariDataSource dataSource = masterDataSourceProperties()
.initializeDataSourceBuilder()
.type(HikariDataSource.class)
.build();
dataSource.setPoolName("masterDataSource");
if(tlsEnabled){
try (InputStream in = new Base64InputStream(new ByteArrayInputStream(truststoreProperties.getContent().getBytes(StandardCharsets.UTF_8.name())))) {
Files.copy(in, Paths.get(truststoreProperties.getLocation()), StandardCopyOption.REPLACE_EXISTING);
}
dataSource.addDataSourceProperty("javax.net.ssl.trustStore", truststoreProperties.getLocation());
dataSource.addDataSourceProperty("javax.net.ssl.trustStoreType", truststoreProperties.getType());
dataSource.addDataSourceProperty("javax.net.ssl.trustStorePassword", truststoreProperties.getPassword());
}
return dataSource;
}

@Bean
@Primary
@ConfigurationProperties("multitenancy.tenancy.datasource")
public DataSourceProperties tenancyDataSourceProperties(){
return new DataSourceProperties();
}

@Bean
@Primary
@ConfigurationProperties("multitenancy.tenancy.datasource.hikari")
public DataSource tenancyDataSource(DataSourceTruststoreProperties truststoreProperties) throws Exception{
HikariDataSource dataSource = tenancyDataSourceProperties()
.initializeDataSourceBuilder()
.type(HikariDataSource.class)
.build();
dataSource.setPoolName("tenancyDataSource");
if(tlsEnabled){
try (InputStream in = new Base64InputStream(new ByteArrayInputStream(truststoreProperties.getContent().getBytes(StandardCharsets.UTF_8.name())))) {
Files.copy(in, Paths.get(truststoreProperties.getLocation()), StandardCopyOption.REPLACE_EXISTING);
}
dataSource.addDataSourceProperty("javax.net.ssl.trustStore", truststoreProperties.getLocation());
dataSource.addDataSourceProperty("javax.net.ssl.trustStoreType", truststoreProperties.getType());
dataSource.addDataSourceProperty("javax.net.ssl.trustStorePassword", truststoreProperties.getPassword());
}
return new TenancyAwareDataSource(dataSource);
}

Tenancy aware database connection

public class TenancyAwareDataSource extends DelegationDataSource{
public TenancyAwareDataSource(DataSource targetDataSource) {
super(targetDataSource);
}

@Override
public Connection getConnection() throws SQLException {
final Connection connection = getTargetDataSource().getConnection();
setTenancyId(connection);
return getTenancyAwareConnectionProxy(connection);
}

@Override
public Connection getConnection(String username, String password) throws SQLException {
final Connection connection = getTargetDataSource().getConnection(username, password);
setTenancyId(connection);
return getTenancyAwareConnectionProxy(connection);
}

private void setTenancyId(Connection connection) throws SQLException{
try (Statement sql = connection.createStatement()) {
Map<String, String> tenancyMap = TenancyContext.getTenancyId();
if (null == tenancyMap) {
sql.execute("SET app.tenant_id TO '" + ZERO + "'");
}else{
sql.execute("SET app.tenant_id TO '" + tenancyMap.get(TENANT) + "'");
}
}
}

private void clearTenantId(Connection connection) throws SQLException {
try (Statement sql = connection.createStatement()) {
sql.execute("RESET app.tenant_id");
}
}

protected Connection getTenancyAwareConnectionProxy(Connection connection) {
return (Connection) Proxy.newProxyInstance(ConnectionProxy.class.getClassLoader(),
new Class[]{ConnectionProxy.class},
new TenancyAwareInvocationHandler(connection));
}

private class TenancyAwareInvocationHandler implements InvocationHandler {
private final Connection target;

public TenancyAwareInvocationHandler(Connection target) {
this.target = target;
}

@Nullable
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch (method.getName()) {
case "equals":
return proxy == args[0];
case "hashCode":
return System.identityHashCode(proxy);
case "toString":
return this.target.toString();
case "unwrap":
if (((Class) args[0]).isInstance(proxy)) {
return proxy;
} else {
return method.invoke(target, args);
}
case "isWrapperFor":
if (((Class) args[0]).isInstance(proxy)) {
return true;
} else {
return method.invoke(target, args);
}
case "getTargetConnection":
return target;
default:
if (method.getName().equals("close")) {
clearTenantId(target);
}
return method.invoke(target, args);
}
}
}

}

TenancyAware Interface
public interface TenancyAware {
String getTenantId();

void setTenantId(String tenantId);
}

TenancyContext
public final class TenancyContext {
private TenancyContext(){}

private static final InheritableThreadLocal<Map<String, String>> CURRENT_TENANT = new InheritableThreadLocal<>();

public static void setTenancyId(String tenancyId) {
Map<String, String> tenancyMap = new HashMap<>();
tenancyMap.put(TENANT, tenancyId);
CURRENT_TENANT.set(tenancyMap);
}

public static Map<String,String> getTennancyId() {
return CURRENT_TENANT.get();
}

public static void clear() {
CURRENT_TENANT.remove();
}
}

JPA EntityListener
public class TenancyListener {

@PreUpdate
@PreRemove
@PrePersist
public void setTenancy(TenancyAware entity) {
final Map<String, String> map = TenancyContext.getTennancyId();
entity.setTenantId(map.get(TENANT));
}
}

TenancyAwareBaseEntity
@Getter
@Setter
@NoArgsConstructor
@SuperBuilder
@MappedSuperClass
@EntityListeners(TenancyListener.class)
public abstract class TenancyAwareBaseEntity implements TenancyAware, Serializable {

@Column(name = "TENANT_ID")
private String tenantId;
...
}

Capture Tenancy from incoming request
@Component
public class TenancyInterceptor implements WebRequestInterceptor {
private final String defaultTenant;

public TenancyInterceptor(@Value("${multitenancy.tenant.default-tenant}:#{null}")String defaultTenant) {
this.defaultTenant = defaultTenant;
}

@Override
public void preHandle(WebRequest request) {
String tenantId = ZERO;
if (request.getHeader(X_TENANT_ID) != null) {
tenantId = request.getHeader(X_TENANT_ID);
}
TenancyContext.setTenancyId(tenantId);
}

public void postHandle(WebRequest request, ModelMap modelMap) {
TenancyContext.clear();
}
}

@Configuration
public class WebConfiguration implements WebMvcConfigurer {
private final TenancyInterceptor tenancyInterceptor;

public WebConfiguration(TenancyInterceptor tenancyInterceptor) {
this.tenancyInterceptor = tenancyInterceptor;
}

@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addWebRequestInterceptor(tenancyInterceptor);
}
}




Comments