From 65ce2ea6b1fcd05c4d0f04ff10355891ae50a54c Mon Sep 17 00:00:00 2001 From: Aleksei Zinovev Date: Wed, 6 Nov 2024 17:10:15 +0100 Subject: [PATCH] Implemented custom SQL DB registration (#917) * Add optional dbType parameter to JDBC read functions This commit introduces an optional `dbType` parameter to various JDBC read functions. It allows for specifying the database type directly, enhancing flexibility and control over query execution. This change ensures backward compatibility by defaulting to type inference when `dbType` is not provided. * Add support for custom DB types in schema generation This update introduces the ability to specify custom database types in the data schema generation process. The changes include the addition of a `dbTypeClassName` field and modifications to pass `DbType` to relevant functions. This enhances flexibility in handling various database types beyond the default configurations. * Refactor database type handling with ServiceLoader This update refactors the DataSchemaGenerator to use ServiceLoader for loading DbTypeProvider and removes hard-coded class loading. It also changes the annotation parameter from dbTypeClassName to dbTypeKClass for better type safety. * Reverted support custom DB for plugins * Added tests for customDB for limited number of cases and API testing * Add dbType parameter to readSqlTable function * Update Javadoc to clarify automatic database type recognition --- .../jetbrains/kotlinx/dataframe/io/db/H2.kt | 2 +- .../kotlinx/dataframe/io/readJdbc.kt | 188 +++++++++++------- .../kotlinx/dataframe/io/h2/h2Test.kt | 132 +++++++++++- 3 files changed, 249 insertions(+), 73 deletions(-) diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt index 1e4172071..bd623046f 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt @@ -15,7 +15,7 @@ import kotlin.reflect.KType * * NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types. */ -public class H2(public val dialect: DbType = MySql) : DbType("h2") { +public open class H2(public val dialect: DbType = MySql) : DbType("h2") { init { require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index 75a101e8e..bb47209d1 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -113,6 +113,8 @@ public data class DbConnectionConfig(val url: String, val user: String = "", val * @param [tableName] the name of the table to read data from. * @param [limit] the maximum number of rows to retrieve from the table. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [dbConfig]. * @return the DataFrame containing the data from the SQL table. */ public fun DataFrame.Companion.readSqlTable( @@ -120,9 +122,10 @@ public fun DataFrame.Companion.readSqlTable( tableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlTable(connection, tableName, limit, inferNullability) + return readSqlTable(connection, tableName, limit, inferNullability, dbType) } } @@ -133,6 +136,8 @@ public fun DataFrame.Companion.readSqlTable( * @param [tableName] the name of the table to read data from. * @param [limit] the maximum number of rows to retrieve from the table. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [connection]. * @return the DataFrame containing the data from the SQL table. * * @see DriverManager.getConnection @@ -142,12 +147,13 @@ public fun DataFrame.Companion.readSqlTable( tableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { val url = connection.metaData.url - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) val selectAllQuery = if (limit > 0) { - dbType.sqlQueryLimit("SELECT * FROM $tableName", limit) + determinedDbType.sqlQueryLimit("SELECT * FROM $tableName", limit) } else { "SELECT * FROM $tableName" } @@ -157,7 +163,7 @@ public fun DataFrame.Companion.readSqlTable( st.executeQuery(selectAllQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) + return fetchAndConvertDataFromResultSet(tableColumns, rs, determinedDbType, limit, inferNullability) } } } @@ -172,6 +178,8 @@ public fun DataFrame.Companion.readSqlTable( * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [dbConfig]. * @return the DataFrame containing the result of the SQL query. */ public fun DataFrame.Companion.readSqlQuery( @@ -179,9 +187,10 @@ public fun DataFrame.Companion.readSqlQuery( sqlQuery: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlQuery(connection, sqlQuery, limit, inferNullability) + return readSqlQuery(connection, sqlQuery, limit, inferNullability, dbType) } } @@ -195,6 +204,8 @@ public fun DataFrame.Companion.readSqlQuery( * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [connection]. * @return the DataFrame containing the result of the SQL query. * * @see DriverManager.getConnection @@ -204,22 +215,23 @@ public fun DataFrame.Companion.readSqlQuery( sqlQuery: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " + "Also it should not contain any separators like `;`." } - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) - val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery + val internalSqlQuery = if (limit > 0) determinedDbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery logger.debug { "Executing SQL query: $internalSqlQuery" } connection.createStatement().use { st -> st.executeQuery(internalSqlQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) + return fetchAndConvertDataFromResultSet(tableColumns, rs, determinedDbType, limit, inferNullability) } } } @@ -233,12 +245,15 @@ public fun DataFrame.Companion.readSqlQuery( * It should not contain `;` symbol. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [DbConnectionConfig]. * @return the DataFrame containing the result of the SQL query. */ public fun DbConnectionConfig.readDataFrame( sqlQueryOrTableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame = when { isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery( @@ -246,6 +261,7 @@ public fun DbConnectionConfig.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable( @@ -253,6 +269,7 @@ public fun DbConnectionConfig.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) else -> throw IllegalArgumentException( @@ -280,12 +297,15 @@ private fun isSqlTableName(sqlQueryOrTableName: String): Boolean { * It should not contain `;` symbol. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [Connection]. * @return the DataFrame containing the result of the SQL query. */ public fun Connection.readDataFrame( sqlQueryOrTableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame = when { isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery( @@ -293,6 +313,7 @@ public fun Connection.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable( @@ -300,6 +321,7 @@ public fun Connection.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) else -> throw IllegalArgumentException( @@ -386,6 +408,8 @@ public fun ResultSet.readDataFrame( * that the [ResultSet] belongs to. * @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet]. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [resultSet]. * @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data. * * [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html @@ -395,10 +419,11 @@ public fun DataFrame.Companion.readResultSet( connection: Connection, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) - return readResultSet(resultSet, dbType, limit, inferNullability) + return readResultSet(resultSet, determinedDbType, limit, inferNullability) } /** @@ -416,6 +441,8 @@ public fun DataFrame.Companion.readResultSet( * that the [ResultSet] belongs to. * @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet]. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [ResultSet]. * @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data. * * [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html @@ -424,7 +451,8 @@ public fun ResultSet.readDataFrame( connection: Connection, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, -): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability) + dbType: DbType? = null, +): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability, dbType) /** * Reads all non-system tables from a database and returns them @@ -434,6 +462,8 @@ public fun ResultSet.readDataFrame( * @param [limit] the maximum number of rows to read from each table. * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [dbConfig]. * @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database. */ public fun DataFrame.Companion.readAllSqlTables( @@ -441,9 +471,10 @@ public fun DataFrame.Companion.readAllSqlTables( catalogue: String? = null, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): Map { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readAllSqlTables(connection, catalogue, limit, inferNullability) + return readAllSqlTables(connection, catalogue, limit, inferNullability, dbType) } } @@ -455,6 +486,8 @@ public fun DataFrame.Companion.readAllSqlTables( * @param [limit] the maximum number of rows to read from each table. * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [connection]. * @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database. * * @see DriverManager.getConnection @@ -464,9 +497,10 @@ public fun DataFrame.Companion.readAllSqlTables( catalogue: String? = null, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): Map { val metaData = connection.metaData - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) // exclude a system and other tables without data, but it looks like it is supported badly for many databases val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE")) @@ -474,8 +508,8 @@ public fun DataFrame.Companion.readAllSqlTables( val dataFrames = mutableMapOf() while (tables.next()) { - val table = dbType.buildTableMetadata(tables) - if (!dbType.isSystemTable(table)) { + val table = determinedDbType.buildTableMetadata(tables) + if (!determinedDbType.isSystemTable(table)) { // we filter here a second time because of specific logic with SQLite and possible issues with future databases val tableName = when { catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}" @@ -488,7 +522,7 @@ public fun DataFrame.Companion.readAllSqlTables( // could be Dialect/Database specific logger.debug { "Reading table: $tableName" } - val dataFrame = readSqlTable(connection, tableName, limit, inferNullability) + val dataFrame = readSqlTable(connection, tableName, limit, inferNullability, dbType) dataFrames += tableName to dataFrame logger.debug { "Finished reading table: $tableName" } } @@ -502,11 +536,17 @@ public fun DataFrame.Companion.readAllSqlTables( * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [tableName] the name of the SQL table for which to retrieve the schema. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [dbConfig]. * @return the [DataFrameSchema] object representing the schema of the SQL table */ -public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig, tableName: String): DataFrameSchema { +public fun DataFrame.Companion.getSchemaForSqlTable( + dbConfig: DbConnectionConfig, + tableName: String, + dbType: DbType? = null, +): DataFrameSchema { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return getSchemaForSqlTable(connection, tableName) + return getSchemaForSqlTable(connection, tableName, dbType) } } @@ -515,20 +555,26 @@ public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig * * @param [connection] the database connection. * @param [tableName] the name of the SQL table for which to retrieve the schema. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [connection]. * @return the schema of the SQL table as a [DataFrameSchema] object. * * @see DriverManager.getConnection */ -public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tableName: String): DataFrameSchema { - val dbType = extractDBTypeFromConnection(connection) +public fun DataFrame.Companion.getSchemaForSqlTable( + connection: Connection, + tableName: String, + dbType: DbType? = null, +): DataFrameSchema { + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) val sqlQuery = "SELECT * FROM $tableName" - val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1) + val selectFirstRowQuery = determinedDbType.sqlQueryLimit(sqlQuery, limit = 1) connection.createStatement().use { st -> st.executeQuery(selectFirstRowQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return buildSchemaByTableColumns(tableColumns, dbType) + return buildSchemaByTableColumns(tableColumns, determinedDbType) } } } @@ -538,11 +584,17 @@ public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tabl * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [sqlQuery] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [dbConfig]. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig, sqlQuery: String): DataFrameSchema { +public fun DataFrame.Companion.getSchemaForSqlQuery( + dbConfig: DbConnectionConfig, + sqlQuery: String, + dbType: DbType? = null, +): DataFrameSchema { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return getSchemaForSqlQuery(connection, sqlQuery) + return getSchemaForSqlQuery(connection, sqlQuery, dbType) } } @@ -551,17 +603,23 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig * * @param [connection] the database connection. * @param [sqlQuery] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [connection]. * @return the schema of the SQL query as a [DataFrameSchema] object. * * @see DriverManager.getConnection */ -public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema { - val dbType = extractDBTypeFromConnection(connection) +public fun DataFrame.Companion.getSchemaForSqlQuery( + connection: Connection, + sqlQuery: String, + dbType: DbType? = null, +): DataFrameSchema { + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) connection.createStatement().use { st -> st.executeQuery(sqlQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return buildSchemaByTableColumns(tableColumns, dbType) + return buildSchemaByTableColumns(tableColumns, determinedDbType) } } } @@ -570,13 +628,18 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQ * Retrieves the schema of an SQL query result or the SQL table using the provided database configuration. * * @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [DbConnectionConfig]. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema = +public fun DbConnectionConfig.getDataFrameSchema( + sqlQueryOrTableName: String, + dbType: DbType? = null, +): DataFrameSchema = when { - isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName) + isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName, dbType) - isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName) + isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName, dbType) else -> throw IllegalArgumentException( "$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!", @@ -587,13 +650,15 @@ public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): D * Retrieves the schema of an SQL query result or the SQL table using the provided database configuration. * * @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [Connection]. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema = +public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String, dbType: DbType? = null): DataFrameSchema = when { - isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName) + isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName, dbType) - isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName) + isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName, dbType) else -> throw IllegalArgumentException( "$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!", @@ -606,7 +671,7 @@ public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrame * NOTE: This function will not close connection and result set and not retrieve data from the result set. * * @param [resultSet] the [ResultSet] obtained from executing a database query. - * @param [dbType] the type of database that the [ResultSet] belongs to. + * @param [dbType] the type of database that the [ResultSet] belongs to, could be a custom object, provided by user. * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbType: DbType): DataFrameSchema { @@ -619,48 +684,25 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp * * NOTE: This function will not close connection and result set and not retrieve data from the result set. * - * @param [dbType] the type of database that the [ResultSet] belongs to. + * @param [dbType] the type of database that the [ResultSet] belongs to, could be a custom object, provided by user. * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun ResultSet.getDataFrameSchema(dbType: DbType): DataFrameSchema = DataFrame.getSchemaForResultSet(this, dbType) -/** - * Retrieves the schema from [ResultSet]. - * - * NOTE: [connection] is required to extract the database type. - * This function will not close connection and result set and not retrieve data from the result set. - * - * @param [resultSet] the [ResultSet] obtained from executing a database query. - * @param [connection] the connection to the database (it's required to extract the database type). - * @return the schema of the [ResultSet] as a [DataFrameSchema] object. - */ -public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema { - val dbType = extractDBTypeFromConnection(connection) - - val tableColumns = getTableColumnsMetadata(resultSet) - return buildSchemaByTableColumns(tableColumns, dbType) -} - -/** - * Retrieves the schema from [ResultSet]. - * - * NOTE: This function will not close connection and result set and not retrieve data from the result set. - * - * @param [connection] the connection to the database (it's required to extract the database type). - * @return the schema of the [ResultSet] as a [DataFrameSchema] object. - */ -public fun ResultSet.getDataFrameSchema(connection: Connection): DataFrameSchema = - DataFrame.getSchemaForResultSet(this, connection) - /** * Retrieves the schemas of all non-system tables in the database using the provided database configuration. * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [dbConfig]. * @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table. */ -public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionConfig): Map { +public fun DataFrame.Companion.getSchemaForAllSqlTables( + dbConfig: DbConnectionConfig, + dbType: DbType? = null, +): Map { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return getSchemaForAllSqlTables(connection) + return getSchemaForAllSqlTables(connection, dbType) } } @@ -668,11 +710,16 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionCo * Retrieves the schemas of all non-system tables in the database using the provided database connection. * * @param [connection] the database connection. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`, + * in that case the [dbType] will be recognized from the [connection]. * @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table. */ -public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map { +public fun DataFrame.Companion.getSchemaForAllSqlTables( + connection: Connection, + dbType: DbType? = null, +): Map { val metaData = connection.metaData - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) val tableTypes = arrayOf("TABLE") // exclude a system and other tables without data @@ -681,11 +728,11 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): val dataFrameSchemas = mutableMapOf() while (tables.next()) { - val jdbcTable = dbType.buildTableMetadata(tables) - if (!dbType.isSystemTable(jdbcTable)) { + val jdbcTable = determinedDbType.buildTableMetadata(tables) + if (!determinedDbType.isSystemTable(jdbcTable)) { // we filter her a second time because of specific logic with SQLite and possible issues with future databases val tableName = jdbcTable.name - val dataFrameSchema = getSchemaForSqlTable(connection, tableName) + val dataFrameSchema = getSchemaForSqlTable(connection, tableName, determinedDbType) dataFrameSchemas += tableName to dataFrameSchema } } @@ -826,6 +873,7 @@ private fun fetchAndConvertDataFromResultSet( } val dataFrame = data.mapIndexed { index, values -> + // TODO: add override handlers from dbType to intercept the final parcing before column creation val correctedValues = if (kotlinTypesForSqlColumns[index]!!.classifier == Array::class) { handleArrayValues(values) } else { diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt index c83d59158..f3c676ce7 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt @@ -166,6 +166,8 @@ class JdbcTest { val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) dataSchema.columns.size shouldBe 2 dataSchema.columns["characterCol"]!!.type shouldBe typeOf() + + connection.createStatement().execute("DROP TABLE EmptyTestTable") } @Test @@ -299,6 +301,8 @@ class JdbcTest { schema.columns["realCol"]!!.type shouldBe typeOf() schema.columns["doublePrecisionCol"]!!.type shouldBe typeOf() schema.columns["decFloatCol"]!!.type shouldBe typeOf() + + connection.createStatement().execute("DROP TABLE $tableName") } @Test @@ -441,7 +445,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema1 = DataFrame.getSchemaForResultSet(rs, connection) + val dataSchema1 = DataFrame.getSchemaForResultSet(rs, H2(MySql)) dataSchema1.columns.size shouldBe 3 dataSchema1.columns["name"]!!.type shouldBe typeOf() } @@ -493,7 +497,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema1 = rs.getDataFrameSchema(connection) + val dataSchema1 = rs.getDataFrameSchema(H2(MySql)) dataSchema1.columns.size shouldBe 3 dataSchema1.columns["name"]!!.type shouldBe typeOf() } @@ -613,6 +617,7 @@ class JdbcTest { """ DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0 + connection.createStatement().execute("DROP TABLE \"ALTER\"") } @Test @@ -967,4 +972,127 @@ class JdbcTest { } exception.message shouldBe "H2 database could not be specified with H2 dialect!" } + + // helper object created for API testing purposes + object CustomDB : H2(MySql) + + @Test + fun `read from table from custom database`() { + val tableName = "Customer" + val df = DataFrame.readSqlTable(connection, tableName, dbType = CustomDB).cast() + + df.rowsCount() shouldBe 4 + df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + df[0][1] shouldBe "John" + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName, dbType = CustomDB) + dataSchema.columns.size shouldBe 3 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + val dbConfig = DbConnectionConfig(url = URL) + val df2 = DataFrame.readSqlTable(dbConfig, tableName, dbType = CustomDB).cast() + + df2.rowsCount() shouldBe 4 + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + df2[0][1] shouldBe "John" + + val dataSchema1 = DataFrame.getSchemaForSqlTable(dbConfig, tableName, dbType = CustomDB) + dataSchema1.columns.size shouldBe 3 + dataSchema1.columns["name"]!!.type shouldBe typeOf() + } + + @Test + fun `read from query from custom database`() { + @Language("SQL") + val sqlQuery = + """ + SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount + FROM Sale s + INNER JOIN Customer c ON s.customerId = c.id + WHERE c.age > 35 + GROUP BY s.customerId, c.name + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery, dbType = CustomDB).cast() + + df.rowsCount() shouldBe 2 + df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 + df[0][0] shouldBe "John" + + val dataSchema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery, dbType = CustomDB) + dataSchema.columns.size shouldBe 2 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + val dbConfig = DbConnectionConfig(url = URL) + val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB).cast() + + df2.rowsCount() shouldBe 2 + df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 + df2[0][0] shouldBe "John" + + val dataSchema1 = DataFrame.getSchemaForSqlQuery(dbConfig, sqlQuery, dbType = CustomDB) + dataSchema1.columns.size shouldBe 2 + dataSchema1.columns["name"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables from custom database`() { + val dataFrameMap = DataFrame.readAllSqlTables(connection, dbType = CustomDB) + dataFrameMap.containsKey("Customer") shouldBe true + dataFrameMap.containsKey("Sale") shouldBe true + + val dataframes = dataFrameMap.values.toList() + + val customerDf = dataframes[0].cast() + + customerDf.rowsCount() shouldBe 4 + customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + customerDf[0][1] shouldBe "John" + + val saleDf = dataframes[1].cast() + + saleDf.rowsCount() shouldBe 4 + saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + + val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection, dbType = CustomDB) + dataFrameSchemaMap.containsKey("Customer") shouldBe true + dataFrameSchemaMap.containsKey("Sale") shouldBe true + + val dataSchemas = dataFrameSchemaMap.values.toList() + + val customerDataSchema = dataSchemas[0] + customerDataSchema.columns.size shouldBe 3 + customerDataSchema.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema = dataSchemas[1] + saleDataSchema.columns.size shouldBe 3 + // TODO: fix nullability + saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + + val dbConfig = DbConnectionConfig(url = URL) + val dataframes2 = DataFrame.readAllSqlTables(dbConfig, dbType = CustomDB).values.toList() + + val customerDf2 = dataframes2[0].cast() + + customerDf2.rowsCount() shouldBe 4 + customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + customerDf2[0][1] shouldBe "John" + + val saleDf2 = dataframes2[1].cast() + + saleDf2.rowsCount() shouldBe 4 + saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + + val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig, dbType = CustomDB).values.toList() + + val customerDataSchema1 = dataSchemas1[0] + customerDataSchema1.columns.size shouldBe 3 + customerDataSchema1.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema1 = dataSchemas1[1] + saleDataSchema1.columns.size shouldBe 3 + saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() + } }