This is an automated email from the ASF dual-hosted git repository. sergeykamov pushed a commit to branch NLPCRAFT-472 in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
The following commit(s) were added to refs/heads/NLPCRAFT-472 by this push: new bf40a9d WIP. bf40a9d is described below commit bf40a9d7eb812b3f71848f4ec40e01d0e6722e8f Author: Sergey Kamov <skhdlem...@gmail.com> AuthorDate: Tue Jan 11 15:25:31 2022 +0300 WIP. --- .../nlpcraft/internal/NCModelClientImpl.scala | 13 +- .../nlpcraft/internal/NCRequestProcessor.scala | 184 +++++++++++++++++++ .../apache/nlpcraft/internal/NCSentenceHelper.java | 197 +++++++++++++++++++++ .../nlpcraft/internal/NCRequestProcessorSpec.scala | 54 ++++++ .../semantic/NCSemanticEntityParserSpec.scala | 19 +- .../apache/nlpcraft/nlp/util/NCTestConfig.scala | 32 ++-- 6 files changed, 476 insertions(+), 23 deletions(-) diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCModelClientImpl.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCModelClientImpl.scala index f668e99..6f4c051 100644 --- a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCModelClientImpl.scala +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCModelClientImpl.scala @@ -37,6 +37,8 @@ import scala.jdk.CollectionConverters.* class NCModelClientImpl(mdl: NCModel) extends LazyLogging: verify(mdl.getConfig) + private val proc = NCRequestProcessor(mdl) + /** * * @throws NCException @@ -49,13 +51,12 @@ class NCModelClientImpl(mdl: NCModel) extends LazyLogging: if list == null then throw new NCException(s"List cannot be null: '$name'") else if list.isEmpty then throw new NCException(s"List cannot be empty: '$name'") - check(cfg.getId, "ID") + check(cfg.getId, "Id") check(cfg.getName, "Name") check(cfg.getVersion, "Version") check(cfg.getTokenParser, "Token parser") - checkList(cfg.getEntityParsers, "Entity parsers tapser") + checkList(cfg.getEntityParsers, "Entity parsers") - // TODO: implement /** * * @param txt @@ -63,7 +64,7 @@ class NCModelClientImpl(mdl: NCModel) extends LazyLogging: * @param usrId * @return */ - def ask(txt: String, data: JMap[String, AnyRef], usrId: String): CompletableFuture[NCResult] = null + def ask(txt: String, data: JMap[String, AnyRef], usrId: String): CompletableFuture[NCResult] = proc.ask(txt, data, usrId) /** * @@ -72,7 +73,9 @@ class NCModelClientImpl(mdl: NCModel) extends LazyLogging: * @param usrId * @return */ - def askSync(txt: String, data: JMap[String, AnyRef], usrId: String): NCResult = null + def askSync(txt: String, data: JMap[String, AnyRef], usrId: String): NCResult = proc.askSync(txt, data, usrId) + + // TODO: implement /** * diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCRequestProcessor.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCRequestProcessor.scala new file mode 100644 index 0000000..59a937a --- /dev/null +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCRequestProcessor.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nlpcraft.internal + +import com.typesafe.scalalogging.LazyLogging +import org.apache.nlpcraft.{NCTokenEnricher, *} +import org.apache.nlpcraft.internal.util.NCUtils + +import java.util +import java.util.concurrent.* +import java.util.concurrent.atomic.AtomicReference +import java.util.{ArrayList, UUID, List as JList, Map as JMap} +import scala.concurrent.ExecutionContext +import scala.jdk.CollectionConverters.* + +// TODO: move it to right package. + +/** + * + */ +object NCRequestProcessor { + case class VariantsHolder(request: NCRequest, variants: Seq[NCVariant], checkCancel: () => Unit) +} + +import NCRequestProcessor._ + +/** + * + * @param mdl */ +class NCRequestProcessor(mdl: NCModel) extends LazyLogging : + require(mdl != null) + + // TODO: shutdown. + private val pool = new java.util.concurrent.ForkJoinPool() + + private val cfg = mdl.getConfig + private var tokParser: NCTokenParser = _ + private var tokEnrichers: Seq[NCTokenEnricher] = _ + private var entEnrichers: Seq[NCEntityEnricher] = _ + private var entParsers: Seq[NCEntityParser] = _ + private var tokenValidators: Seq[NCTokenValidator] = _ + private var entityValidators: Seq[NCEntityValidator] = _ + private var variantValidators: Seq[NCVariantValidator] = _ + + init() + + /** + * */ + private def init(): Unit = + def nvl[T](list: JList[T]): Seq[T] = if list == null then Seq.empty else list.asScala.toSeq + + this.tokParser = cfg.getTokenParser + this.tokEnrichers = nvl(cfg.getTokenEnrichers) + this.entEnrichers = nvl(cfg.getEntityEnrichers) + this.entParsers = nvl(cfg.getEntityParsers) + this.tokenValidators = nvl(cfg.getTokenValidators) + this.entityValidators = nvl(cfg.getEntityValidators) + this.variantValidators = nvl(cfg.getVariantValidators) + + require(tokParser != null && entParsers.nonEmpty) + + /** + * + * @param h + * @return + */ + private def matchAndExecute(h: VariantsHolder): NCResult = ??? // TODO: implement. + + /** + * + * @param txt + * @param data + * @param usrId + * @param checkCancel + * @return + */ + // It returns intermediate variants holder just for test reasons. + private[internal] def prepare(txt: String, data: JMap[String, AnyRef], usrId: String, checkCancel: () => Unit): VariantsHolder = + require(txt != null && usrId != null) + + val toks = tokParser.tokenize(txt) + if toks.isEmpty then throw new NCException(s"Invalid request: $txt") // TODO: error text + + val req: NCRequest = new NCRequest: + override val getUserId: String = usrId + override val getRequestId: String = UUID.randomUUID().toString + override val getText: String = txt + override val getReceiveTimestamp: Long = System.currentTimeMillis() + override val getRequestData: JMap[String, AnyRef] = data + + for (enricher <- tokEnrichers) + checkCancel() + enricher.enrich(req, cfg, toks) + + for (validator <- tokenValidators) + checkCancel() + validator.validate(req, cfg, toks) + + val entsList = new util.ArrayList[NCEntity]() + + for (parser <- entParsers) + checkCancel() + val ents = parser.parse(req, cfg, toks) + if ents == null then + // TODO: error text. + throw new NCException(s"Invalid entities parser null result [text=$txt, parser=${parser.getClass.getName}]") + entsList.addAll(ents) + + // TODO: error text. + if entsList.isEmpty then throw new NCException(s"No entities found for text: $txt") + + for (enricher <- entEnrichers) + checkCancel() + enricher.enrich(req, cfg, entsList) + for (validator <- entityValidators) + checkCancel() + validator.validate(req, cfg, entsList) + + val entities = entsList.asScala.toSeq + + val over = + toks.asScala. + map(t => t.getIndex -> entities.filter(_.getTokens.contains(t))). + filter(_._2.size > 1). + flatMap(_._2). + toSet. + flatMap(ent => ent.getTokens.asScala.map(_.getIndex).map(_ -> ent)). + groupBy { case (idx, _) => idx }. + map { case (_, seq) => seq.map { case (_, ent) => ent } }. + toSeq. + sortBy(-_.size) + + val dels = NCSentenceHelper.findCombinations(over.map(_.asJava).asJava, pool).asScala.map(_.asScala) + + var variantsList: JList[NCVariant] = dels.map(delComb => + new NCVariant: + override def getEntities: JList[NCEntity] = entities.filter(e => !delComb.contains(e)).asJava + ).asJava + + for (validator <- variantValidators) + checkCancel() + variantsList = validator.filter(req, cfg, variantsList) + + VariantsHolder(req, variantsList.asScala.toSeq, checkCancel) + + /** + * + * @param txt + * @param data + * @param usrId + * @return + */ + def askSync(txt: String, data: JMap[String, AnyRef], usrId: String): NCResult = + matchAndExecute(prepare(txt, data, usrId, () => ())) + + /** + * + * @param txt + * @param data + * @param usrId + * @return + */ + def ask(txt: String, data: JMap[String, AnyRef], usrId: String): CompletableFuture[NCResult] = + val fut = new CompletableFuture[NCResult] + + // TODO: error text. + def check = () => if fut.isCancelled then throw new NCException("Task was cancelled") + + fut.completeAsync(() => matchAndExecute(prepare(txt, data, usrId, check))) diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCSentenceHelper.java b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCSentenceHelper.java new file mode 100644 index 0000000..fd02253 --- /dev/null +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/internal/NCSentenceHelper.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nlpcraft.internal; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.RecursiveTask; + +import static java.util.stream.Collectors.toList; + +/** + * It is not converted to scala because scala and java long values implicit conversion performance problems. + */ +class NCSentenceHelper extends RecursiveTask<List<Long>> { + private static final long THRESHOLD = (long)Math.pow(2, 20); + + private final long lo; + private final long hi; + private final long[] wordBits; + private final int[] wordCounts; + + private NCSentenceHelper(long lo, long hi, long[] wordBits, int[] wordCounts) { + this.lo = lo; + this.hi = hi; + this.wordBits = wordBits; + this.wordCounts = wordCounts; + } + + private List<Long> computeLocal() { + List<Long> res = new ArrayList<>(); + + for (long comboBits = lo; comboBits < hi; comboBits++) { + boolean match = true; + + // For each input row we check if subtracting the current combination of words + // from the input row would give us the expected result. + for (int j = 0; j < wordBits.length; j++) { + // Get bitmask of how many words can be subtracted from the row. + // Check if there is more than 1 word remaining after subtraction. + if (wordCounts[j] - Long.bitCount(wordBits[j] & comboBits) > 1) { + // Skip this combination. + match = false; + + break; + } + } + + if (match && excludes(comboBits, res)) + res.add(comboBits); + } + + return res; + } + + private List<Long> forkJoin() { + long mid = lo + hi >>> 1L; + + NCSentenceHelper t1 = new NCSentenceHelper(lo, mid, wordBits, wordCounts); + NCSentenceHelper t2 = new NCSentenceHelper(mid, hi, wordBits, wordCounts); + + t2.fork(); + + return merge(t1.compute(), t2.join()); + } + + private static List<Long> merge(List<Long> l1, List<Long> l2) { + if (l1.isEmpty()) + return l2; + else if (l2.isEmpty()) + return l1; + + int size1 = l1.size(); + int size2 = l2.size(); + + if (size1 == 1 && size2 > 1 || size2 == 1 && size1 > 1) { + // Minor optimization in case if one of the lists has only one element. + List<Long> res = size1 == 1 ? l2 : l1; + Long val = size1 == 1 ? l1.get(0) : l2.get(0); + + if (excludes(val, res)) + res.add(val); + + return res; + } + + List<Long> res = new ArrayList<>(size1 + size2); + + for (int i = 0, max = Math.max(size1, size2); i < max; i++) { + Long v1 = i < size1 ? l1.get(i) : null; + Long v2 = i < size2 ? l2.get(i) : null; + + if (v1 != null && v2 != null) { + if (containsAllBits(v1, v2)) + v1 = null; + else if (containsAllBits(v2, v1)) + v2 = null; + } + + if (v1 != null && excludes(v1, res)) + res.add(v1); + + if (v2 != null && excludes(v2, res)) + res.add(v2); + } + + return res; + } + + private static boolean excludes(long bits, List<Long> allBits) { + for (Long allBit : allBits) + if (containsAllBits(bits, allBit)) + return false; + + return true; + } + + private static boolean containsAllBits(long bitSet1, long bitSet2) { + return (bitSet1 & bitSet2) == bitSet2; + } + + private static <T> long wordsToBits(Set<T> words, List<T> dict) { + long bits = 0; + + for (int i = 0, n = dict.size(); i < n; i++) + if (words.contains(dict.get(i))) + bits |= 1L << i; + + return bits; + } + + private static <T> List<T> bitsToWords(long bits, List<T> dict) { + List<T> words = new ArrayList<>(Long.bitCount(bits)); + + for (int i = 0, n = dict.size(); i < n; i++) + if ((bits & 1L << i) != 0) + words.add(dict.get(i)); + + return words; + } + + @Override + protected List<Long> compute() { + return hi - lo <= THRESHOLD ? computeLocal() : forkJoin(); + } + + /** + * + * @param words + * @param pool + * @param <T> + * @return + */ + static <T> List<List<T>> findCombinations(List<Set<T>> words, ForkJoinPool pool) { + assert words != null && !words.isEmpty(); + assert pool != null; + + if (words.stream().allMatch(p -> p.size() == 1)) + return Collections.singletonList(Collections.emptyList()); + + // Build dictionary of unique words. + List<T> dict = words.stream().flatMap(Collection::stream).distinct().collect(toList()); + + if (dict.size() > Long.SIZE) + // Note: Power set of 64 words results in 9223372036854775807 combinations. + throw new IllegalArgumentException("Dictionary is too long: " + dict.size()); + + // Convert words to bitmasks (each bit corresponds to an index in the dictionary). + long[] wordBits = words.stream().sorted(Comparator.comparingInt(Set::size)).mapToLong(row -> wordsToBits(row, dict)).toArray(); + // Cache words count per row. + int[] wordCounts = words.stream().sorted(Comparator.comparingInt(Set::size)).mapToInt(Set::size).toArray(); + + // Prepare Fork/Join task to iterate over the power set of all combinations. + return + pool.invoke(new NCSentenceHelper(1, (long)Math.pow(2, dict.size()), wordBits, wordCounts)). + stream().map(bits -> bitsToWords(bits, dict)).collect(toList()); + } +} diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/NCRequestProcessorSpec.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/NCRequestProcessorSpec.scala new file mode 100644 index 0000000..51d332d --- /dev/null +++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/internal/NCRequestProcessorSpec.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nlpcraft.internal + +import org.apache.nlpcraft.nlp.entity.parser.semantic.impl.en.NCEnPorterStemmer +import org.apache.nlpcraft.nlp.entity.parser.semantic.* +import org.apache.nlpcraft.nlp.util.* +import org.apache.nlpcraft.* +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters.* + +/** + * + */ +class NCRequestProcessorSpec: + private def test(txt: String, variantCnt: Int, elements: NCSemanticElement*): Unit = + val cfg = NCTestConfig.EN.clone() + + val parser = new NCSemanticEntityParser(new NCEnPorterStemmer, cfg.getTokenParser, elements.asJava) + cfg.getEntityParsers.clear() + cfg.getEntityParsers.add(parser) + + val res = new NCRequestProcessor(new NCModelAdapter(cfg)).prepare(txt, null, "userId", () => ()) + + println(s"Variants count: ${res.variants.size}") + + for ((v, idx) <- res.variants.zipWithIndex) + println(s"Variant: $idx") + NCTestUtils.printEntities(txt, v.getEntities.asScala.toSeq) + + require(res.variants.size == variantCnt) + + @Test + def test(): Unit = + import org.apache.nlpcraft.nlp.entity.parser.semantic.NCSemanticTestElement as Elem + + test("t1 t2", 4, Elem("t1", "t2"), Elem("t2", "t1")) + test("t1 t2", 2, Elem("t1", "t2"), Elem("t2")) \ No newline at end of file diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/entity/parser/semantic/NCSemanticEntityParserSpec.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/entity/parser/semantic/NCSemanticEntityParserSpec.scala index a51a914..b909816 100644 --- a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/entity/parser/semantic/NCSemanticEntityParserSpec.scala +++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/entity/parser/semantic/NCSemanticEntityParserSpec.scala @@ -39,7 +39,7 @@ import scala.jdk.OptionConverters.RichOptional * @param values * @param groups */ -case class Element( +case class NCSemanticTestElement( id: String, synonyms: Set[String] = Set.empty, values: Map[String, Set[String]] = Map.empty, @@ -52,6 +52,9 @@ case class Element( override def getSynonyms: JSet[String] = synonyms.asJava override def getProperties: JMap[String, Object] = props.asJava +object NCSemanticTestElement { + def apply(id: String, synonyms: String*) = new NCSemanticTestElement(id, synonyms = synonyms.toSet) +} /** * */ @@ -62,19 +65,19 @@ class NCSemanticEntityParserSpec: NCTestConfig.EN.getTokenParser, Seq( // Standard. - Element("t1", synonyms = Set("t1")), + NCSemanticTestElement("t1", synonyms = Set("t1")), // No extra synonyms. - Element("t2"), + NCSemanticTestElement("t2"), // Multiple words. - Element("t3", synonyms = Set("t3 t3")), + NCSemanticTestElement("t3", synonyms = Set("t3 t3")), // Value. No extra synonyms. - Element("t4", values = Map("value4" -> Set.empty)), + NCSemanticTestElement("t4", values = Map("value4" -> Set.empty)), // Value. Multiple words. - Element("t5", values = Map("value5" -> Set("value 5"))), + NCSemanticTestElement("t5", values = Map("value5" -> Set("value 5"))), // Elements data. - Element("t6", props = Map("testKey" -> "testValue")), + NCSemanticTestElement("t6", props = Map("testKey" -> "testValue")), // Regex. - Element("t7", synonyms = Set("x //[a-d]+//")) + NCSemanticTestElement("t7", synonyms = Set("x //[a-d]+//")) ).asJava ) diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/util/NCTestConfig.scala b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/util/NCTestConfig.scala index 0b85193..bde08b7 100644 --- a/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/util/NCTestConfig.scala +++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/nlp/util/NCTestConfig.scala @@ -27,19 +27,31 @@ import java.util.{Optional, ArrayList as JAList, List as JList} * * @param tokParser */ -case class NCTestConfig(tokParser: NCTokenParser) extends NCPropertyMapAdapter with NCModelConfig: +case class NCTestConfig(tokParser: NCTokenParser) extends NCPropertyMapAdapter with NCModelConfig with Cloneable: require(tokParser != null) - override val getId: String = "testId" - override val getName: String = "test" - override val getVersion: String = "1.0" + override val getId = "testId" + override val getName = "test" + override val getVersion = "1.0" override val getTokenParser: NCTokenParser = tokParser - override val getTokenEnrichers: JList[NCTokenEnricher] = new JAList[NCTokenEnricher]() - override val getEntityEnrichers: JList[NCEntityEnricher] = new JAList[NCEntityEnricher]() - override val getEntityParsers: JList[NCEntityParser] = new JAList[NCEntityParser]() - override val getTokenValidators: JList[NCTokenValidator] = new JAList[NCTokenValidator]() - override val getEntityValidators: JList[NCEntityValidator] = new JAList[NCEntityValidator]() - override val getVariantValidators: JList[NCVariantValidator] = new JAList[NCVariantValidator]() + override val getTokenEnrichers = new JAList[NCTokenEnricher]() + override val getEntityEnrichers = new JAList[NCEntityEnricher]() + override val getEntityParsers = new JAList[NCEntityParser]() + override val getTokenValidators = new JAList[NCTokenValidator]() + override val getEntityValidators = new JAList[NCEntityValidator]() + override val getVariantValidators = new JAList[NCVariantValidator]() + + override def clone(): NCTestConfig = + val copy = NCTestConfig(this.tokParser) + + copy.getTokenEnrichers.addAll(this.getTokenEnrichers) + copy.getEntityEnrichers.addAll(this.getEntityEnrichers) + copy.getEntityParsers.addAll(this.getEntityParsers) + copy.getTokenValidators.addAll(this.getTokenValidators) + copy.getEntityValidators.addAll(this.getEntityValidators) + copy.getVariantValidators.addAll(this.getVariantValidators) + + copy /** * */ object NCTestConfig: