diff --git a/src/main/kotlin/fr/postgresjson/connexion/Connection.kt b/src/main/kotlin/fr/postgresjson/connexion/Connection.kt index 1c33eab..9108304 100644 --- a/src/main/kotlin/fr/postgresjson/connexion/Connection.kt +++ b/src/main/kotlin/fr/postgresjson/connexion/Connection.kt @@ -114,7 +114,7 @@ class Requester ( val functionRegex = """create .*(procedure|function) *(?[^(\s]+)\s*\((?(\s*((IN|OUT|INOUT|VARIADIC)?\s+)?([^\s,)]+\s+)?([^\s,)]+)(\s+(?:default\s|=)\s*[^\s,)]+)?\s*(,|(?=\))))*)\) *(?RETURNS *[^ ]+)?""" .toRegex(setOf(IGNORE_CASE, MULTILINE)) - val paramsRegex = """\s*(?((?IN|OUT|INOUT|VARIADIC)?\s+)?(?[^\s,)]+\s+)?(?[^\s,)]+)(\s+(?:default\s|=)\s*[^\s,)]+)?)\s*(,|$)""" + val paramsRegex = """\s*(?((?IN|OUT|INOUT|VARIADIC)?\s+)?(?[^\s,)]+\s+)?(?[^\s,)]+)(\s+(?default\s|=)\s*[^\s,)]+)?)\s*(,|$)""" .toRegex(setOf(IGNORE_CASE, MULTILINE)) return functionRegex.findAll(functionContent).map { queryMatch -> @@ -129,7 +129,8 @@ class Requester ( Function.Parameter( paramsMatch.groups["name"]!!.value.trim(), paramsMatch.groups["type"]!!.value.trim(), - paramsMatch.groups["direction"]?.value?.trim()) + paramsMatch.groups["direction"]?.value?.trim(), + paramsMatch.groups["default"]?.value?.trim()) }.toList() } else { listOf() @@ -173,7 +174,7 @@ class Requester ( class Function(val name: String, val parameters: List, private val connection : Connection) { - class Parameter(val name: String, val type: String, direction: Direction? = Direction.IN) + class Parameter(val name: String, val type: String, direction: Direction? = Direction.IN, val default: Any? = null) { val direction: Direction @@ -184,10 +185,11 @@ class Requester ( this.direction = direction } } - constructor(name: String, type: String, direction: String? = "IN") : this( + constructor(name: String, type: String, direction: String? = "IN", default: Any? = null) : this( name = name, type = type, - direction = direction?.let { Direction.valueOf(direction.toUpperCase())} + direction = direction?.let { Direction.valueOf(direction.toUpperCase())}, + default = default ) enum class Direction { IN, OUT, INOUT } } @@ -197,8 +199,8 @@ class Requester ( } fun ?> selectOne(typeReference: TypeReference, values: List = emptyList()): R? { - val placeholder = List(values.size) {"?"}.joinToString(separator=", ") - val sql = "SELECT * FROM $name (${placeholder})" + val args = compileArgs(values) + val sql = "SELECT * FROM $name ($args)" return connection.selectOne(sql, typeReference, values) } @@ -206,13 +208,25 @@ class Requester ( inline fun ?> selectOne(values: List = emptyList()): R? = selectOne(object: TypeReference() {}, values) fun ?>> select(typeReference: TypeReference, values: List = emptyList()): R? { - val placeholder = List(values.size) {"?"}.joinToString(separator=", ") - val sql = "SELECT * FROM $name ($placeholder)" + val args = compileArgs(values) + val sql = "SELECT * FROM $name ($args)" return connection.select(sql, typeReference, values) } inline fun ?>> select(values: List = emptyList()): R? = select(object: TypeReference() {}, values) + + private fun compileArgs(values: List): String { + val placeholders = values + .filterIndexed { index, any -> + this.parameters[index].default === null || any !== null + } + .mapIndexed { index, any -> + "?::" + this.parameters[index].type + } + + return placeholders.joinToString(separator=", ") + } } } diff --git a/src/test/kotlin/fr/postgresjson/ConnectionTest.kt b/src/test/kotlin/fr/postgresjson/ConnectionTest.kt index 7f1f19f..be06db6 100644 --- a/src/test/kotlin/fr/postgresjson/ConnectionTest.kt +++ b/src/test/kotlin/fr/postgresjson/ConnectionTest.kt @@ -2,7 +2,7 @@ package fr.postgresjson import fr.postgresjson.connexion.Connection import fr.postgresjson.entity.IdEntity -import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance @@ -46,4 +46,11 @@ class ConnectionTest(): TestAbstract() { assertTrue(objs[0].id == 1) assertTrue(objs[0].test!!.id == 1) } + + @Test + fun callRequestWithArgs() { + val result: ObjTest? = connection.selectOne("select json_build_object('id', 1, 'name', ?::text)", listOf("myName")) + assertNotNull(result) + assertEquals("myName", result!!.name) + } } \ No newline at end of file diff --git a/src/test/kotlin/fr/postgresjson/RequestTest.kt b/src/test/kotlin/fr/postgresjson/RequestTest.kt index efc3e8e..79aa602 100644 --- a/src/test/kotlin/fr/postgresjson/RequestTest.kt +++ b/src/test/kotlin/fr/postgresjson/RequestTest.kt @@ -26,7 +26,7 @@ class RequestTest: TestAbstract() { val objTest: ObjTest? = Requester(getConnextion()) .addFunction(resources) .getFunction("test_function") - .selectOne(listOf("ploop", "plip")) + .selectOne(listOf("test", "plip")) assertEquals(objTest!!.id, 3) assertEquals(objTest.name, "test") } diff --git a/src/test/resources/sql/function/Test/function_test.sql b/src/test/resources/sql/function/Test/function_test.sql index 8e81032..3e3698d 100644 --- a/src/test/resources/sql/function/Test/function_test.sql +++ b/src/test/resources/sql/function/Test/function_test.sql @@ -3,6 +3,6 @@ LANGUAGE plpgsql AS $$ BEGIN - result = json_build_object('id', 3, 'name', 'test'); + result = json_build_object('id', 3, 'name', name); END; $$ \ No newline at end of file