Skip to content

Commit

Permalink
Add statement param setting
Browse files Browse the repository at this point in the history
  • Loading branch information
esaounkine committed Jul 24, 2017
1 parent 750bf0a commit 98d0f78
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 44 deletions.
104 changes: 60 additions & 44 deletions src/main/kotlin/kotliquery/Session.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,58 +27,74 @@ open class Session(

private val logger = LoggerFactory.getLogger(Session::class.java)

private fun createPreparedStatement(query: Query): PreparedStatement {
val stmt = if (returnGeneratedKeys) {
if (connection.driverName == "oracle.jdbc.driver.OracleDriver") {
connection.underlying.prepareStatement(query.statement, autoGeneratedKeys.toTypedArray())
} else {
connection.underlying.prepareStatement(query.statement, Statement.RETURN_GENERATED_KEYS)
}
private fun PreparedStatement.setParam(idx: Int, v: Any?) {
if (v == null) {
this.setObject(idx, null)
} else {
connection.underlying.prepareStatement(query.statement)
when (v) {
is String -> this.setString(idx, v)
is Byte -> this.setByte(idx, v)
is Boolean -> this.setBoolean(idx, v)
is Int -> this.setInt(idx, v)
is Long -> this.setLong(idx, v)
is Short -> this.setShort(idx, v)
is Double -> this.setDouble(idx, v)
is Float -> this.setFloat(idx, v)
is ZonedDateTime -> this.setTimestamp(idx, Timestamp(Date.from(v.toInstant()).time))
is OffsetDateTime -> this.setTimestamp(idx, Timestamp(Date.from(v.toInstant()).time))
is Instant -> this.setTimestamp(idx, Timestamp(Date.from(v).time))
is LocalDateTime -> this.setTimestamp(idx, Timestamp(org.joda.time.LocalDateTime.parse(v.toString()).toDate().time))
is LocalDate -> this.setDate(idx, java.sql.Date(org.joda.time.LocalDate.parse(v.toString()).toDate().time))
is LocalTime -> this.setTime(idx, java.sql.Time(org.joda.time.LocalTime.parse(v.toString()).toDateTimeToday().millis))
is org.joda.time.DateTime -> this.setTimestamp(idx, Timestamp(v.toDate().time))
is org.joda.time.LocalDateTime -> this.setTimestamp(idx, Timestamp(v.toDate().time))
is org.joda.time.LocalDate -> this.setDate(idx, java.sql.Date(v.toDate().time))
is org.joda.time.LocalTime -> this.setTime(idx, java.sql.Time(v.toDateTimeToday().millis))
is java.util.Date -> this.setTimestamp(idx, Timestamp(v.time))
is java.sql.Timestamp -> this.setTimestamp(idx, v)
is java.sql.Time -> this.setTime(idx, v)
is java.sql.Date -> this.setTimestamp(idx, Timestamp(v.time))
is java.sql.SQLXML -> this.setSQLXML(idx, v)
is ByteArray -> this.setBytes(idx, v)
is InputStream -> this.setBinaryStream(idx, v)
is BigDecimal -> this.setBigDecimal(idx, v)
is java.sql.Array -> this.setArray(idx, v)
is URL -> this.setURL(idx, v)
else -> this.setObject(idx, v)
}
}
query.params.withIndex().forEach { param ->
val v = param.value
val idx = param.index + 1
if (v == null) {
stmt.setObject(idx, null)
} else {
when (v) {
is String -> stmt.setString(idx, v)
is Byte -> stmt.setByte(idx, v)
is Boolean -> stmt.setBoolean(idx, v)
is Int -> stmt.setInt(idx, v)
is Long -> stmt.setLong(idx, v)
is Short -> stmt.setShort(idx, v)
is Double -> stmt.setDouble(idx, v)
is Float -> stmt.setFloat(idx, v)
is ZonedDateTime -> stmt.setTimestamp(idx, Timestamp(Date.from(v.toInstant()).time))
is OffsetDateTime -> stmt.setTimestamp(idx, Timestamp(Date.from(v.toInstant()).time))
is Instant -> stmt.setTimestamp(idx, Timestamp(Date.from(v).time))
is LocalDateTime -> stmt.setTimestamp(idx, Timestamp(org.joda.time.LocalDateTime.parse(v.toString()).toDate().time))
is LocalDate -> stmt.setDate(idx, java.sql.Date(org.joda.time.LocalDate.parse(v.toString()).toDate().time))
is LocalTime -> stmt.setTime(idx, java.sql.Time(org.joda.time.LocalTime.parse(v.toString()).toDateTimeToday().millis))
is org.joda.time.DateTime -> stmt.setTimestamp(idx, Timestamp(v.toDate().time))
is org.joda.time.LocalDateTime -> stmt.setTimestamp(idx, Timestamp(v.toDate().time))
is org.joda.time.LocalDate -> stmt.setDate(idx, java.sql.Date(v.toDate().time))
is org.joda.time.LocalTime -> stmt.setTime(idx, java.sql.Time(v.toDateTimeToday().millis))
is java.util.Date -> stmt.setTimestamp(idx, Timestamp(v.time))
is java.sql.Timestamp -> stmt.setTimestamp(idx, v)
is java.sql.Time -> stmt.setTime(idx, v)
is java.sql.Date -> stmt.setTimestamp(idx, Timestamp(v.time))
is java.sql.SQLXML -> stmt.setSQLXML(idx, v)
is ByteArray -> stmt.setBytes(idx, v)
is InputStream -> stmt.setBinaryStream(idx, v)
is BigDecimal -> stmt.setBigDecimal(idx, v)
is java.sql.Array -> stmt.setArray(idx, v)
is URL -> stmt.setURL(idx, v)
else -> stmt.setObject(idx, v)
}

fun populateParams(query: Query, stmt: PreparedStatement): PreparedStatement {
if(query.replacementMap.isNotEmpty()) {
query.replacementMap.forEach { paramName, occurrences ->
occurrences.forEach {
stmt.setParam(it + 1, query.paramMap[paramName])
}
}
} else {
query.params.forEachIndexed { index, value ->
stmt.setParam(index + 1, value)
}
}

return stmt
}

fun createPreparedStatement(query: Query): PreparedStatement {
val stmt = if (returnGeneratedKeys) {
if (connection.driverName == "oracle.jdbc.driver.OracleDriver") {
connection.underlying.prepareStatement(query.cleanStatement, autoGeneratedKeys.toTypedArray())
} else {
connection.underlying.prepareStatement(query.cleanStatement, Statement.RETURN_GENERATED_KEYS)
}
} else {
connection.underlying.prepareStatement(query.cleanStatement)
}

return populateParams(query, stmt)
}

private fun <A> rows(query: Query, extractor: (Row) -> A?): List<A> {
return using(createPreparedStatement(query)) { stmt ->
using(stmt.executeQuery()) { rs ->
Expand Down
38 changes: 38 additions & 0 deletions src/test/kotlin/kotliquery/UsageTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kotliquery

import org.junit.Test
import java.sql.DriverManager
import java.sql.PreparedStatement
import java.time.ZonedDateTime
import java.util.*
import kotlin.test.assertEquals
Expand Down Expand Up @@ -195,4 +196,41 @@ create table members (
}
}


@Test
fun stmtParamPopulation() {
withPreparedStmt(queryOf("""SELECT * FROM dual t
WHERE (:param1 IS NULL OR :param2 = :param2)
AND (:param2 IS NULL OR :param1 = :param3)
AND (:param3 IS NULL OR :param3 = :param1)""",
paramMap = mapOf("param1" to "1",
"param2" to 2,
"param3" to true))
) { preparedStmt ->
assertEquals("""SELECT * FROM dual t
WHERE (? IS NULL OR ? = ?) AND (? IS NULL OR ? = ?) AND (? IS NULL OR ? = ?)
{1: '1', 2: 2, 3: 2, 4: 2, 5: '1', 6: TRUE, 7: TRUE, 8: TRUE, 9: '1'}""".normalizeSpaces(),
preparedStmt.toString().split(": ", limit = 2)[1].normalizeSpaces())
}

withPreparedStmt(queryOf("""SELECT * FROM dual t WHERE (:param1 IS NULL OR :param2 = :param2)""",
paramMap = mapOf("param2" to 2))
) { preparedStmt ->
assertEquals("""SELECT * FROM dual t WHERE (? IS NULL OR ? = ?)
{1: NULL, 2: 2, 3: 2}""".normalizeSpaces(),
preparedStmt.toString().split(": ", limit = 2)[1].normalizeSpaces())
}

}

fun withPreparedStmt(query: Query, closure: (PreparedStatement) -> Unit) {
using(borrowConnection()) { conn ->
val session = Session(Connection(conn, driverName))

val preparedStmt = session.createPreparedStatement(query)

closure(preparedStmt)
}
}

}

0 comments on commit 98d0f78

Please sign in to comment.