diff --git a/src/main/kotlin/fr/postgresjson/definition/Function.kt b/src/main/kotlin/fr/postgresjson/definition/Function.kt index 8078d5e..9aaae8d 100644 --- a/src/main/kotlin/fr/postgresjson/definition/Function.kt +++ b/src/main/kotlin/fr/postgresjson/definition/Function.kt @@ -1,11 +1,11 @@ package fr.postgresjson.definition -import com.github.jasync.sql.db.util.length import fr.postgresjson.definition.Parameter.Direction import java.nio.file.Path import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind.EXACTLY_ONCE import kotlin.contracts.contract +import kotlin.text.RegexOption.IGNORE_CASE class Function( override val script: String, @@ -46,6 +46,8 @@ class Function( private class NextScript(val value: T, val restOfScript: String) { val nextScriptPart: ScriptPart = ScriptPart(restOfScript) + fun isLast() = restOfScript == "" + fun isEmptyValue() = value == "" || value == null } /** @@ -79,7 +81,7 @@ class Function( private fun ScriptPart.getFunctionName(): NextScript { try { - return getNextScript { status.isNotEscaped() && listOf("(", " ", "\n").any { afterBeginBy(it) } } + return getNextScript { status.isNotEscaped() && afterBeginBy("(", " ", "\n") } } catch (e: NameMalformed) { throw FunctionNameMalformed(null, e) } @@ -112,17 +114,17 @@ class Function( val script: String, ) { fun afterBeginBy(vararg texts: String): Boolean = texts.any { - script.substring(index+1).take(it.length) == it + script.substring(index + 1).take(it.length) == it } - val nextChar: Char? get() = script.substring(index+1).getOrNull(0) + val nextChar: Char? get() = script.substring(index + 1).getOrNull(0) } /** * Get next part of script. * You can define a list of characters that end the part of script. Like `(` or space. */ - private fun ScriptPart.getNextScript(isEnd: Context.() -> Boolean): NextScript { + private fun ScriptPart.getNextScript(isEnd: Context.() -> Boolean = { false }): NextScript { val status = Status() fun String.unescape(): String { @@ -157,7 +159,7 @@ class Function( } if (isEnd(Context(index, c, status.copy(), restOfScript))) { - return NextScript(restOfScript.take(index+1).unescape(), restOfScript.drop(index+1)) + return NextScript(restOfScript.take(index + 1).unescape(), restOfScript.drop(index + 1)) } } if (status.isNotEscaped()) { @@ -224,6 +226,22 @@ class Function( return NextScript(value, restOfScript.apply { dropWhile { it in chars } }) } + private fun NextScript.changeValue(block: (T) -> T): NextScript { + return NextScript(block(value), restOfScript) + } + + private fun NextScript.changeScript(block: (String) -> String): NextScript { + return NextScript(value, block(restOfScript)) + } + + private fun NextScript.dropOneOf(vararg endTextList: String): NextScript { + return changeScript { script -> + endTextList + .filter { script.startsWith(it) } + .let { script.drop(it.size) } + } + } + private fun ScriptPart.toArgument(): Parameter { var script: ScriptPart = this.trimSpace() return Parameter( @@ -245,10 +263,7 @@ class Function( private fun ScriptPart.getArgName(): NextScript { try { - return getNextScript { - listOf(" ", "\n") - .any { afterBeginBy(it) } - } + return getNextScript { afterBeginBy(" ", "\n") } } catch (e: NameMalformed) { throw ArgNameMalformed(null, e) } @@ -256,10 +271,8 @@ class Function( private fun ScriptPart.getArgType(): NextScript { val fullType = try { - getNextScript { - listOf(" default ", "=", ")") - .any { afterBeginBy(it) } - } + val endTextList = arrayOf(" default ", "=", ")") + getNextScript { afterBeginBy(texts = endTextList) } } catch (e: ParseError) { throw ArgTypeMalformed(null, e) } @@ -285,11 +298,17 @@ class Function( ) } - /** - * TODO implement this method - */ private fun ScriptPart.getArgDefault(): NextScript { - return NextScript("plop", "") + return if (this.isEmpty() || this.restOfScript == ")") { + NextScript(null, "") + } else { + """^(\s*=\s*|\s+default\s+)(.+)\s*$""" + .toRegex(IGNORE_CASE) + .find(restOfScript) + .let { it ?: throw ArgDefaultMalformed() } + .let { it.groups[2]!!.value } + .let { NextScript(it, "") } + } } /** @@ -303,11 +322,16 @@ class Function( class ArgumentNotFound(cause: Throwable? = null): Resource.ParseException("Argument not found in script", cause) class FunctionNameMalformed(message: String? = null, cause: Throwable? = null): Resource.ParseException(message ?: "Function name is malformed", cause) + class ArgNameMalformed(message: String? = null, cause: Throwable? = null): Resource.ParseException(message ?: "Arg name is malformed", cause) + class ArgTypeMalformed(message: String? = null, cause: Throwable? = null): Resource.ParseException(message ?: "Arg type is malformed", cause) + class ArgDefaultMalformed(message: String? = null, cause: Throwable? = null): + Resource.ParseException(message ?: "Arg default is malformed", cause) + class NameMalformed(message: String? = null, cause: Throwable? = null): Resource.ParseException(message ?: "name is malformed", cause) diff --git a/src/test/kotlin/fr/postgresjson/definition/FunctionTest.kt b/src/test/kotlin/fr/postgresjson/definition/FunctionTest.kt index 4f0722a..703f959 100644 --- a/src/test/kotlin/fr/postgresjson/definition/FunctionTest.kt +++ b/src/test/kotlin/fr/postgresjson/definition/FunctionTest.kt @@ -186,6 +186,32 @@ class FunctionTest: FreeSpec({ param[0].type.scale shouldBe 8 } } + + "parameters with default text" - { + val param = Function( + // language=PostgreSQL + """ + create or replace function myfun(one text default 'example') returns text language plpgsql as + $$ begin end;$$; + """.trimIndent() + ).parameters + + "should have 1 parameters" { + param shouldHaveSize 1 + } + + "should have name" { + param[0].name shouldBe "one" + } + + "should have type name" { + param[0].type.name shouldBe "text" + } + + "should have default text" { + param[0].default shouldBe "'example'" + } + } } // "function returns" - {