package fr.postgresjson.connexion import com.fasterxml.jackson.core.type.TypeReference import com.github.jasync.sql.db.Connection import com.github.jasync.sql.db.QueryResult import com.github.jasync.sql.db.pool.ConnectionPool import com.github.jasync.sql.db.postgresql.PostgreSQLConnection import com.github.jasync.sql.db.postgresql.PostgreSQLConnectionBuilder import com.github.jasync.sql.db.util.length import fr.postgresjson.entity.EntityI import fr.postgresjson.entity.Serializable import fr.postgresjson.serializer.Serializer import fr.postgresjson.utils.LoggerDelegate import org.slf4j.Logger import java.util.concurrent.* typealias SelectOneCallback = QueryResult.(T?) -> Unit typealias SelectCallback = QueryResult.(List) -> Unit typealias SelectPaginatedCallback = QueryResult.(Paginated) -> Unit class Connection( private val database: String, private val username: String, private val password: String, private val host: String = "localhost", private val port: Int = 5432 ) : Executable { private lateinit var connection: ConnectionPool private val serializer = Serializer() private val logger: Logger? by LoggerDelegate() internal fun connect(): ConnectionPool { if (!::connection.isInitialized || !connection.isConnected()) { connection = PostgreSQLConnectionBuilder.createConnectionPool( "jdbc:postgresql://$host:$port/$database?user=$username&password=$password" ) } return connection } fun inTransaction(f: (Connection) -> CompletableFuture) = connect().inTransaction(f) override fun select( sql: String, typeReference: TypeReference, values: List, block: (QueryResult, R?) -> Unit ): R? { val primaryObject = values.firstOrNull { it is EntityI && typeReference.type.typeName == it::class.java.name } as R? val result = exec(sql, compileArgs(values)) val json = result.rows[0].getString(0) return if (json === null) { null } else { if (primaryObject != null) { serializer.deserialize(json, primaryObject) } else { serializer.deserialize(json, typeReference) } }.also { block(result, it) } } inline fun selectOne( sql: String, values: List = emptyList(), noinline block: SelectOneCallback = {} ): R? = select(sql, object : TypeReference() {}, values, block) override fun select( sql: String, typeReference: TypeReference, values: Map, block: (QueryResult, R?) -> Unit ): R? { return replaceArgs(sql, values) { select(this.sql, typeReference, this.parameters, block) } } inline fun selectOne( sql: String, values: Map, noinline block: SelectOneCallback = {} ): R? = select(sql, object : TypeReference() {}, values, block) override fun select( sql: String, typeReference: TypeReference>, values: List, block: (QueryResult, List) -> Unit ): List { val result = exec(sql, compileArgs(values)) val json = result.rows[0].getString(0) return if (json === null) { listOf() as List } else { serializer.deserializeList(json, typeReference) }.also { block(result, it) } } inline fun select( sql: String, values: List = emptyList(), noinline block: SelectCallback = {} ): List = select(sql, object : TypeReference>() {}, values, block) override fun select( sql: String, page: Int, limit: Int, typeReference: TypeReference>, values: Map, block: (QueryResult, Paginated) -> Unit ): Paginated { val offset = (page - 1) * limit val newValues = values .plus("offset" to offset) .plus("limit" to limit) val line = replaceArgs(sql, newValues) { exec(this.sql, this.parameters) } return line.run { val json = rows[0].getString(0) val entities = if (json === null) { listOf() as List } else { serializer.deserializeList(json, typeReference) } Paginated( entities, offset, limit, rows[0].getInt("total") ?: error("The query not return total") ) }.also { block(line, it) } } inline fun select( sql: String, page: Int, limit: Int, values: Map = emptyMap(), noinline block: SelectPaginatedCallback = {} ): Paginated = select(sql, page, limit, object : TypeReference>() {}, values, block) override fun select( sql: String, typeReference: TypeReference>, values: Map, block: (QueryResult, List) -> Unit ): List { return replaceArgs(sql, values) { select(this.sql, typeReference, this.parameters, block) } } inline fun select( sql: String, values: Map, noinline block: SelectCallback = {} ): List = select(sql, object : TypeReference>() {}, values, block) override fun exec(sql: String, values: List): QueryResult { val compiledValues = compileArgs(values) return stopwatchQuery(sql, compiledValues) { connect().sendPreparedStatement(sql, compiledValues).join() } } override fun exec(sql: String, values: Map): QueryResult { return replaceArgs(sql, values) { exec(this.sql, this.parameters) } } override fun sendQuery(sql: String, values: List): Int { val compiledValues = compileArgs(values) return stopwatchQuery(sql, compiledValues) { replaceArgsIntoSql(sql, compiledValues) { connect().sendQuery(it).join().rowsAffected.toInt() } } } override fun sendQuery(sql: String, values: Map): Int { return replaceArgs(sql, values) { sendQuery(this.sql, this.parameters) } } private fun compileArgs(values: List): List { return values.map { if (it is Serializable || (it is List<*> && it.firstOrNull() is Serializable)) { serializer.serialize(it) } else { it } } } private fun replaceArgs(sql: String, values: Map, block: ParametersQuery.() -> T): T { val paramRegex = "(? val name = match.groups[1]!!.value values[name] ?: values[name.trimStart('_')] ?: error("Parameter $name missing") }.toList() var newSql = sql values.forEach { (key, _) -> val regex = ":_?$key".toRegex() newSql = newSql.replace(regex, "?") } return block(ParametersQuery(newSql, newArgs)) } private fun replaceArgsIntoSql(sql: String, values: List, block: (String) -> T): T { val paramRegex = "(?) private fun stopwatchQuery(sql: String, values: List = emptyList(), callback: () -> T): T { try { val start = System.currentTimeMillis() val result = callback() val duration = (System.currentTimeMillis() - start) val resultText = when (result) { null -> "with no result" is QueryResult -> result.rows.firstOrNull()?.joinToString(", ")?.let { text -> if (text.length > 100) "${text.take(100)}... (size: ${text.length})" else text } ?: "with no result" else -> "unknown" } val args = """ |Query ($duration ms): |${sql.trimIndent().prependIndent()} |Arguments (${values.length}): |${values.joinToString("\n").ifBlank { "No arguments" }.prependIndent()} |Result: |${resultText.trimIndent().prependIndent()} """.trimMargin().prependIndent(" > ") logger?.debug("Query executed in $duration ms \n{}", args) return result } catch (e: Throwable) { logger?.info(""" Query Error: ${sql.prependIndent()}, ${values.joinToString(", ").prependIndent()} """.trimIndent(), e) throw e } } }