refactoring: force cast args

This commit is contained in:
2019-06-17 14:13:12 +02:00
parent 00d2fa335d
commit f960e4a66a
4 changed files with 33 additions and 12 deletions

View File

@@ -114,7 +114,7 @@ class Requester (
val functionRegex = """create .*(procedure|function) *(?<name>[^(\s]+)\s*\((?<params>(\s*((IN|OUT|INOUT|VARIADIC)?\s+)?([^\s,)]+\s+)?([^\s,)]+)(\s+(?:default\s|=)\s*[^\s,)]+)?\s*(,|(?=\))))*)\) *(?<return>RETURNS *[^ ]+)?""" val functionRegex = """create .*(procedure|function) *(?<name>[^(\s]+)\s*\((?<params>(\s*((IN|OUT|INOUT|VARIADIC)?\s+)?([^\s,)]+\s+)?([^\s,)]+)(\s+(?:default\s|=)\s*[^\s,)]+)?\s*(,|(?=\))))*)\) *(?<return>RETURNS *[^ ]+)?"""
.toRegex(setOf(IGNORE_CASE, MULTILINE)) .toRegex(setOf(IGNORE_CASE, MULTILINE))
val paramsRegex = """\s*(?<param>((?<direction>IN|OUT|INOUT|VARIADIC)?\s+)?(?<name>[^\s,)]+\s+)?(?<type>[^\s,)]+)(\s+(?:default\s|=)\s*[^\s,)]+)?)\s*(,|$)""" val paramsRegex = """\s*(?<param>((?<direction>IN|OUT|INOUT|VARIADIC)?\s+)?(?<name>[^\s,)]+\s+)?(?<type>[^\s,)]+)(\s+(?<default>default\s|=)\s*[^\s,)]+)?)\s*(,|$)"""
.toRegex(setOf(IGNORE_CASE, MULTILINE)) .toRegex(setOf(IGNORE_CASE, MULTILINE))
return functionRegex.findAll(functionContent).map { queryMatch -> return functionRegex.findAll(functionContent).map { queryMatch ->
@@ -129,7 +129,8 @@ class Requester (
Function.Parameter( Function.Parameter(
paramsMatch.groups["name"]!!.value.trim(), paramsMatch.groups["name"]!!.value.trim(),
paramsMatch.groups["type"]!!.value.trim(), paramsMatch.groups["type"]!!.value.trim(),
paramsMatch.groups["direction"]?.value?.trim()) paramsMatch.groups["direction"]?.value?.trim(),
paramsMatch.groups["default"]?.value?.trim())
}.toList() }.toList()
} else { } else {
listOf() listOf()
@@ -173,7 +174,7 @@ class Requester (
class Function(val name: String, val parameters: List<Parameter>, private val connection : Connection) { class Function(val name: String, val parameters: List<Parameter>, private val connection : Connection) {
class Parameter(val name: String, val type: String, direction: Direction? = Direction.IN) class Parameter(val name: String, val type: String, direction: Direction? = Direction.IN, val default: Any? = null)
{ {
val direction: Direction val direction: Direction
@@ -184,10 +185,11 @@ class Requester (
this.direction = direction this.direction = direction
} }
} }
constructor(name: String, type: String, direction: String? = "IN") : this( constructor(name: String, type: String, direction: String? = "IN", default: Any? = null) : this(
name = name, name = name,
type = type, type = type,
direction = direction?.let { Direction.valueOf(direction.toUpperCase())} direction = direction?.let { Direction.valueOf(direction.toUpperCase())},
default = default
) )
enum class Direction { IN, OUT, INOUT } enum class Direction { IN, OUT, INOUT }
} }
@@ -197,8 +199,8 @@ class Requester (
} }
fun <T, R : EntityI<T?>?> selectOne(typeReference: TypeReference<R>, values: List<String?> = emptyList()): R? { fun <T, R : EntityI<T?>?> selectOne(typeReference: TypeReference<R>, values: List<String?> = emptyList()): R? {
val placeholder = List(values.size) {"?"}.joinToString(separator=", ") val args = compileArgs(values)
val sql = "SELECT * FROM $name (${placeholder})" val sql = "SELECT * FROM $name ($args)"
return connection.selectOne(sql, typeReference, values) return connection.selectOne(sql, typeReference, values)
} }
@@ -206,13 +208,25 @@ class Requester (
inline fun <T, reified R: EntityI<T?>?> selectOne(values: List<String?> = emptyList()): R? = selectOne(object: TypeReference<R>() {}, values) inline fun <T, reified R: EntityI<T?>?> selectOne(values: List<String?> = emptyList()): R? = selectOne(object: TypeReference<R>() {}, values)
fun <T, R : List<EntityI<T?>?>> select(typeReference: TypeReference<R>, values: List<Any?> = emptyList()): R? { fun <T, R : List<EntityI<T?>?>> select(typeReference: TypeReference<R>, values: List<Any?> = emptyList()): R? {
val placeholder = List(values.size) {"?"}.joinToString(separator=", ") val args = compileArgs(values)
val sql = "SELECT * FROM $name ($placeholder)" val sql = "SELECT * FROM $name ($args)"
return connection.select(sql, typeReference, values) return connection.select(sql, typeReference, values)
} }
inline fun <T, reified R: List<EntityI<T?>?>> select(values: List<Any?> = emptyList()): R? = select(object: TypeReference<R>() {}, values) inline fun <T, reified R: List<EntityI<T?>?>> select(values: List<Any?> = emptyList()): R? = select(object: TypeReference<R>() {}, values)
private fun compileArgs(values: List<Any?>): String {
val placeholders = values
.filterIndexed { index, any ->
this.parameters[index].default === null || any !== null
}
.mapIndexed { index, any ->
"?::" + this.parameters[index].type
}
return placeholders.joinToString(separator=", ")
}
} }
} }

View File

@@ -2,7 +2,7 @@ package fr.postgresjson
import fr.postgresjson.connexion.Connection import fr.postgresjson.connexion.Connection
import fr.postgresjson.entity.IdEntity import fr.postgresjson.entity.IdEntity
import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance
@@ -46,4 +46,11 @@ class ConnectionTest(): TestAbstract() {
assertTrue(objs[0].id == 1) assertTrue(objs[0].id == 1)
assertTrue(objs[0].test!!.id == 1) assertTrue(objs[0].test!!.id == 1)
} }
@Test
fun callRequestWithArgs() {
val result: ObjTest? = connection.selectOne("select json_build_object('id', 1, 'name', ?::text)", listOf("myName"))
assertNotNull(result)
assertEquals("myName", result!!.name)
}
} }

View File

@@ -26,7 +26,7 @@ class RequestTest: TestAbstract() {
val objTest: ObjTest? = Requester(getConnextion()) val objTest: ObjTest? = Requester(getConnextion())
.addFunction(resources) .addFunction(resources)
.getFunction("test_function") .getFunction("test_function")
.selectOne(listOf("ploop", "plip")) .selectOne(listOf("test", "plip"))
assertEquals(objTest!!.id, 3) assertEquals(objTest!!.id, 3)
assertEquals(objTest.name, "test") assertEquals(objTest.name, "test")
} }

View File

@@ -3,6 +3,6 @@ LANGUAGE plpgsql
AS AS
$$ $$
BEGIN BEGIN
result = json_build_object('id', 3, 'name', 'test'); result = json_build_object('id', 3, 'name', name);
END; END;
$$ $$