This is an automated email from the ASF dual-hosted git repository.
sergeykamov pushed a commit to branch NLPCRAFT-477
in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
The following commit(s) were added to refs/heads/NLPCRAFT-477 by this push:
new 9c1e67d Dialog manager added.
9c1e67d is described below
commit 9c1e67d4ea93899b1a33fd42d44530d5d1ce9f2a
Author: Sergey Kamov <[email protected]>
AuthorDate: Thu Feb 3 14:08:11 2022 +0300
Dialog manager added.
---
.../conversation/NCConversationHolder.scala | 21 ++-
.../conversation/NCConversationManager.scala | 6 +-
.../internal/dialogflow/NCDialogFlowManager.scala | 175 +++++++++++++++++++++
.../apache/nlpcraft/internal/util/NCUtils.scala | 22 +--
.../dialogflow/NCDialogFlowManagerSpec.scala | 130 +++++++++++++++
5 files changed, 336 insertions(+), 18 deletions(-)
diff --git
a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationHolder.scala
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationHolder.scala
index bad168c..de4a0e5 100644
---
a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationHolder.scala
+++
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationHolder.scala
@@ -46,13 +46,22 @@ case class NCConversationHolder(
// Short-Term-Memory.
private val stm = mutable.ArrayBuffer.empty[ConversationItem]
private val lastEnts = mutable.ArrayBuffer.empty[Iterable[NCEntity]]
+ private val ctx = mutable.ArrayBuffer.empty[NCEntity]
- @volatile private var ctx: util.List[NCEntity] = new
util.ArrayList[NCEntity]()
@volatile private var lastUpdateTstamp = NCUtils.nowUtcMs()
@volatile private var depth = 0
/**
*
+ * @param newCtx
+ */
+ private def replaceContext(newCtx: mutable.ArrayBuffer[NCEntity]): Unit =
+ require(Thread.holdsLock(stm))
+ ctx.clear()
+ ctx ++= newCtx
+
+ /**
+ *
*/
private def squeezeEntities(): Unit =
require(Thread.holdsLock(stm))
@@ -93,7 +102,7 @@ case class NCConversationHolder(
squeezeEntities()
lastUpdateTstamp = now
- ctx = new
util.ArrayList[NCEntity](stm.flatMap(_.holders.map(_.entity)).asJava)
+ replaceContext(stm.flatMap(_.holders.map(_.entity)))
ack()
}
@@ -106,7 +115,7 @@ case class NCConversationHolder(
stm.synchronized {
for (item <- stm) item.holders --= item.holders.filter(h =>
p.test(h.entity))
squeezeEntities()
- ctx = ctx.asScala.filter(ent => !p.test(ent)).asJava
+ replaceContext(ctx.filter(ent => !p.test(ent)))
}
logger.trace(s"STM is cleared [usrId=$usrId, mdlId=$mdlId]")
@@ -186,7 +195,7 @@ case class NCConversationHolder(
if ctx.isEmpty then logger.trace(s"STM is empty for [$z]")
else
val tbl = NCAsciiTable("Entity ID", "Groups", "Request ID")
- ctx.asScala.foreach(ent => tbl += (
+ ctx.foreach(ent => tbl += (
ent.getId,
ent.getGroups.asScala.mkString(", "),
ent.getRequestId
@@ -199,8 +208,8 @@ case class NCConversationHolder(
*/
def getEntity: util.List[NCEntity] =
stm.synchronized {
- val reqIds =
ctx.asScala.map(_.getRequestId).distinct.zipWithIndex.toMap
- val ents = ctx.asScala.groupBy(_.getRequestId).toSeq.sortBy(p =>
reqIds(p._1)).reverse.flatMap(_._2)
+ val reqIds = ctx.map(_.getRequestId).distinct.zipWithIndex.toMap
+ val ents = ctx.groupBy(_.getRequestId).toSeq.sortBy(p =>
reqIds(p._1)).reverse.flatMap(_._2)
new util.ArrayList[NCEntity](ents.asJava)
}
diff --git
a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationManager.scala
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationManager.scala
index 2b87f6a..0a58184 100644
---
a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationManager.scala
+++
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/conversation/NCConversationManager.scala
@@ -37,12 +37,14 @@ class NCConversationManager(mdlCfg: NCModelConfig) extends
LazyLogging:
* @return
*/
def start(): Unit =
- gc = NCUtils.mkThread(s"conv-mgr-gc-@${mdlCfg.getId}") { t =>
+ gc = NCUtils.mkThread("conv-mgr-gc", mdlCfg.getId) { t =>
while (!t.isInterrupted)
try
convs.synchronized {
val sleepTime = clearForTimeout() - NCUtils.now()
- if sleepTime > 0 then convs.wait(sleepTime)
+ if sleepTime > 0 then
+ logger.trace(s"${t.getName} waits for $sleepTime
ms.")
+ convs.wait(sleepTime)
}
catch
case _: InterruptedException => // No-op.
diff --git
a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/dialogflow/NCDialogFlowManager.scala
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/dialogflow/NCDialogFlowManager.scala
new file mode 100644
index 0000000..0a35329
--- /dev/null
+++
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/dialogflow/NCDialogFlowManager.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nlpcraft.internal.dialogflow
+
+import com.typesafe.scalalogging.LazyLogging
+import org.apache.nlpcraft.*
+import org.apache.nlpcraft.internal.ascii.NCAsciiTable
+
+import java.text.DateFormat
+import java.util
+import java.util.*
+import scala.collection.*
+import com.typesafe.scalalogging.LazyLogging
+import org.apache.nlpcraft.internal.util.NCUtils
+
+import java.time.format.DateTimeFormatter
+
+/**
+ * Dialog flow manager.
+ */
+class NCDialogFlowManager(mdlCfg: NCModelConfig) extends LazyLogging:
+ private final val flow = mutable.HashMap.empty[String,
mutable.ArrayBuffer[NCDialogFlowItem]]
+
+ @volatile private var gc: Thread = _
+
+ /**
+ *
+ * @return
+ */
+ def start(): Unit =
+ gc = NCUtils.mkThread("dialog-mgr-gc", mdlCfg.getId) { t =>
+ while (!t.isInterrupted)
+ try
+ flow.synchronized {
+ val sleepTime = clearForTimeout() - NCUtils.now()
+
+ if sleepTime > 0 then
+ logger.trace(s"${t.getName} waits for $sleepTime
ms.")
+ flow.wait(sleepTime)
+ }
+ catch
+ case _: InterruptedException => // No-op.
+ case e: Throwable => logger.error(s"Unexpected error for
thread: ${t.getName}", e)
+ }
+
+ gc.start()
+
+ /**
+ *
+ */
+ def stop(): Unit =
+ NCUtils.stopThread(gc)
+ gc = null
+ flow.clear()
+
+ /**
+ * Adds matched (winning) intent to the dialog flow.
+ *
+ * @param intentMatch
+ * @param res Intent callback result.
+ * @param ctx Original query context.
+ * @param parent
+ */
+ def addMatchedIntent(intentMatch: NCIntentMatch, res: NCResult, ctx:
NCContext): Unit =
+ val item: NCDialogFlowItem = new NCDialogFlowItem:
+ override val getIntentMatch: NCIntentMatch = intentMatch
+ override val getRequest: NCRequest = ctx.getRequest
+ override val getResult: NCResult = res
+
+ flow.synchronized {
+ flow.getOrElseUpdate(ctx.getRequest.getUserId,
mutable.ArrayBuffer.empty[NCDialogFlowItem]).append(item)
+ flow.notifyAll()
+ }
+
+ /**
+ * Gets sequence of dialog flow items sorted from oldest to newest (i.e.
dialog flow) for given user ID.
+ *
+ * @param usrId User ID.
+ * @return Dialog flow.
+ */
+ def getDialogFlow(usrId: String): Seq[NCDialogFlowItem] =
+ flow.synchronized { flow.get(usrId) } match
+ case Some(buf) => buf.toSeq
+ case None => Seq.empty
+
+ /**
+ * Prints out ASCII table for current dialog flow.
+ *
+ * @param usrId User ID.
+ */
+ def ack(usrId: String): Unit =
+ val tbl = NCAsciiTable(
+ "#",
+ "Intent ID",
+ "Request ID",
+ "Text",
+ "Received"
+ )
+
+ getDialogFlow(usrId).zipWithIndex.foreach { (itm, idx) =>
+ tbl += (
+ idx + 1,
+ itm.getIntentMatch.getIntentId,
+ itm.getRequest.getRequestId,
+ itm.getRequest.getText,
+ DateFormat.getDateTimeInstance.format(new
Date(itm.getRequest.getReceiveTimestamp))
+ )
+ }
+
+ logger.info(s"""Current dialog flow (oldest first) for
[mdlId=${mdlCfg.getId}, usrId=$usrId]\n${tbl.toString()}""")
+
+ /**
+ * Gets next clearing time.
+ */
+ private def clearForTimeout(): Long =
+ require(Thread.holdsLock(flow))
+
+ val timeout = mdlCfg.getConversationTimeout
+ val bound = NCUtils.now() - timeout
+ var next = Long.MaxValue
+
+ val delKeys = mutable.ArrayBuffer.empty[String]
+
+ for ((usrId, values) <- flow)
+ values --= values.filter(_.getRequest.getReceiveTimestamp < bound)
+
+ if values.nonEmpty then
+ val candidate =
values.map(_.getRequest.getReceiveTimestamp).min + timeout
+ if next > candidate then next = candidate
+ else
+ delKeys += usrId
+
+ if delKeys.nonEmpty then flow --= delKeys
+
+ next
+
+ /**
+ * Clears dialog history for given user ID.
+ *
+ * @param usrId User ID.
+ */
+ def clear(usrId: String): Unit =
+ flow.synchronized {
+ flow -= usrId
+ flow.notifyAll()
+ }
+
+ /**
+ * Clears dialog history for given user ID and predicate.
+ *
+ * @param usrId User ID.
+ * @param mdlId Model ID.
+ * @param pred Intent ID predicate.
+ * @param parent Parent span, if any.
+ */
+ def clearForPredicate(usrId: String, pred: String => Boolean): Unit =
+ flow.synchronized {
+ flow(usrId) = flow(usrId).filterNot(v =>
pred(v.getIntentMatch.getIntentId))
+ flow.notifyAll()
+ }
\ No newline at end of file
diff --git
a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
index 478e9f3..59b259e 100644
--- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
+++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/util/NCUtils.scala
@@ -347,7 +347,9 @@ object NCUtils extends LazyLogging:
@volatile private var stopped = false
override def isInterrupted: Boolean = super.isInterrupted ||
stopped
- override def interrupt(): Unit = stopped = true; super.interrupt()
+ override def interrupt(): Unit =
+ stopped = true
+ super.interrupt()
override def run(): Unit =
logger.trace(s"Thread started: $name")
@@ -362,6 +364,15 @@ object NCUtils extends LazyLogging:
stopped = true
/**
+ *
+ * @param prefix
+ * @param mdlId
+ * @param body
+ * @return
+ */
+ def mkThread(prefix: String, mdlId: String)(body: Thread => Unit): Thread
= mkThread(s"$prefix-@$mdlId")(body)
+
+ /**
* Gets resource existing flag.
*
* @param res Resource.
@@ -403,15 +414,6 @@ object NCUtils extends LazyLogging:
else E(s"Source not found or unsupported: $src")
/**
- * Makes thread.
- *
- * @param name Name.
- * @param body Thread body.
- */
- def mkThread(name: String, body: Runnable): Thread =
- mkThread(name) { _ => body.run() }
-
- /**
* Sleeps number of milliseconds properly handling exceptions.
*
* @param delay Number of milliseconds to sleep.
diff --git
a/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/dialogflow/NCDialogFlowManagerSpec.scala
b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/dialogflow/NCDialogFlowManagerSpec.scala
new file mode 100644
index 0000000..478a925
--- /dev/null
+++
b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/dialogflow/NCDialogFlowManagerSpec.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nlpcraft.internal.dialogflow
+
+import org.apache.nlpcraft.internal.util.NCUtils
+import org.apache.nlpcraft.nlp.util.NCTestRequest
+import org.apache.nlpcraft.{NCContext, NCConversation, NCEntity,
NCIntentMatch, NCModelConfig, NCRequest, NCResult, NCVariant}
+import org.junit.jupiter.api.{AfterEach, Test}
+
+import java.util
+
+/**
+ *
+ */
+class NCDialogFlowManagerSpec:
+ case class IntentMatchMock(intentId: String) extends NCIntentMatch:
+ override def getIntentId: String = intentId
+ override def getIntentEntities: util.List[util.List[NCEntity]] = null
+ override def getTermEntities(idx: Int): util.List[NCEntity] = null
+ override def getTermEntities(termId: String): util.List[NCEntity] =
null
+ override def getVariant: NCVariant = null
+
+ case class ContextMock(userId: String, reqTs: Long = NCUtils.now())
extends NCContext:
+ override def isOwnerOf(ent: NCEntity): Boolean = false
+ override def getModelConfig: NCModelConfig = null
+ override def getRequest: NCRequest = NCTestRequest(txt = "Any", userId
= userId, ts = reqTs)
+ override def getConversation: NCConversation = null
+ override def getVariants: util.Collection[NCVariant] = null
+
+ case class ModelConfigMock(timeout: Long = Long.MaxValue) extends
NCModelConfig("testId", "test", "1.0", "Test description", "Test origin"):
+ override def getConversationTimeout: Long = timeout
+
+ private var mgr: NCDialogFlowManager = _
+
+ /**
+ *
+ * @param expSizes
+ */
+ private def check(expSizes: (String, Int)*): Unit =
+ for ((usrId, expSize) <- expSizes)
+ val size = mgr.getDialogFlow(usrId).size
+ require(size == expSize, s"Expected: $expSize for '$usrId', but
found: $size")
+
+ /**
+ *
+ * @param userIds
+ */
+ private def ask(userIds: String*): Unit = for (userId <- userIds)
mgr.ack(userId)
+
+ /**
+ *
+ * @param id
+ * @param ctx
+ */
+ private def addMatchedIntent(id: String, ctx: NCContext): Unit =
mgr.addMatchedIntent(IntentMatchMock(id), null, ctx)
+
+ /**
+ *
+ */
+ @AfterEach
+ def cleanUp(): Unit = if mgr != null then mgr.stop()
+
+ @Test
+ def test(): Unit =
+ mgr = NCDialogFlowManager(ModelConfigMock())
+
+ val now = NCUtils.now()
+
+ addMatchedIntent("i11", ContextMock("user1"))
+ addMatchedIntent("i12", ContextMock("user1"))
+ addMatchedIntent("i21", ContextMock("user2"))
+ addMatchedIntent("i22", ContextMock("user2"))
+ addMatchedIntent("i31", ContextMock("user3"))
+
+ // Initial.
+ ask("user1", "user2", "user3", "user4")
+ check("user1" -> 2, "user2" -> 2, "user3" -> 1, "user4" -> 0)
+
+ mgr.clear(usrId = "user4")
+ check("user1" -> 2, "user2" -> 2, "user3" -> 1, "user4" -> 0)
+
+ mgr.clear(usrId = "user1")
+ check("user1" -> 0, "user2" -> 2, "user3" -> 1, "user4" -> 0)
+
+ mgr.clearForPredicate(usrId = "user2", _ == "i21")
+ check("user1" -> 0, "user2" -> 1, "user3" -> 1, "user4" -> 0)
+
+ mgr.clear(usrId = "user2")
+ mgr.clear(usrId = "user3")
+ check("user1" -> 0, "user2" -> 0, "user3" -> 0, "user4" -> 0)
+
+ @Test
+ def testTimeout(): Unit =
+ val delay = 500
+ val timeout = delay * 20
+
+ mgr = NCDialogFlowManager(ModelConfigMock(timeout))
+
+ val now = NCUtils.now()
+
+ addMatchedIntent("any", ContextMock("user1", now))
+ addMatchedIntent("any", ContextMock("user1", now - timeout - delay))
+ addMatchedIntent("any", ContextMock("user2", now - timeout))
+
+ // Initial.
+ ask("user1", "user2")
+ check("user1" -> 2, "user2" -> 1)
+
+ mgr.start()
+
+ Thread.sleep(delay * 3)
+ check("user1" -> 1, "user2" -> 0)
+
+ mgr.stop()
+ check("user1" -> 0, "user2" -> 0)