This is an automated email from the ASF dual-hosted git repository. sergeykamov pushed a commit to branch intent_trait_api in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
commit 092244e87556a6ff2c66e63dfa9ddc4306af6340 Author: Sergey Kamov <[email protected]> AuthorDate: Mon Nov 14 13:09:30 2022 +0400 Intent trait API. --- .../main/scala/org/apache/nlpcraft/NCIntent2.scala | 6 ++++-- .../{NCIntent2.scala => NCMatchInput.scala} | 17 ++++++++++++---- .../org/apache/nlpcraft/NCMatchedCallback2.scala | 2 +- .../intent/matcher/NCIntentSolverManager.scala | 13 +++++++++--- .../apache/nlpcraft/nlp/test/TimeTestModel.scala | 23 ++++++++++++++-------- .../nlpcraft/nlp/test/TimeTestModelSpec.scala | 2 +- 6 files changed, 44 insertions(+), 19 deletions(-) diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala index c3e56a63..8f8aff91 100644 --- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala @@ -21,7 +21,9 @@ import org.apache.nlpcraft.* trait NCIntent2[T] { // Gets callback argument if matched. - def tryMatch(ctx: NCContext, variant: NCVariant): Option[T] + def tryMatch(mi: NCMatchInput): Option[T] - def mkResult(ctx: NCContext, arg: T): NCResult + def mkResult(mi: NCMatchInput, arg: T): NCResult + + def getId: String } diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchInput.scala similarity index 65% copy from nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala copy to nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchInput.scala index c3e56a63..8289dfa1 100644 --- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCIntent2.scala +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchInput.scala @@ -19,9 +19,18 @@ package org.apache.nlpcraft import org.apache.nlpcraft.* -trait NCIntent2[T] { - // Gets callback argument if matched. - def tryMatch(ctx: NCContext, variant: NCVariant): Option[T] +trait NCMatchInput { + // Context data. + def getModelConfig: NCModelConfig + def getRequest: NCRequest + def getTokens: List[NCToken] + def getConversation: NCConversation + + // Variant, passed one by one from sorted list. + def getVariant4Match: NCVariant + + // Helper methods. + def getAllEntities: List[NCEntity] = getVariant4Match.getEntities ++ getConversation.getStm + def hasDialogIdsBefore(idsRegex: String): Boolean = true // TBI - def mkResult(ctx: NCContext, arg: T): NCResult } diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchedCallback2.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchedCallback2.scala index 25342712..62274c9a 100644 --- a/nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchedCallback2.scala +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/NCMatchedCallback2.scala @@ -23,5 +23,5 @@ package org.apache.nlpcraft */ trait NCMatchedCallback2[T]: def getIntent: NCIntent2[T] - def getContext: NCContext + def getInput: NCMatchInput def getIntentArgument: T diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala index de8cc7e7..c5a497f5 100644 --- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/intent/matcher/NCIntentSolverManager.scala @@ -762,13 +762,20 @@ class NCIntentSolverManager( ) val i: NCIntent2[_] = intent - i.tryMatch(ctx, v) match + val input = new NCMatchInput: + override def getModelConfig: NCModelConfig = ctx.getModelConfig + override def getRequest: NCRequest = ctx.getRequest + override def getTokens: List[NCToken] = ctx.getTokens + override def getConversation: NCConversation = ctx.getConversation + override def getVariant4Match: NCVariant = v + + i.tryMatch(input) match case Some(data) => typ match - case REGULAR => return Left(i.mkResult(ctx, data)) + case REGULAR => return Left(i.mkResult(input, data)) case _ => return Right(new NCMatchedCallback2[Any](): override def getIntent: NCIntent2[Any] = i.asInstanceOf[NCIntent2[Any]] - override def getContext: NCContext = ctx + override def getInput: NCMatchInput = input override def getIntentArgument: Any = data.asInstanceOf[Any] ) diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModel.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModel.scala index 1d68127e..19ad93a6 100644 --- a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModel.scala +++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModel.scala @@ -18,14 +18,21 @@ class TimeTestModel extends NCModel( // Add this method to API. override def getIntents: List[NCIntent2[_]] = List( new NCIntent2[TimeData] : - override def tryMatch(ctx: NCContext, v: NCVariant): Option[TimeData] = - Option.when(v.getEntities.length == 1 && v.getEntities.count(_.getId == "x:time") == 1)(TimeData()) - override def mkResult(ctx: NCContext, data: TimeData): NCResult = NCResult("Asked for local") // TBI. + override def getId: String = "id1" + override def tryMatch(mi: NCMatchInput): Option[TimeData] = + val varEnts = mi.getVariant4Match.getEntities + + Option.when(varEnts.length == 1 && varEnts.count(_.getId == "x:time") == 1)(TimeData()) + override def mkResult(mi: NCMatchInput, data: TimeData): NCResult = NCResult("Asked for local") // TBI. , new NCIntent2[CityTimeData] : - override def tryMatch(ctx: NCContext, v: NCVariant): Option[CityTimeData] = - val cities = v.getEntities.filter(_.getId == "opennlp:location") - val times = v.getEntities.filter(_.getId == "x:time") - Option.when(v.getEntities.length == 2 && cities.length == 1 && times.length == 1)(CityTimeData(cities.head.mkText)) - override def mkResult(ctx: NCContext, data: CityTimeData): NCResult = NCResult(s"Asked for ${data.city}") // TBI. + override def getId: String = "id2" + override def tryMatch(mi: NCMatchInput): Option[CityTimeData] = + val varEnts = mi.getVariant4Match.getEntities + val allEnts = mi.getAllEntities + + val cities = varEnts.filter(_.getId == "opennlp:location") + val times = allEnts.filter(_.getId == "x:time") + Option.when(cities.length == 1 && times.length == 1 && varEnts.forall(p => p.getId == "opennlp:location" || p.getId == "x:time"))(CityTimeData(cities.head.mkText)) + override def mkResult(mi: NCMatchInput, data: CityTimeData): NCResult = NCResult(s"Asked for ${data.city}") // TBI. ) diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModelSpec.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModelSpec.scala index af54adfb..40d7a699 100644 --- a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModelSpec.scala +++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/test/TimeTestModelSpec.scala @@ -47,7 +47,7 @@ class TimeTestModelSpec extends AnyFunSuite: System.out.println(s"Argument: ${intent.getIntentArgument}") - val res = intent.getIntent.mkResult(intent.getContext, intent.getIntentArgument) + val res = intent.getIntent.mkResult(intent.getInput, intent.getIntentArgument) System.out.println(s"Body: ${res.getBody}")
