This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 66e8bd58bb [SYSTEMDS-3325] Multi-threaded tokenization
66e8bd58bb is described below
commit 66e8bd58bb5ed9c51afdfe391d5f0246b8bf383d
Author: Lukas Erlbacher <[email protected]>
AuthorDate: Mon Jul 11 11:30:40 2022 +0200
[SYSTEMDS-3325] Multi-threaded tokenization
This patch replaces the existing tokenization builtin with
a transformencode-style multithreaded implementation.
In the build stage the input gets split into tokens and saved
in an internal representation, also additional metadata is
computed which is then needed in the apply phase. During the
apply the computed data is retrieved and written to the output.
The current implementation splits the input frame into row
partitions. Default is 64, which can be changed with the
sysds.parallel.tokenize.numBlocks configuration.
Closes #1639.
---
conf/SystemDS-config.xml.template | 8 +-
.../apache/sysds/conf/ConfigurationManager.java | 8 +
src/main/java/org/apache/sysds/conf/DMLConfig.java | 4 +
.../java/org/apache/sysds/hops/OptimizerUtils.java | 21 +-
.../cp/ParameterizedBuiltinCPInstruction.java | 2 +-
.../spark/ParameterizedBuiltinSPInstruction.java | 5 +-
...kenizerPre.java => DocumentRepresentation.java} | 32 ++-
.../sysds/runtime/transform/tokenize/Token.java | 124 ++++++++++++
.../runtime/transform/tokenize/Tokenizer.java | 216 ++++++++++++++++++---
.../transform/tokenize/TokenizerFactory.java | 36 ++--
.../runtime/transform/tokenize/TokenizerPost.java | 33 ----
.../transform/tokenize/TokenizerPostCount.java | 121 ------------
.../transform/tokenize/TokenizerPostHash.java | 159 ---------------
.../transform/tokenize/TokenizerPostPosition.java | 137 -------------
.../transform/tokenize/TokenizerPreNgram.java | 100 ----------
.../tokenize/TokenizerPreWhitespaceSplit.java | 92 ---------
.../tokenize/applier/TokenizerApplier.java | 213 ++++++++++++++++++++
.../tokenize/applier/TokenizerApplierCount.java | 138 +++++++++++++
.../tokenize/applier/TokenizerApplierHash.java | 172 ++++++++++++++++
.../tokenize/applier/TokenizerApplierPosition.java | 121 ++++++++++++
.../tokenize/builder/TokenizerBuilder.java | 88 +++++++++
.../tokenize/builder/TokenizerBuilderNgram.java | 101 ++++++++++
.../builder/TokenizerBuilderWhitespaceSplit.java | 90 +++++++++
.../transform/TokenizeMultithreadedTest.java | 192 ++++++++++++++++++
24 files changed, 1520 insertions(+), 693 deletions(-)
diff --git a/conf/SystemDS-config.xml.template
b/conf/SystemDS-config.xml.template
index 88658b44e0..b799d5d7ce 100644
--- a/conf/SystemDS-config.xml.template
+++ b/conf/SystemDS-config.xml.template
@@ -50,7 +50,13 @@
<!-- #threads in multi-threaded transformencode -->
<sysds.parallel.encode.numThreads>-1</sysds.parallel.encode.numThreads>
-
+
+ <!-- enable multi-threaded tokenize -->
+ <sysds.parallel.tokenize>false</sysds.parallel.tokenize>
+
+ <!-- #blocks the input frame is split up for multithreaded tokenization -->
+ <sysds.parallel.tokenize.numBlocks>64</sysds.parallel.tokenize.numBlocks>
+
<!-- enables compressed linear algebra, experimental feature -->
<sysds.compressed.linalg>false</sysds.compressed.linalg>
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 505f33b19f..936b6eabda 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -173,10 +173,18 @@ public class ConfigurationManager
return
getDMLConfig().getBooleanValue(DMLConfig.PARALLEL_ENCODE);
}
+ public static boolean isParallelTokenize() {
+ return
getDMLConfig().getBooleanValue(DMLConfig.PARALLEL_TOKENIZE);
+ }
+
public static boolean isStagedParallelTransform() {
return
getDMLConfig().getBooleanValue(DMLConfig.PARALLEL_ENCODE_STAGED);
}
+ public static int getNumberTokenizeBlocks(){
+ return
getDMLConfig().getIntValue(DMLConfig.PARALLEL_TOKENIZE_NUM_BLOCKS);
+ }
+
public static int getParallelApplyBlocks(){
return
getDMLConfig().getIntValue(DMLConfig.PARALLEL_ENCODE_APPLY_BLOCKS);
}
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 1b730f6a3c..5ab151f7d5 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -75,6 +75,8 @@ public class DMLConfig
public static final String PARALLEL_ENCODE_APPLY_BLOCKS =
"sysds.parallel.encode.applyBlocks";
public static final String PARALLEL_ENCODE_BUILD_BLOCKS =
"sysds.parallel.encode.buildBlocks";
public static final String PARALLEL_ENCODE_NUM_THREADS =
"sysds.parallel.encode.numThreads";
+ public static final String PARALLEL_TOKENIZE =
"sysds.parallel.tokenize";
+ public static final String PARALLEL_TOKENIZE_NUM_BLOCKS =
"sysds.parallel.tokenize.numBlocks";
public static final String COMPRESSED_LINALG =
"sysds.compressed.linalg";
public static final String COMPRESSED_LOSSY =
"sysds.compressed.lossy";
public static final String COMPRESSED_VALID_COMPRESSIONS =
"sysds.compressed.valid.compressions";
@@ -141,6 +143,8 @@ public class DMLConfig
_defaultVals.put(DEFAULT_BLOCK_SIZE,
String.valueOf(OptimizerUtils.DEFAULT_BLOCKSIZE) );
_defaultVals.put(CP_PARALLEL_OPS, "true" );
_defaultVals.put(CP_PARALLEL_IO, "true" );
+ _defaultVals.put(PARALLEL_TOKENIZE, "false");
+ _defaultVals.put(PARALLEL_TOKENIZE_NUM_BLOCKS, "64");
_defaultVals.put(PARALLEL_ENCODE, "false" );
_defaultVals.put(PARALLEL_ENCODE_STAGED, "false" );
_defaultVals.put(PARALLEL_ENCODE_APPLY_BLOCKS, "-1");
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index a58450d5f3..fc5db7a737 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -1118,7 +1118,26 @@ public class OptimizerUtils
return ret;
}
-
+
+ public static int getTokenizeNumThreads()
+ {
+ //by default max local parallelism (vcores)
+ int ret = InfrastructureAnalyzer.getLocalParallelism();
+ int maxNumThreads = ConfigurationManager.getNumThreads();
+
+ //apply external max constraint (e.g., set by parfor or other
rewrites)
+ if( maxNumThreads > 0 ) {
+ ret = Math.min(ret, maxNumThreads);
+ }
+
+ //check if enabled in config.xml
+ if( !ConfigurationManager.isParallelTokenize() ) {
+ ret = 1;
+ }
+
+ return ret;
+ }
+
public static Level getDefaultLogLevel() {
Level log = Logger.getRootLogger().getLevel();
return (log != null) ? log : Level.INFO;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index d679a1291c..0d8d860bd0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -289,7 +289,7 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
// compute tokenizer
Tokenizer tokenizer =
TokenizerFactory.createTokenizer(getParameterMap().get("spec"),
Integer.parseInt(getParameterMap().get("max_tokens")));
- FrameBlock fbout = tokenizer.tokenize(data, new
FrameBlock(tokenizer.getSchema()));
+ FrameBlock fbout = tokenizer.tokenize(data,
OptimizerUtils.getTokenizeNumThreads());
// release locks
ec.setFrameOutput(output.getName(), fbout);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 0494f42ea8..f513898aa7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -33,6 +33,7 @@ import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -473,7 +474,7 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
sec.addLineageRDD(output.getName(),
params.get("target"));
// get max tokens for row upper bound
- long numRows = tokenizer.getNumRows(mc.getRows());
+ long numRows =
tokenizer.getMaxNumRows((int)mc.getRows());
long numCols = tokenizer.getNumCols();
sec.getDataCharacteristics(output.getName()).set(numRows, numCols,
mc.getBlocksize());
@@ -852,7 +853,7 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
long key = in._1();
FrameBlock blk = in._2();
- FrameBlock fbout = _tokenizer.tokenize(blk, new
FrameBlock(_tokenizer.getSchema()));
+ FrameBlock fbout = _tokenizer.tokenize(blk,
OptimizerUtils.getTokenizeNumThreads());
return new Tuple2<>(key, fbout);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPre.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java
similarity index 52%
rename from
src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPre.java
rename to
src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java
index 640bb5ba56..b52ef34b46 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPre.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java
@@ -19,11 +19,33 @@
package org.apache.sysds.runtime.transform.tokenize;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-
-import java.io.Serializable;
+import java.util.ArrayList;
import java.util.List;
+import java.util.stream.Collectors;
+
+public class DocumentRepresentation {
+ public List<Object> keys;
+ public List<Token> tokens;
+
+ public DocumentRepresentation(List<Object> keys, List<Token> tokens) {
+ this.keys = keys;
+ this.tokens = tokens;
+ }
+
+ public List<Token> getTokens() {
+ return tokens;
+ }
+
-public interface TokenizerPre extends Serializable {
- List<Tokenizer.DocumentToTokens> tokenizePre(FrameBlock in);
+ public void splitIntoNgrams(int minGram, int maxGram){
+ List<Token> ngramTokens = new ArrayList<>();
+ for(int n = minGram; n <= maxGram; n++){
+ for(int i = 0; i < tokens.size() - n + 1; i++){
+ List<Token> subList = tokens.subList(i, i+n);
+ Token token = new Token(subList);
+ ngramTokens.add(token);
+ }
+ }
+ tokens = ngramTokens;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java
new file mode 100644
index 0000000000..990f7e0f71
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class Token {
+
+ public static final String EMPTY_TOKEN = "";
+
+ public static class SubToken {
+ private final String text;
+ private final long startIndex;
+ private final long endIndex;
+
+ public SubToken(String token, long startIndex) {
+ this.text = token;
+ this.startIndex = startIndex;
+ this.endIndex = startIndex + token.length();
+ }
+
+ @Override
+ public String toString() {
+ return "SubToken{" +
+ "textToken='" + text + '\'' +
+ ", startIndex=" + startIndex +
+ ", endIndex=" + endIndex +
+ '}';
+ }
+ }
+
+ private List<SubToken> subTokens;
+
+ private Token(int subListSize){
+ subTokens = new ArrayList<>(subListSize);
+ }
+
+ public Token(String token, long startIndex) {
+ this(1);
+ subTokens.add(new SubToken(token, startIndex));
+ }
+
+ public Token(List<String> tokens, List<Long> startIndex){
+ this(tokens.size());
+ if(tokens.size() != startIndex.size())
+ throw new DMLRuntimeException("Cannot create token from mismatched
input sizes");
+ for(int i = 0; i < tokens.size(); i++){
+ subTokens.add(new SubToken(tokens.get(i), startIndex.get(i)));
+ }
+ }
+
+ public Token(List<Token> subList) {
+ this(getNumSubTokens(subList));
+ for(Token token: subList){
+ subTokens.addAll(token.subTokens);
+ }
+ }
+
+ private static int getNumSubTokens(List<Token> tokens){
+ int sum = 0;
+ for (Token token : tokens) {
+ sum += token.getNumSubTokens();
+ }
+ return sum;
+ }
+
+ public int getNumSubTokens(){
+ return subTokens.size();
+ }
+
+ public long getStartIndex(int subTokenIndex){
+ return subTokens.get(subTokenIndex).startIndex;
+ }
+
+ @Override
+ public int hashCode() {
+ return toString().hashCode();
+ }
+
+ @Override
+ public String toString() {
+ if(subTokens.size() == 0){
+ return EMPTY_TOKEN;
+ }
+ if(subTokens.size() == 1){
+ return subTokens.get(0).text;
+ }
+ StringBuilder sb = new StringBuilder().append("\"('");
+ for(int i = 0; i < subTokens.size(); i++){
+ sb.append(subTokens.get(i).text);
+ if(i < subTokens.size()-1)
+ sb.append("', '");
+ }
+ sb.append("')\"");
+ //return "\"('" + subTokens.stream().map(subToken ->
subToken.text).collect(Collectors.joining("', '")) + "')\"";
+ return sb.toString();
+ }
+
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java
index dd4982a2db..6ba0dcb4f8 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java
@@ -22,63 +22,221 @@ package org.apache.sysds.runtime.transform.tokenize;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier;
+import
org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplierCount;
+import
org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplierHash;
+import
org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplierPosition;
+import org.apache.sysds.runtime.transform.tokenize.builder.TokenizerBuilder;
+import org.apache.sysds.runtime.util.DependencyTask;
+import org.apache.sysds.runtime.util.DependencyThreadPool;
import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
public class Tokenizer implements Serializable {
private static final long serialVersionUID = 7155673772374114577L;
protected static final Log LOG =
LogFactory.getLog(Tokenizer.class.getName());
+ private static final boolean MULTI_THREADED_STAGES_TOKENIZER = false;
+ public static final int TOKENIZE_NUM_BLOCKS =
ConfigurationManager.getNumberTokenizeBlocks();
- private final TokenizerPre tokenizerPre;
- private final TokenizerPost tokenizerPost;
+ private DocumentRepresentation[] internalRepresentation = null;
+ private final TokenizerBuilder tokenizerBuilder;
+ private final TokenizerApplier tokenizerApplier;
- protected Tokenizer(TokenizerPre tokenizerPre, TokenizerPost
tokenizerPost) {
-
- this.tokenizerPre = tokenizerPre;
- this.tokenizerPost = tokenizerPost;
+ protected Tokenizer(TokenizerBuilder tokenizerBuilder, TokenizerApplier
tokenizerApplier) {
+ this.tokenizerBuilder = tokenizerBuilder;
+ this.tokenizerApplier = tokenizerApplier;
}
public Types.ValueType[] getSchema() {
- return tokenizerPost.getOutSchema();
+ return tokenizerApplier.getOutSchema();
+ }
+
+ public int getMaxNumRows(int inRows) {
+ return tokenizerApplier.getMaxNumRows(inRows);
}
- public long getNumRows(long inRows) {
- return tokenizerPost.getNumRows(inRows);
+ public int getNumRowsEstimate(){
+ // Estimate upperbound because e.g. Count Applier has less since it
only outputs each unique token once
+ if(internalRepresentation != null){
+ if(tokenizerApplier.isWideFormat()) {
+ return internalRepresentation.length;
+ }
+ if(tokenizerApplier.hasPadding()) {
+ return internalRepresentation.length *
tokenizerApplier.getMaxTokens();
+ }
+ return Arrays.stream(internalRepresentation).mapToInt(doc ->
Math.min(doc.tokens.size(), tokenizerApplier.getMaxTokens())).sum();
+ }
+ throw new DMLRuntimeException("Internal Token Representation was not
computed yet. Can not get exact size.");
}
public long getNumCols() {
- return tokenizerPost.getNumCols();
+ return tokenizerApplier.getNumCols();
}
- public FrameBlock tokenize(FrameBlock in, FrameBlock out) {
- // First convert to internal representation
- List<DocumentToTokens> documentsToTokenList =
tokenizerPre.tokenizePre(in);
- // Then convert to output representation
- return tokenizerPost.tokenizePost(documentsToTokenList, out);
+ public void allocateInternalRepresentation(int numDocuments){
+ internalRepresentation = new DocumentRepresentation[numDocuments];
+ tokenizerApplier.allocateInternalMeta(numDocuments);
}
- static class Token {
- String textToken;
- long startIndex;
- long endIndex;
+ public FrameBlock tokenize(FrameBlock in) {
+ return tokenize(in, 1);
+ }
- public Token(String token, long startIndex) {
- this.textToken = token;
- this.startIndex = startIndex;
- this.endIndex = startIndex + token.length();
+ public FrameBlock tokenize(FrameBlock in, int k) {
+ allocateInternalRepresentation(in.getNumRows());
+ FrameBlock out = new FrameBlock(this.getSchema());
+ if (k > 1 && !MULTI_THREADED_STAGES_TOKENIZER) {
+ DependencyThreadPool pool = new DependencyThreadPool(k);
+ LOG.debug("Tokenizing with full DAG on " + k + " Threads");
+ try {
+ List<DependencyTask<?>> tokenizeTasks = getTokenizeTasks(in,
out, pool);
+ int lastRow =
pool.submitAllAndWait(tokenizeTasks).stream().map(s -> s == null? 0
:(Integer)s).max(Integer::compare).get();
+ if(lastRow != out.getNumRows()){
+ out = out.slice(0, lastRow - 1, 0, out.getNumColumns() -
1, null);
+ }
+ } catch (ExecutionException | InterruptedException e) {
+ LOG.error("MT tokenize failed");
+ e.printStackTrace();
+ }
+ pool.shutdown();
+ } else {
+ build(in, k);
+
out.ensureAllocatedColumns(tokenizerApplier.getNumRows(this.internalRepresentation));
+ out = apply(out, k);
}
+ return out;
}
- static class DocumentToTokens {
- List<Object> keys;
- List<Tokenizer.Token> tokens;
+ private List<DependencyTask<?>> getTokenizeTasks(FrameBlock in, FrameBlock
out, DependencyThreadPool pool) {
+ // TODO further optimisation of task graph to reduce memory usage!
+ // TODO add cache awareness
+ List<DependencyTask<?>> tasks = new ArrayList<>();
+ Map<Integer[], Integer[]> depMap = new HashMap<>();
+ tasks.add(DependencyThreadPool.createDependencyTask(new
AllocateOutputFrame(this, out)));
+ List<DependencyTask<?>> buildTasks = getBuildTasks(in); // First half
is builder build second half is applier build, dependencies already done
+ tasks.addAll(buildTasks);
+ List<DependencyTask<?>> applyTasks =
tokenizerApplier.getApplyTasks(this.internalRepresentation, out);
+ if(applyTasks.size() != buildTasks.size() / 2)
+ throw new DMLRuntimeException("Different block sizes between build
and apply tasks currently not supported");
+ // Builder creates internal representation for a given section
+ // Applier builder creates additional meta information which will be
needed in the apply step
+ // If there is long representation and no padding:
+ // - Count and Hash apply tasks have dependencies to the metadata
build task of all previous chunks due to "getOutputRow".
+ // e.g. apply task starting at row 100 with block size 50 has
dependencies to the ApplierBuildTask responsible for sections [0-49] and
[50-99].
+ // - Same for Position only they are only dependent on the internal
representation creation since it does not have metadata.
+ if(!tokenizerApplier.isWideFormat() || !tokenizerApplier.hasPadding()){
+ int buildTaskOffset;
+ if(tokenizerApplier instanceof TokenizerApplierPosition){
+ buildTaskOffset = 0;
+ }
+ else if (tokenizerApplier instanceof TokenizerApplierCount ||
tokenizerApplier instanceof TokenizerApplierHash) {
+ buildTaskOffset = applyTasks.size();
+ }
+ else{
+ throw new DMLRuntimeException("Unknown TokenizerApplier");
+ }
+ depMap.put(new Integer[] {0, 1}, new Integer[]{1,
(buildTasks.size()/2) + 1});
+ depMap.put(new Integer[] {tasks.size(),
tasks.size()+applyTasks.size()}, new Integer[]{0, 1});
+ for(int i = 0; i < applyTasks.size(); i++){
+ depMap.put(new Integer[] {tasks.size() + i,
tasks.size()+applyTasks.size()}, new Integer[]{1+buildTaskOffset + i,
2+buildTaskOffset + i});
+ }
+ }
+ tasks.addAll(applyTasks);
+ List<List<? extends Callable<?>>> deps = new
ArrayList<>(Collections.nCopies(tasks.size(), null));
+ DependencyThreadPool.createDependencyList(tasks, depMap, deps);
+ return DependencyThreadPool.createDependencyTasks(tasks, deps);
+ }
- public DocumentToTokens(List<Object> keys, List<Tokenizer.Token>
tokens) {
- this.keys = keys;
- this.tokens = tokens;
+ public FrameBlock apply(FrameBlock out, int k) {
+ int lastRow = -1;
+ if(k > 1){
+ DependencyThreadPool pool = new DependencyThreadPool(k);
+ try{
+ List<DependencyTask<?>> taskList =
tokenizerApplier.getApplyTasks(this.internalRepresentation, out);
+ lastRow = pool.submitAllAndWait(taskList).stream().map(s ->
(Integer)s).max(Integer::compare).get();
+ }
+ catch(ExecutionException | InterruptedException e) {
+ LOG.error("MT Tokenizer apply failed");
+ e.printStackTrace();
+ }
+ pool.shutdown();
+
+ }else{
+ lastRow =
tokenizerApplier.applyInternalRepresentation(this.internalRepresentation, out);
}
+ if(lastRow != out.getNumRows()){
+ out = out.slice(0, lastRow - 1, 0, out.getNumColumns() - 1, null);
+ }
+
+ return out;
+ }
+
+ public List<DependencyTask<?>> getBuildTasks(FrameBlock in){
+ List<DependencyTask<?>> tasks = tokenizerBuilder.getTasks(in,
this.internalRepresentation);
+ List<DependencyTask<?>> applierBuildTaskList =
tokenizerApplier.getBuildTasks(this.internalRepresentation);
+ if(tasks.size() != applierBuildTaskList.size())
+ throw new DMLRuntimeException("Cannot create dependencies for
mismatched array sizes");
+ tasks.addAll(applierBuildTaskList);
+ List<List<? extends Callable<?>>> deps = new
ArrayList<>(Collections.nCopies(tasks.size(), null));
+ Map<Integer[], Integer[]> depMap = new HashMap<>();
+ for(int i = 0; i < tasks.size() / 2; i++){
+ depMap.put(new Integer[]{i+applierBuildTaskList.size(),
i+applierBuildTaskList.size() + 1}, new Integer[] {i, i+1});
+ }
+ DependencyThreadPool.createDependencyList(tasks, depMap, deps);
+ tasks = DependencyThreadPool.createDependencyTasks(tasks, deps);
+ return tasks;
}
+
+ public void build(FrameBlock in, int k){
+ tokenizerApplier.allocateInternalMeta(in.getNumRows());
+ if(k > 1){
+ DependencyThreadPool pool = new DependencyThreadPool(k);
+ try{
+ pool.submitAllAndWait(getBuildTasks(in));
+ }
+ catch(ExecutionException | InterruptedException e) {
+ LOG.error("MT Tokenizer build failed");
+ e.printStackTrace();
+ }
+ pool.shutdown();
+
+ }else{
+ tokenizerBuilder.createInternalRepresentation(in,
this.internalRepresentation);
+ tokenizerApplier.build(this.internalRepresentation, 0, -1);
+ }
+ }
+
+
+ protected static class AllocateOutputFrame implements Callable<Object>{
+
+ protected final Tokenizer _tokenizer;
+ protected final FrameBlock _out;
+
+ protected AllocateOutputFrame(Tokenizer tokenizer,
+ FrameBlock out){
+ this._tokenizer = tokenizer;
+ this._out = out;
+ }
+
+ @Override
+ public Object call() throws Exception {
+ _out.ensureAllocatedColumns(_tokenizer.getNumRowsEstimate());
+ return null;
+ }
+ }
+
+
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java
index 18c4bff89a..218bb5ee4e 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java
@@ -20,6 +20,13 @@
package org.apache.sysds.runtime.transform.tokenize;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier;
+import
org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplierCount;
+import
org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplierHash;
+import
org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplierPosition;
+import org.apache.sysds.runtime.transform.tokenize.builder.TokenizerBuilder;
+import
org.apache.sysds.runtime.transform.tokenize.builder.TokenizerBuilderNgram;
+import
org.apache.sysds.runtime.transform.tokenize.builder.TokenizerBuilderWhitespaceSplit;
import org.apache.wink.json4j.JSONObject;
import org.apache.wink.json4j.JSONArray;
@@ -53,13 +60,13 @@ public class TokenizerFactory {
int tokenizeCol = jSpec.getInt("tokenize_col");
// tokenization needs one or more idCols that define the document
and are replicated per token
- List<Integer> idCols = new ArrayList<>();
JSONArray idColsJsonArray = jSpec.getJSONArray("id_cols");
+ int[] idCols = new int[idColsJsonArray.length()];
for (int i=0; i < idColsJsonArray.length(); i++) {
- idCols.add(idColsJsonArray.getInt(i));
+ idCols[i] = idColsJsonArray.getInt(i);
}
// Output schema is derived from specified id cols
- int numIdCols = idCols.size();
+ int numIdCols = idCols.length;
// get difference between long and wide format
boolean wideFormat = false; // long format is default
@@ -67,18 +74,23 @@ public class TokenizerFactory {
wideFormat = jSpec.getBoolean("format_wide");
}
- TokenizerPre tokenizerPre;
- TokenizerPost tokenizerPost;
+ boolean applyPadding = false; // no padding is default
+ if (jSpec.has("apply_padding")) {
+ applyPadding = jSpec.getBoolean("apply_padding");
+ }
+
+ TokenizerBuilder tokenizerBuilder;
+ TokenizerApplier tokenizerApplier;
- // Note that internal representation should be independent from
output representation
+ // Note that internal representation should be independent of
output representation
// Algorithm to transform tokens into internal token representation
switch (algo) {
case "split":
- tokenizerPre = new TokenizerPreWhitespaceSplit(idCols,
tokenizeCol, algoParams);
+ tokenizerBuilder = new
TokenizerBuilderWhitespaceSplit(idCols, tokenizeCol, algoParams);
break;
case "ngram":
- tokenizerPre = new TokenizerPreNgram(idCols, tokenizeCol,
algoParams);
+ tokenizerBuilder = new TokenizerBuilderNgram(idCols,
tokenizeCol, algoParams);
break;
default:
throw new IllegalArgumentException("Algorithm {algo=" +
algo + "} is not supported.");
@@ -87,19 +99,19 @@ public class TokenizerFactory {
// Transform tokens to output representation
switch (out) {
case "count":
- tokenizerPost = new TokenizerPostCount(outParams,
numIdCols, maxTokens, wideFormat);
+ tokenizerApplier = new TokenizerApplierCount(numIdCols,
maxTokens, wideFormat, applyPadding, outParams);
break;
case "position":
- tokenizerPost = new TokenizerPostPosition(outParams,
numIdCols, maxTokens, wideFormat);
+ tokenizerApplier = new TokenizerApplierPosition(numIdCols,
maxTokens, wideFormat, applyPadding);
break;
case "hash":
- tokenizerPost = new TokenizerPostHash(outParams,
numIdCols, maxTokens, wideFormat);
+ tokenizerApplier = new TokenizerApplierHash(numIdCols,
maxTokens, wideFormat, applyPadding, outParams);
break;
default:
throw new IllegalArgumentException("Output representation
{out=" + out + "} is not supported.");
}
- tokenizer = new Tokenizer(tokenizerPre, tokenizerPost);
+ tokenizer = new Tokenizer(tokenizerBuilder, tokenizerApplier);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPost.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPost.java
deleted file mode 100644
index 5f35c89307..0000000000
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPost.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.transform.tokenize;
-
-import org.apache.sysds.common.Types;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-
-import java.io.Serializable;
-import java.util.List;
-
-public interface TokenizerPost extends Serializable {
- FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> tl, FrameBlock
out);
- Types.ValueType[] getOutSchema();
- long getNumRows(long inRows);
- long getNumCols();
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostCount.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostCount.java
deleted file mode 100644
index f1f9e8156c..0000000000
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostCount.java
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.transform.tokenize;
-
-import org.apache.sysds.common.Types;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.util.UtilFunctions;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-public class TokenizerPostCount implements TokenizerPost{
-
- private static final long serialVersionUID = 6382000606237705019L;
- private final Params params;
- private final int numIdCols;
- private final int maxTokens;
- private final boolean wideFormat;
-
- static class Params implements Serializable {
-
- private static final long serialVersionUID = 5121697674346781880L;
-
- public boolean sort_alpha = false;
-
- public Params(JSONObject json) throws JSONException {
- if (json != null && json.has("sort_alpha")) {
- this.sort_alpha = json.getBoolean("sort_alpha");
- }
- }
- }
-
- public TokenizerPostCount(JSONObject params, int numIdCols, int maxTokens,
boolean wideFormat) throws JSONException {
- this.params = new Params(params);
- this.numIdCols = numIdCols;
- this.maxTokens = maxTokens;
- this.wideFormat = wideFormat;
- }
-
- @Override
- public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> tl,
FrameBlock out) {
- for (Tokenizer.DocumentToTokens docToToken: tl) {
- List<Object> keys = docToToken.keys;
- List<Tokenizer.Token> tokenList = docToToken.tokens;
- // Creating the counts for BoW
- Map<String, Long> tokenCounts =
tokenList.stream().collect(Collectors.groupingBy(token ->
- token.textToken, Collectors.counting()));
- // Remove duplicate strings
- Stream<String> distinctTokenStream = tokenList.stream().map(token
-> token.textToken).distinct();
- if (params.sort_alpha) {
- // Sort alphabetically
- distinctTokenStream = distinctTokenStream.sorted();
- }
- List<String> outputTokens =
distinctTokenStream.collect(Collectors.toList());
-
- int numTokens = 0;
- for (String token: outputTokens) {
- if (numTokens >= maxTokens) {
- break;
- }
- // Create a row per token
- long count = tokenCounts.get(token);
- List<Object> rowList = new ArrayList<>(keys);
- rowList.add(token);
- rowList.add(count);
- Object[] row = new Object[rowList.size()];
- rowList.toArray(row);
- out.appendRow(row);
- numTokens++;
- }
- }
-
- return out;
- }
-
- @Override
- public Types.ValueType[] getOutSchema() {
- if (wideFormat) {
- throw new IllegalArgumentException("Wide Format is not supported
for Count Representation.");
- }
- // Long format only depends on numIdCols
- Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols +
2,Types.ValueType.STRING );
- schema[numIdCols + 1] = Types.ValueType.INT64;
- return schema;
- }
-
- public long getNumRows(long inRows) {
- if (wideFormat) {
- return inRows;
- } else {
- return inRows * maxTokens;
- }
- }
-
- public long getNumCols() {
- return this.getOutSchema().length;
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostHash.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostHash.java
deleted file mode 100644
index dc64538fb1..0000000000
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostHash.java
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.transform.tokenize;
-
-import org.apache.sysds.common.Types;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.util.UtilFunctions;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.TreeMap;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-
-public class TokenizerPostHash implements TokenizerPost{
-
- private static final long serialVersionUID = 4763889041868044668L;
- private final Params params;
- private final int numIdCols;
- private final int maxTokens;
- private final boolean wideFormat;
-
- static class Params implements Serializable {
-
- private static final long serialVersionUID =
-256069061414241795L;
-
- public int num_features = 1048576; // 2^20
-
- public Params(JSONObject json) throws JSONException {
- if (json != null && json.has("num_features")) {
- this.num_features = json.getInt("num_features");
- }
- }
- }
-
- public TokenizerPostHash(JSONObject params, int numIdCols, int
maxTokens, boolean wideFormat) throws JSONException {
- this.params = new Params(params);
- this.numIdCols = numIdCols;
- this.maxTokens = maxTokens;
- this.wideFormat = wideFormat;
- }
-
- @Override
- public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> tl,
FrameBlock out) {
- for (Tokenizer.DocumentToTokens docToToken: tl) {
- List<Object> keys = docToToken.keys;
- List<Tokenizer.Token> tokenList = docToToken.tokens;
- // Transform to hashes
- List<Integer> hashList = tokenList.stream().map(token
-> token.textToken.hashCode() %
-
params.num_features).collect(Collectors.toList());
- // Counting the hashes
- Map<Integer, Long> hashCounts =
hashList.stream().collect(Collectors.groupingBy(Function.identity(),
- Collectors.counting()));
- // Sorted by hash
- Map<Integer, Long> sortedHashes = new
TreeMap<>(hashCounts);
-
- if (wideFormat) {
- this.appendTokensWide(keys, sortedHashes, out);
- } else {
- this.appendTokensLong(keys, sortedHashes, out);
- }
- }
-
- return out;
- }
-
- private void appendTokensLong(List<Object> keys, Map<Integer, Long>
sortedHashes, FrameBlock out) {
- int numTokens = 0;
- for (Map.Entry<Integer, Long> hashCount:
sortedHashes.entrySet()) {
- if (numTokens >= maxTokens) {
- break;
- }
- // Create a row per token
- int hash = hashCount.getKey() + 1;
- long count = hashCount.getValue();
- List<Object> rowList = new ArrayList<>(keys);
- rowList.add((long) hash);
- rowList.add(count);
- Object[] row = new Object[rowList.size()];
- rowList.toArray(row);
- out.appendRow(row);
- numTokens++;
- }
- }
-
- private void appendTokensWide(List<Object> keys, Map<Integer, Long>
sortedHashes, FrameBlock out) {
- // Create one row with keys as prefix
- List<Object> rowList = new ArrayList<>(keys);
-
- for (int tokenPos = 0; tokenPos < maxTokens; tokenPos++) {
- long positionHash = sortedHashes.getOrDefault(tokenPos,
0L);
- rowList.add(positionHash);
- }
- Object[] row = new Object[rowList.size()];
- rowList.toArray(row);
- out.appendRow(row);
- }
-
- @Override
- public Types.ValueType[] getOutSchema() {
- if (wideFormat) {
- return getOutSchemaWide(numIdCols, maxTokens);
- } else {
- return getOutSchemaLong(numIdCols);
- }
- }
-
- private static Types.ValueType[] getOutSchemaWide(int numIdCols, int
maxTokens) {
- Types.ValueType[] schema = new Types.ValueType[numIdCols +
maxTokens];
- int i = 0;
- for (; i < numIdCols; i++) {
- schema[i] = Types.ValueType.STRING;
- }
- for (int j = 0; j < maxTokens; j++, i++) {
- schema[i] = Types.ValueType.INT64;
- }
- return schema;
- }
-
- private static Types.ValueType[] getOutSchemaLong(int numIdCols) {
- Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols +
2,Types.ValueType.STRING );
- schema[numIdCols] = Types.ValueType.INT64;
- schema[numIdCols+1] = Types.ValueType.INT64;
- return schema;
- }
-
- public long getNumRows(long inRows) {
- if (wideFormat) {
- return inRows;
- } else {
- return inRows * maxTokens;
- }
- }
-
- public long getNumCols() {
- return this.getOutSchema().length;
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostPosition.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostPosition.java
deleted file mode 100644
index 64a24aa7fb..0000000000
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPostPosition.java
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.transform.tokenize;
-
-import org.apache.sysds.common.Types;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-
-import org.apache.sysds.runtime.util.UtilFunctions;
-import org.apache.wink.json4j.JSONObject;
-
-import java.util.ArrayList;
-import java.util.List;
-
-public class TokenizerPostPosition implements TokenizerPost{
-
- private static final long serialVersionUID = 3563407270742660830L;
- private final int numIdCols;
- private final int maxTokens;
- private final boolean wideFormat;
-
- public TokenizerPostPosition(JSONObject params, int numIdCols, int
maxTokens, boolean wideFormat) {
- // No configurable params yet
- this.numIdCols = numIdCols;
- this.maxTokens = maxTokens;
- this.wideFormat = wideFormat;
- }
-
- @Override
- public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> tl,
FrameBlock out) {
- for (Tokenizer.DocumentToTokens docToToken: tl) {
- List<Object> keys = docToToken.keys;
- List<Tokenizer.Token> tokenList = docToToken.tokens;
-
- if (wideFormat) {
- this.appendTokensWide(keys, tokenList, out);
- } else {
- this.appendTokensLong(keys, tokenList, out);
- }
- }
-
- return out;
- }
-
- public void appendTokensLong(List<Object> keys, List<Tokenizer.Token>
tokenList, FrameBlock out) {
- int numTokens = 0;
- for (Tokenizer.Token token: tokenList) {
- if (numTokens >= maxTokens) {
- break;
- }
- // Create a row per token
- List<Object> rowList = new ArrayList<>(keys);
- // Convert to 1-based index for DML
- rowList.add(token.startIndex + 1);
- rowList.add(token.textToken);
- Object[] row = new Object[rowList.size()];
- rowList.toArray(row);
- out.appendRow(row);
- numTokens++;
- }
- }
-
- public void appendTokensWide(List<Object> keys, List<Tokenizer.Token>
tokenList, FrameBlock out) {
- // Create one row with keys as prefix
- List<Object> rowList = new ArrayList<>(keys);
-
- int numTokens = 0;
- for (Tokenizer.Token token: tokenList) {
- if (numTokens >= maxTokens) {
- break;
- }
- rowList.add(token.textToken);
- numTokens++;
- }
- // Remaining positions need to be filled with empty tokens
- for (; numTokens < maxTokens; numTokens++) {
- rowList.add("");
- }
- Object[] row = new Object[rowList.size()];
- rowList.toArray(row);
- out.appendRow(row);
- }
-
- @Override
- public Types.ValueType[] getOutSchema() {
- if (wideFormat) {
- return getOutSchemaWide(numIdCols, maxTokens);
- } else {
- return getOutSchemaLong(numIdCols);
- }
-
- }
-
- private static Types.ValueType[] getOutSchemaWide(int numIdCols, int
maxTokens) {
- Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols +
maxTokens,Types.ValueType.STRING );
- return schema;
- }
-
- private static Types.ValueType[] getOutSchemaLong(int numIdCols) {
- Types.ValueType[] schema = new Types.ValueType[numIdCols + 2];
- int i = 0;
- for (; i < numIdCols; i++) {
- schema[i] = Types.ValueType.STRING;
- }
- schema[i] = Types.ValueType.INT64;
- schema[i+1] = Types.ValueType.STRING;
- return schema;
- }
-
- public long getNumRows(long inRows) {
- if (wideFormat) {
- return inRows;
- } else {
- return inRows * maxTokens;
- }
- }
-
- public long getNumCols() {
- return this.getOutSchema().length;
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPreNgram.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPreNgram.java
deleted file mode 100644
index a602c2b9de..0000000000
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPreNgram.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.transform.tokenize;
-
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-
-public class TokenizerPreNgram implements TokenizerPre {
-
- private static final long serialVersionUID = -6297904316677723802L;
-
- private final TokenizerPreWhitespaceSplit tokenizerPreWhitespaceSplit;
- private final Params params;
-
- static class Params implements Serializable {
-
- private static final long serialVersionUID = -6516419749810062677L;
-
- public int minGram = 1;
- public int maxGram = 2;
-
- public Params(JSONObject json) throws JSONException {
- if (json != null && json.has("min_gram")) {
- this.minGram = json.getInt("min_gram");
- }
- if (json != null && json.has("max_gram")) {
- this.maxGram = json.getInt("max_gram");
- }
- }
- }
-
- public TokenizerPreNgram(List<Integer> idCols, int tokenizeCol, JSONObject
params) throws JSONException {
- this.tokenizerPreWhitespaceSplit = new
TokenizerPreWhitespaceSplit(idCols, tokenizeCol, params);
- this.params = new Params(params);
- }
-
- public List<Tokenizer.Token> wordTokenToNgrams(Tokenizer.Token wordTokens)
{
- List<Tokenizer.Token> ngramTokens = new ArrayList<>();
-
- int tokenLen = wordTokens.textToken.length();
- int startPos = params.minGram - params.maxGram;
- int endPos = Math.max(tokenLen - params.minGram, startPos);
-
- for (int i = startPos; i <= endPos; i++) {
- int startSlice = Math.max(i, 0);
- int endSlice = Math.min(i + params.maxGram, tokenLen);
- String substring = wordTokens.textToken.substring(startSlice,
endSlice);
- long tokenStart = wordTokens.startIndex + startSlice;
- ngramTokens.add(new Tokenizer.Token(substring, tokenStart));
- }
-
- return ngramTokens;
- }
-
- public List<Tokenizer.Token> wordTokenListToNgrams(List<Tokenizer.Token>
wordTokens) {
- List<Tokenizer.Token> ngramTokens = new ArrayList<>();
-
- for (Tokenizer.Token wordToken: wordTokens) {
- List<Tokenizer.Token> ngramTokensForWord =
wordTokenToNgrams(wordToken);
- ngramTokens.addAll(ngramTokensForWord);
- }
- return ngramTokens;
- }
-
- @Override
- public List<Tokenizer.DocumentToTokens> tokenizePre(FrameBlock in) {
- List<Tokenizer.DocumentToTokens> docToWordTokens =
tokenizerPreWhitespaceSplit.tokenizePre(in);
-
- List<Tokenizer.DocumentToTokens> docToNgramTokens = new ArrayList<>();
- for (Tokenizer.DocumentToTokens docToTokens: docToWordTokens) {
- List<Object> keys = docToTokens.keys;
- List<Tokenizer.Token> wordTokens = docToTokens.tokens;
- List<Tokenizer.Token> ngramTokens =
wordTokenListToNgrams(wordTokens);
- docToNgramTokens.add(new Tokenizer.DocumentToTokens(keys,
ngramTokens));
- }
- return docToNgramTokens;
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPreWhitespaceSplit.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPreWhitespaceSplit.java
deleted file mode 100644
index 2653fc0be9..0000000000
---
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerPreWhitespaceSplit.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.transform.tokenize;
-
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
-public class TokenizerPreWhitespaceSplit implements TokenizerPre {
-
- private static final long serialVersionUID = 539127244034913364L;
-
- private final Params params;
-
- private final List<Integer> idCols;
- private final int tokenizeCol;
-
- static class Params implements Serializable {
-
- private static final long serialVersionUID = -4368552847660442628L;
-
- public String regex = "\\s+"; // whitespace
-
- public Params(JSONObject json) throws JSONException {
- if (json != null && json.has("regex")) {
- this.regex = json.getString("regex");
- }
- }
- }
-
- public TokenizerPreWhitespaceSplit(List<Integer> idCols, int tokenizeCol,
JSONObject params) throws JSONException {
- this.idCols = idCols;
- this.tokenizeCol = tokenizeCol;
- this.params = new Params(params);
- }
-
- public List<Tokenizer.Token> splitToTokens(String text) {
- List<Tokenizer.Token> tokenList = new ArrayList<>();
- String[] textTokens = text.split(params.regex);
- int curIndex = 0;
- for(String textToken: textTokens) {
- int tokenIndex = text.indexOf(textToken, curIndex);
- curIndex = tokenIndex;
- tokenList.add(new Tokenizer.Token(textToken, tokenIndex));
- }
- return tokenList;
- }
-
- @Override
- public List<Tokenizer.DocumentToTokens> tokenizePre(FrameBlock in) {
- List<Tokenizer.DocumentToTokens> documentsToTokenList = new
ArrayList<>();
-
- Iterator<String[]> iterator = in.getStringRowIterator();
- iterator.forEachRemaining(s -> {
- // Convert index value to Java (0-based) from DML (1-based)
- String text = s[tokenizeCol - 1];
- List<Object> keys = new ArrayList<>();
- for (Integer idCol: idCols) {
- Object key = s[idCol - 1];
- keys.add(key);
- }
-
- // Transform to Bag format internally
- List<Tokenizer.Token> tokenList = splitToTokens(text);
- documentsToTokenList.add(new Tokenizer.DocumentToTokens(keys,
tokenList));
- });
-
- return documentsToTokenList;
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java
new file mode 100644
index 0000000000..de37e51516
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.applier;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.util.DependencyTask;
+import org.apache.sysds.runtime.util.DependencyThreadPool;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Callable;
+
+import static
org.apache.sysds.runtime.transform.tokenize.Tokenizer.TOKENIZE_NUM_BLOCKS;
+import static org.apache.sysds.runtime.util.UtilFunctions.getBlockSizes;
+
+public abstract class TokenizerApplier implements Serializable {
+
+ protected static final Log LOG =
LogFactory.getLog(TokenizerApplier.class.getName());
+
+ public static final String PADDING_STRING = "";
+
+ protected final int numIdCols;
+ protected final int maxTokens;
+ protected final boolean wideFormat;
+ protected final boolean applyPadding;
+
+ public TokenizerApplier(int numIdCols, int maxTokens, boolean wideFormat,
boolean applyPadding){
+ this.numIdCols = numIdCols;
+ this.maxTokens = maxTokens;
+ this.wideFormat = wideFormat;
+ this.applyPadding = applyPadding;
+ }
+
+ public int applyInternalRepresentation(DocumentRepresentation[]
internalRepresentation, FrameBlock out){
+ return applyInternalRepresentation(internalRepresentation, out, 0, -1);
+ }
+ abstract int applyInternalRepresentation(DocumentRepresentation[]
internalRepresentation, FrameBlock out, int startRow, int blk);
+
+ public void build(DocumentRepresentation[] internalRepresentation, int
inputRowStart, int blk){ }
+
+ public List<DependencyTask<?>> getBuildTasks(DocumentRepresentation[]
internalRepresentation){
+ int nRows = internalRepresentation.length;
+ List<Callable<Object>> tasks = new ArrayList<>();
+ int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS);
+ if(blockSizes.length == 1){
+ tasks.add(new TokenizerApplierBuildTask<>(this,
internalRepresentation, 0, -1));
+ }
+ else {
+ for(int startRow = 0, i = 0; i < blockSizes.length;
startRow+=blockSizes[i], i++){
+ tasks.add(new TokenizerApplierBuildTask<>(this,
internalRepresentation, startRow, blockSizes[i]));
+ }
+ }
+ return DependencyThreadPool.createDependencyTasks(tasks, null);
+ }
+
+ public List<DependencyTask<?>> getApplyTasks(DocumentRepresentation[]
internalRepresentation, FrameBlock out) {
+ int nRows = internalRepresentation.length;
+ List<Callable<Object>> tasks = new ArrayList<>();
+ int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS);
+ if(blockSizes.length == 1){
+ tasks.add(new TokenizerApplyTask<>(this, out,
internalRepresentation, 0, -1));
+ }
+ else {
+ for(int startRow = 0, i = 0; i < blockSizes.length;
startRow+=blockSizes[i], i++){
+ tasks.add(new TokenizerApplyTask<>(this, out,
internalRepresentation, startRow, blockSizes[i]));
+ }
+ }
+ return DependencyThreadPool.createDependencyTasks(tasks, null);
+ }
+
+ protected int setKeys(int row, List<Object> keys, FrameBlock out){
+ int col = 0;
+ for(; col < keys.size(); col++){
+ out.set(row, col, keys.get(col));
+ }
+ return col;
+ }
+
+ protected int applyPaddingLong(int startRow, int numTokens, List<Object>
keys, FrameBlock out, Object val1, Object val2){
+ int row = startRow;
+ for (; numTokens < maxTokens; numTokens++, row++){
+ int col = setKeys(row, keys, out);
+ out.set(row, col, val1);
+ out.set(row, col+1, val2);
+ }
+ return row;
+ }
+
+ protected void applyPaddingWide(int row, int offset, int startToken,
FrameBlock out, Object padding){
+ int token = startToken;
+ for (; token < maxTokens; token++) {
+ out.set(row, offset+token, padding);
+ }
+ }
+
+ public abstract Types.ValueType[] getOutSchema();
+
+ public boolean hasPadding(){
+ return applyPadding;
+ }
+
+ public int getMaxTokens(){
+ return maxTokens;
+ }
+
+ public int getMaxNumRows(int inRows) {
+ if (wideFormat) {
+ return inRows;
+ } else {
+ return inRows * maxTokens;
+ }
+ }
+ public abstract int getNumRows(DocumentRepresentation[]
internalRepresentation);
+
+ public <T, E> int getOutputRow(int inputRowStart, List<Map<T, E>>
internalData){
+ if(wideFormat)
+ return inputRowStart;
+ if(applyPadding)
+ return maxTokens * inputRowStart;
+ return internalData.stream().limit(inputRowStart).mapToInt(hashMap ->
Math.min(hashMap.size(), maxTokens)).sum();
+ }
+
+ public int getOutputRow(int inputRowStart, DocumentRepresentation[]
internalData){
+ if(wideFormat)
+ return inputRowStart;
+ if(applyPadding)
+ return maxTokens * inputRowStart;
+ return Arrays.stream(internalData).limit(inputRowStart).mapToInt(doc
-> Math.min(doc.tokens.size(), maxTokens)).sum();
+ }
+
+ public long getNumCols() {
+ return this.getOutSchema().length;
+ }
+
+ public boolean isWideFormat() {
+ return wideFormat;
+ }
+
+ public void allocateInternalMeta(int numDocuments) { }
+
+
+ protected static class TokenizerApplyTask<T extends TokenizerApplier>
implements Callable<Object>{
+
+ protected final T _tokenizerApplier;
+ protected final FrameBlock _output;
+ protected final DocumentRepresentation[] _internalRepresentation;
+ protected final int _rowStart;
+ protected final int _blk;
+
+ protected TokenizerApplyTask(T tokenizerApplier, FrameBlock out,
+ DocumentRepresentation[]
internalRepresentation,
+ int rowStart, int blk){
+ this._tokenizerApplier = tokenizerApplier;
+ this._output = out;
+ this._internalRepresentation = internalRepresentation;
+ this._rowStart = rowStart;
+ this._blk = blk;
+ }
+
+ @Override
+ public Object call() throws Exception {
+ return
this._tokenizerApplier.applyInternalRepresentation(this._internalRepresentation,
this._output, this._rowStart, this._blk);
+ }
+ }
+
+ protected static class TokenizerApplierBuildTask<T extends
TokenizerApplier> implements Callable<Object>{
+
+ protected final T _tokenizerApplier;
+ protected final DocumentRepresentation[] _internalRepresentation;
+ protected final int _rowStart;
+ protected final int _blk;
+
+ protected TokenizerApplierBuildTask(T tokenizerApplier,
+ DocumentRepresentation[]
internalRepresentation,
+ int rowStart, int blk){
+ this._tokenizerApplier = tokenizerApplier;
+ this._internalRepresentation = internalRepresentation;
+ this._rowStart = rowStart;
+ this._blk = blk;
+ }
+
+ @Override
+ public Object call() throws Exception {
+ this._tokenizerApplier.build(this._internalRepresentation,
this._rowStart, this._blk);
+ return null;
+ }
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java
new file mode 100644
index 0000000000..a67467f51e
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.applier;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.transform.tokenize.Token;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class TokenizerApplierCount extends TokenizerApplier {
+
+ private static final long serialVersionUID = 6382000606237705019L;
+ public boolean sort_alpha = false;
+
+ private List<Map<String, Integer>> counts;
+
+ public TokenizerApplierCount(int numIdCols, int maxTokens, boolean
wideFormat, boolean applyPadding, JSONObject params) throws JSONException {
+ super(numIdCols, maxTokens, wideFormat, applyPadding);
+ if (params != null && params.has("sort_alpha")) {
+ this.sort_alpha = params.getBoolean("sort_alpha");
+ }
+ }
+
+ @Override
+ public int getNumRows(DocumentRepresentation[] internalRepresentation) {
+ if(wideFormat)
+ return internalRepresentation.length;
+ if(applyPadding)
+ return maxTokens * internalRepresentation.length;
+ return counts.stream().mapToInt(hashMap -> Math.min(hashMap.size(),
maxTokens)).sum();
+ }
+
+ @Override
+ public void allocateInternalMeta(int numDocuments) {
+ counts = new ArrayList<>(Collections.nCopies(numDocuments,null));
+ }
+
+ @Override
+ public void build(DocumentRepresentation[] internalRepresentation, int
inputRowStart, int blk){
+ int endIndex = getEndIndex(internalRepresentation.length,
inputRowStart, blk);
+ for(int i = inputRowStart; i < endIndex; i++){
+ Map<String, Integer> tokenCounts = new HashMap<>();
+ for(Token token: internalRepresentation[i].tokens){
+ String txt = token.toString();
+ Integer count = tokenCounts.getOrDefault(txt, null);
+ if(count != null)
+ tokenCounts.put(txt, count + 1);
+ else
+ tokenCounts.put(txt, 1);
+ }
+ counts.set(i, tokenCounts);
+ }
+ }
+
+ @Override
+ public int applyInternalRepresentation(DocumentRepresentation[]
internalRepresentation, FrameBlock out, int inputRowStart, int blk) {
+ int endIndex = getEndIndex(internalRepresentation.length,
inputRowStart, blk);
+ int outputRow = getOutputRow(inputRowStart, counts);
+ for(int i = inputRowStart; i < endIndex; i++) {
+ List<Object> keys = internalRepresentation[i].keys;
+ // Creating the counts for BoW
+ Map<String, Integer> tokenCounts = counts.get(i);
+ // Remove duplicate strings
+ Collection<String> distinctTokens = tokenCounts.keySet();
+ if (this.sort_alpha) {
+ // Sort alphabetically
+ distinctTokens = new TreeSet<>(distinctTokens);
+ }
+
+ int numTokens = 0;
+ for (String token: distinctTokens) {
+ if (numTokens >= maxTokens) {
+ break;
+ }
+ int col = setKeys(outputRow, keys, out);
+ // Create a row per token
+ long count = tokenCounts.get(token);
+ out.set(outputRow, col, token);
+ out.set(outputRow, col+1, count);
+ outputRow++;
+ numTokens++;
+ }
+ if(applyPadding){
+ outputRow = applyPaddingLong(outputRow, numTokens, keys, out,
PADDING_STRING, -1);
+ }
+ }
+ return outputRow;
+ }
+
+ @Override
+ public Types.ValueType[] getOutSchema() {
+ if (wideFormat) {
+ throw new IllegalArgumentException("Wide Format is not supported
for Count Representation.");
+ }
+ // Long format only depends on numIdCols
+ Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols +
2,Types.ValueType.STRING );
+ schema[numIdCols + 1] = Types.ValueType.INT64;
+ return schema;
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java
new file mode 100644
index 0000000000..e9e125ca73
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.applier;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.transform.tokenize.Token;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
+import scala.Array;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class TokenizerApplierHash extends TokenizerApplier {
+
+ private static final long serialVersionUID = 4763889041868044668L;
+
+ public int num_features = 1048576; // 2^20
+
+ private List<Map<Integer, Long>> hashes;
+
+ public TokenizerApplierHash( int numIdCols, int maxTokens, boolean
wideFormat, boolean applyPadding, JSONObject params) throws JSONException {
+ super(numIdCols, maxTokens, wideFormat, applyPadding);
+ if(!applyPadding && wideFormat){
+ LOG.warn("ApplyPadding was set to 'false', Hash
Tokenizer with wide format always has padding applied");
+ }
+ if (params != null && params.has("num_features")) {
+ this.num_features = params.getInt("num_features");
+ }
+ }
+
+
+ @Override
+ public int getNumRows(DocumentRepresentation[] internalRepresentation) {
+ if(wideFormat)
+ return internalRepresentation.length;
+ if(applyPadding)
+ return maxTokens * internalRepresentation.length;
+ return hashes.stream().mapToInt(hashMap ->
Math.min(hashMap.size(), maxTokens)).sum();
+ }
+
+ @Override
+ public void allocateInternalMeta(int numDocuments) {
+ hashes = new
ArrayList<>(Collections.nCopies(numDocuments,null));
+ }
+
+ @Override
+ public void build(DocumentRepresentation[] internalRepresentation, int
inputRowStart, int blk){
+ int endIndex = getEndIndex(internalRepresentation.length,
inputRowStart, blk);
+ for(int i = inputRowStart; i < endIndex; i++){
+ List<Integer> hashList =
internalRepresentation[i].tokens.stream().map(token -> {
+ int mod = (token.hashCode() %
this.num_features);
+ if(mod < 0)
+ mod += this.num_features;
+ return mod;
+ }).collect(Collectors.toList());
+ Map<Integer, Long> hashCounts =
hashList.stream().collect(Collectors.groupingBy(Function.identity(),
+ Collectors.counting()));
+ hashes.set(i, new TreeMap<>(hashCounts));
+ }
+ }
+
+
+
+
+ @Override
+ public int applyInternalRepresentation(DocumentRepresentation[]
internalRepresentation, FrameBlock out, int inputRowStart, int blk) {
+ int endIndex = getEndIndex(internalRepresentation.length,
inputRowStart, blk);
+ int outputRow = getOutputRow(inputRowStart, hashes);
+ for(int i = inputRowStart; i < endIndex; i++) {
+ List<Object> keys = internalRepresentation[i].keys;
+ Map<Integer, Long> sortedHashes = hashes.get(i);
+ if (wideFormat) {
+ outputRow = this.setTokensWide(outputRow, keys,
sortedHashes, out);
+ } else {
+ outputRow = this.setTokensLong(outputRow, keys,
sortedHashes, out);
+ }
+ }
+ return outputRow;
+ }
+
+
+ private int setTokensLong(int row, List<Object> keys, Map<Integer,
Long> sortedHashes, FrameBlock out) {
+ int numTokens = 0;
+ for (Map.Entry<Integer, Long> hashCount:
sortedHashes.entrySet()) {
+ if (numTokens >= maxTokens) {
+ break;
+ }
+ int col = setKeys(row, keys, out);
+ // Create a row per token
+ int hash = hashCount.getKey() + 1;
+ long count = hashCount.getValue();
+ out.set(row, col, (long)hash);
+ out.set(row, col + 1, count);
+ numTokens++;
+ row++;
+ }
+ if(applyPadding){
+ row = applyPaddingLong(row, numTokens, keys, out,
PADDING_STRING, 0L);
+ }
+ return row;
+ }
+
+ private int setTokensWide(int row, List<Object> keys, Map<Integer,
Long> sortedHashes, FrameBlock out) {
+ // Create one row with keys as prefix
+ int numKeys = setKeys(row, keys, out);
+ for (int tokenPos = 0; tokenPos < maxTokens; tokenPos++) {
+ long positionHash = sortedHashes.getOrDefault(tokenPos,
0L);
+ out.set(row, numKeys + tokenPos, positionHash);
+ }
+ return ++row;
+ }
+
+ @Override
+ public Types.ValueType[] getOutSchema() {
+ if (wideFormat) {
+ return getOutSchemaWide(numIdCols, maxTokens);
+ } else {
+ return getOutSchemaLong(numIdCols);
+ }
+ }
+
+ private static Types.ValueType[] getOutSchemaWide(int numIdCols, int
maxTokens) {
+ Types.ValueType[] schema = new Types.ValueType[numIdCols +
maxTokens];
+ int i = 0;
+ for (; i < numIdCols; i++) {
+ schema[i] = Types.ValueType.STRING;
+ }
+ for (int j = 0; j < maxTokens; j++, i++) {
+ schema[i] = Types.ValueType.INT64;
+ }
+ return schema;
+ }
+
+ private static Types.ValueType[] getOutSchemaLong(int numIdCols) {
+ Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols +
2,Types.ValueType.STRING );
+ schema[numIdCols] = Types.ValueType.INT64;
+ schema[numIdCols+1] = Types.ValueType.INT64;
+ return schema;
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java
new file mode 100644
index 0000000000..c92e86b28d
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.applier;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.transform.tokenize.Token;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class TokenizerApplierPosition extends TokenizerApplier {
+
+ private static final long serialVersionUID = 3563407270742660830L;
+
+ public TokenizerApplierPosition(int numIdCols, int maxTokens, boolean
wideFormat, boolean applyPadding) {
+ super(numIdCols, maxTokens, wideFormat, applyPadding);
+ }
+
+ public int getNumRows(DocumentRepresentation[] internalRepresentation){
+ return wideFormat ? internalRepresentation.length :
Arrays.stream(internalRepresentation).mapToInt(doc -> applyPadding? maxTokens:
Math.min(doc.tokens.size(), maxTokens)).sum();
+ }
+
+
+
+ @Override
+ public int applyInternalRepresentation(DocumentRepresentation[]
internalRepresentation, FrameBlock out, int inputRowStart, int blk) {
+ int endIndex = getEndIndex(internalRepresentation.length,
inputRowStart, blk);
+ int outputRow = getOutputRow(inputRowStart,
internalRepresentation);
+ for(int i = inputRowStart; i < endIndex; i++ ) {
+ List<Object> keys = internalRepresentation[i].keys;
+ List<Token> tokenList =
internalRepresentation[i].tokens;
+
+ if (wideFormat) {
+ outputRow = this.appendTokensWide(outputRow,
keys, tokenList, out);
+ } else {
+ outputRow = this.appendTokensLong(outputRow,
keys, tokenList, out);
+ }
+ }
+ return outputRow;
+ }
+
+
+ public int appendTokensLong(int row, List<Object> keys, List<Token>
tokenList, FrameBlock out) {
+ int numTokens = 0;
+ for (Token token: tokenList) {
+ if (numTokens >= maxTokens) {
+ break;
+ }
+ int col = setKeys(row, keys, out);
+ out.set(row, col, token.getStartIndex(0) + 1);
+ out.set(row, col + 1, token.toString());
+ row++;
+ numTokens++;
+ }
+ if(applyPadding){
+ row = applyPaddingLong(row, numTokens, keys, out, -1,
PADDING_STRING);
+ }
+
+ return row;
+ }
+
+ public int appendTokensWide(int row, List<Object> keys, List<Token>
tokenList, FrameBlock out) {
+ // Create one row with keys as prefix
+ int numKeys = setKeys(row, keys, out);
+ int token = 0;
+ for (; token < tokenList.size(); token++) {
+ if (token >= maxTokens) {
+ break;
+ }
+ out.set(row, numKeys+token,
tokenList.get(token).toString());
+ }
+ if(applyPadding){
+ // Remaining positions need to be filled with empty
tokens
+ applyPaddingWide(row, numKeys, token, out,
PADDING_STRING);
+ }
+ return ++row;
+ }
+
+ @Override
+ public Types.ValueType[] getOutSchema() {
+ if (wideFormat) {
+ return getOutSchemaWide(numIdCols, maxTokens);
+ } else {
+ return getOutSchemaLong(numIdCols);
+ }
+
+ }
+
+ private static Types.ValueType[] getOutSchemaWide(int numIdCols, int
maxTokens) {
+ return UtilFunctions.nCopies(numIdCols +
maxTokens,Types.ValueType.STRING );
+ }
+
+ private static Types.ValueType[] getOutSchemaLong(int numIdCols) {
+ Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols +
2,Types.ValueType.STRING );
+ schema[numIdCols] = Types.ValueType.INT64;
+ return schema;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java
new file mode 100644
index 0000000000..36c0c26e25
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.builder;
+
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
+import org.apache.sysds.runtime.util.DependencyTask;
+import org.apache.sysds.runtime.util.DependencyThreadPool;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import static
org.apache.sysds.runtime.transform.tokenize.Tokenizer.TOKENIZE_NUM_BLOCKS;
+import static org.apache.sysds.runtime.util.UtilFunctions.getBlockSizes;
+
+public abstract class TokenizerBuilder implements Serializable {
+
+
+ public void createInternalRepresentation(FrameBlock in,
DocumentRepresentation[] internalRepresentation) {
+ createInternalRepresentation(in, internalRepresentation, 0, -1);
+ }
+
+ public abstract void createInternalRepresentation(FrameBlock in,
DocumentRepresentation[] internalRepresentation, int rowStart, int blk);
+
+ public List<DependencyTask<?>> getTasks(FrameBlock in,
DocumentRepresentation[] internalRepresentation) {
+ int nRows = in.getNumRows();
+ List<Callable<Object>> tasks = new ArrayList<>();
+ int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS);
+ if(blockSizes.length == 1){
+ tasks.add(new TokenizerBuildTask<>(this, in,
internalRepresentation, 0, -1));
+ }
+ else {
+ for(int startRow = 0, i = 0; i < blockSizes.length;
startRow+=blockSizes[i], i++){
+ tasks.add(new TokenizerBuildTask<>(this, in,
internalRepresentation, startRow, blockSizes[i]));
+ }
+ }
+ return DependencyThreadPool.createDependencyTasks(tasks, null);
+ }
+
+
+ protected static class TokenizerBuildTask<T extends TokenizerBuilder>
implements Callable<Object>{
+
+ protected final T _tokenizerBuilder;
+ protected final FrameBlock _input;
+ protected final DocumentRepresentation[] _internalRepresentation;
+ protected final int _rowStart;
+ protected final int _blk;
+
+ protected TokenizerBuildTask(T tokenizerBuilder, FrameBlock input,
+ DocumentRepresentation[]
internalRepresentation,
+ int rowStart, int blk){
+ this._tokenizerBuilder = tokenizerBuilder;
+ this._input = input;
+ this._internalRepresentation = internalRepresentation;
+ this._rowStart = rowStart;
+ this._blk = blk;
+ }
+
+ @Override
+ public Object call() throws Exception {
+ this._tokenizerBuilder.createInternalRepresentation(this._input,
this._internalRepresentation, this._rowStart, this._blk);
+ return null;
+ }
+ }
+
+
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java
new file mode 100644
index 0000000000..5ea87288b2
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.builder;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.transform.tokenize.Token;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class TokenizerBuilderNgram extends TokenizerBuilderWhitespaceSplit {
+
+ private static final long serialVersionUID = -6297904316677723802L;
+
+ private enum NgramType{
+ DOCUMENT,
+ TOKEN
+ }
+
+ public int minGram = 1;
+ public int maxGram = 2;
+ public NgramType ngramType = NgramType.DOCUMENT;
+
+ public TokenizerBuilderNgram(int[] idCols, int tokenizeCol, JSONObject
params) throws JSONException {
+ super(idCols, tokenizeCol, params);
+ if (params != null && params.has("min_gram")) {
+ this.minGram = params.getInt("min_gram");
+ }
+ if (params != null && params.has("max_gram")) {
+ this.maxGram = params.getInt("max_gram");
+ }
+ if (params != null && params.has("ngram_type")){
+ String type = params.getString("ngram_type").toLowerCase();
+ if(type.equals("document")){
+ this.ngramType = NgramType.DOCUMENT;
+ } else if (type.equals("token")) {
+ this.ngramType = NgramType.TOKEN;
+ }else {
+ throw new DMLRuntimeException("Invalid ngram type, choose
between 'token' and 'document'");
+ }
+ }
+ }
+
+ public List<Token> splitIntoNgrams(Token token, int minGram, int maxGram){
+ if(token.getNumSubTokens() == 0)
+ throw new DMLRuntimeException("Cannot create ngram of token where
there are no subTokens");
+ if(token.getNumSubTokens() != 1)
+ throw new DMLRuntimeException("Cannot create ngram of token where
there are more than 1 subTokens");
+ String tokenText = token.toString();
+ List<Token> newTokens = new ArrayList<>();
+ for(int n = minGram; n <= maxGram; n++){
+ for(int i = 0; i < tokenText.length() - n + 1; i++){
+ String substring = tokenText.substring(i, i+n);
+ newTokens.add(new Token(substring, token.getStartIndex(0) +
i));
+ }
+ }
+ return newTokens;
+ }
+ @Override
+ public void createInternalRepresentation(FrameBlock in,
DocumentRepresentation[] internalRepresentation, int rowStart, int blk) {
+ super.createInternalRepresentation(in, internalRepresentation,
rowStart, blk);
+ int endIndex = getEndIndex(in.getNumRows(), rowStart, blk);
+ for(int row = rowStart; row < endIndex; row++){
+ DocumentRepresentation documentRepresentation =
internalRepresentation[row];
+
+ if(this.ngramType == NgramType.DOCUMENT){
+ documentRepresentation.splitIntoNgrams(this.minGram,
this.maxGram);
+ } else if (this.ngramType == NgramType.TOKEN) {
+ List<Token> newTokens = new ArrayList<>();
+ for (Token wordToken: documentRepresentation.getTokens()) {
+ newTokens.addAll(splitIntoNgrams(wordToken, this.minGram,
this.maxGram));
+ }
+ documentRepresentation.tokens = newTokens;
+ }
+ }
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java
new file mode 100644
index 0000000000..c1ba7916b6
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.tokenize.builder;
+
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
+import org.apache.sysds.runtime.transform.tokenize.Token;
+import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
+import org.apache.sysds.runtime.util.DependencyTask;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.Callable;
+import java.util.stream.Collectors;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getBlockSizes;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class TokenizerBuilderWhitespaceSplit extends TokenizerBuilder {
+
+ private static final long serialVersionUID = 539127244034913364L;
+
+ private final int[] idCols;
+ private final int tokenizeCol;
+
+ public String regex = "\\s+"; // whitespace
+
+ public TokenizerBuilderWhitespaceSplit(int[] idCols, int tokenizeCol,
JSONObject params) throws JSONException {
+ if (params != null && params.has("regex")) {
+ this.regex = params.getString("regex");
+ }
+ this.idCols = idCols;
+ this.tokenizeCol = tokenizeCol;
+ }
+
+ public List<Token> splitToTokens(String text) {
+ List<Token> tokenList = new ArrayList<>();
+ if(text == null)
+ return tokenList;
+ String[] textTokens = text.split(this.regex);
+ int curIndex = 0;
+ for(String textToken: textTokens) {
+ if(Objects.equals(textToken, "")){
+ continue;
+ }
+ int tokenIndex = text.indexOf(textToken, curIndex);
+ curIndex = tokenIndex;
+ tokenList.add(new Token(textToken, tokenIndex));
+ }
+ return tokenList;
+ }
+
+ @Override
+ public void createInternalRepresentation(FrameBlock in,
DocumentRepresentation[] internalRepresentation, int rowStart, int blk) {
+ int endIndex = getEndIndex(in.getNumRows(), rowStart, blk);
+ for (int i = rowStart; i < endIndex; i++) {
+ String text = in.getString(i, tokenizeCol - 1);
+ List<Token> tokenList = splitToTokens(text);
+ List<Object> keys = new ArrayList<>();
+ for (Integer idCol : idCols) {
+ Object key = in.get(i, idCol - 1);
+ keys.add(key);
+ internalRepresentation[i] = new DocumentRepresentation(keys,
tokenList);
+ }
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java
new file mode 100644
index 0000000000..2b28848da8
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
+import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.json.Json;
+import javax.json.JsonObject;
+import javax.json.JsonObjectBuilder;
+import java.io.IOException;
+
+
+public class TokenizeMultithreadedTest extends AutomatedTestBase {
+ private static final String TEST_DIR = "functions/transform/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
TokenizeMultithreadedTest.class.getSimpleName() + "/";
+
+ //dataset and transform tasks without missing values
+ private final static String DATASET =
"20news/20news_subset_untokenized.csv";
+
+
+ private final static JsonObject ngram_algo_params0 =
Json.createObjectBuilder()
+ .add("min_gram", 2)
+ .add("max_gram", 3)
+ .add("regex", "\\W+")
+ .build();
+
+ private final static JsonObject count_out_params0 =
Json.createObjectBuilder().add("sort_alpha", false).build();
+ private final static JsonObject count_out_params1 =
Json.createObjectBuilder().add("sort_alpha", true).build();
+
+ private final static JsonObject hash_out_params0 =
Json.createObjectBuilder().add("num_features", 128).build();
+
+ public enum TokenizerBuilder {
+ WHITESPACE_SPLIT,
+ NGRAM,
+ }
+
+ public enum TokenizerApplier {
+ COUNT,
+ HASH,
+ POSITION,
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(this.getClass().getSimpleName(),
+ new TestConfiguration(TEST_CLASS_DIR,
this.getClass().getSimpleName(), new String[] { "R" }) );
+ }
+
+ @Test
+ public void testTokenizeSplitCountLong() {
+ runTokenizeTest(ExecMode.SINGLE_NODE,
TokenizerBuilder.WHITESPACE_SPLIT,TokenizerApplier.COUNT,
+ 2000, false, null, count_out_params0);
+ }
+
+ @Test
+ public void testTokenizeNgramCountLong() {
+ runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM,
TokenizerApplier.COUNT,
+ 2000, false, ngram_algo_params0, count_out_params0);
+ }
+
+ @Test
+ public void testTokenizeSplitPositionLong() {
+ runTokenizeTest(ExecMode.SINGLE_NODE,
TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.POSITION,
+ 2000, false, null, null);
+ }
+
+ @Test
+ public void testTokenizeNgramPositionLong() {
+ runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM,
TokenizerApplier.POSITION,
+ 2000, false, ngram_algo_params0, null);
+ }
+
+ @Test
+ public void testTokenizeSplitHashLong() {
+ runTokenizeTest(ExecMode.SINGLE_NODE,
TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.HASH,
+ 2000, false, null, hash_out_params0);
+ }
+
+ @Test
+ public void testTokenizeNgramHashLong() {
+ runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM,
TokenizerApplier.HASH,
+ 2000, false, ngram_algo_params0, hash_out_params0);
+ }
+ @Test
+ public void testTokenizeSplitCountWide() {
+ runTokenizeTest(ExecMode.SINGLE_NODE,
TokenizerBuilder.WHITESPACE_SPLIT,TokenizerApplier.POSITION,
+ 2000, true, null, count_out_params0);
+ }
+
+ @Test
+ public void testTokenizeNgramCountWide() {
+ runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM,
TokenizerApplier.POSITION,
+ 2000, true, ngram_algo_params0, count_out_params0);
+ }
+
+ @Test
+ public void testTokenizeSplitHashWide() {
+ runTokenizeTest(ExecMode.SINGLE_NODE,
TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.HASH,
+ 2000, true, null, hash_out_params0);
+ }
+
+ @Test
+ public void testTokenizeNgramHashWide() {
+ runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM,
TokenizerApplier.HASH,
+ 2000, true, ngram_algo_params0, hash_out_params0);
+ }
+
+ private void runTokenizeTest(ExecMode rt, TokenizerBuilder builder,
TokenizerApplier applier,
+ int max_tokens, boolean format_wide, JsonObject algo_params,
JsonObject out_params) {
+ try{
+ getAndLoadTestConfiguration(this.getClass().getSimpleName());
+ FileFormatPropertiesCSV props = new FileFormatPropertiesCSV();
+ props.setHeader(false);
+ FrameBlock input =
FrameReaderFactory.createFrameReader(Types.FileFormat.CSV, props)
+ .readFrameFromHDFS(DATASET_DIR+DATASET, -1L, -1L);
+ String spec = createTokenizerSpec(builder, applier, format_wide,
algo_params, out_params);
+ Tokenizer tokenizer = TokenizerFactory.createTokenizer(spec,
max_tokens);
+ FrameBlock outS = tokenizer.tokenize(input, 1);
+ FrameBlock outM = tokenizer.tokenize(input, 12);
+ Assert.assertEquals(outS.getNumRows(), outM.getNumRows());
+ Assert.assertEquals(outS.getNumColumns(), outM.getNumColumns());
+ TestUtils.compareFrames(DataConverter.convertToStringFrame(outS),
+ DataConverter.convertToStringFrame(outM), outS.getNumRows(),
outS.getNumColumns());
+
+ } catch (Exception ex){
+ throw new RuntimeException(ex);
+ }
+
+ }
+
+ private String createTokenizerSpec(TokenizerBuilder builder,
TokenizerApplier applier, boolean format_wide, JsonObject algo_params,
JsonObject out_params) {
+ JsonObjectBuilder spec = Json.createObjectBuilder();
+ switch (builder){
+ case WHITESPACE_SPLIT:
+ spec.add("algo", "split");
+ break;
+ case NGRAM:
+ spec.add("algo", "ngram");
+ break;
+ }
+ switch (applier){
+ case COUNT:
+ spec.add("out", "count");
+ break;
+ case POSITION:
+ spec.add("out", "position");
+ break;
+ case HASH:
+ spec.add("out", "hash");
+ break;
+ }
+ if(out_params != null)
+ spec.add("out_params", out_params);
+ if(algo_params != null)
+ spec.add("algo_params", algo_params);
+ spec.add("format_wide", format_wide);
+ spec.add("id_cols",Json.createArrayBuilder().add(2).add(3));
+ spec.add("tokenize_col", 4);
+ return spec.build().toString();
+ }
+}