This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 23935c2a20 [SYSTEMDS-3581] New dense block to deduplicate rows
23935c2a20 is described below
commit 23935c2a2000778ab67b1eb24a4aab85482de783
Author: e-strauss <[email protected]>
AuthorDate: Thu Jun 15 12:04:58 2023 +0200
[SYSTEMDS-3581] New dense block to deduplicate rows
This patch adds a new dense block to deduplicate the duplicate
rows by having pointers to the same row which is allocated once.
Internally, it has a 2d array, where the 2nd dimension has the
pointer to the rows. The use case word embedding which produces
many duplicate dense rows.
Closes #1842
---
.../sysds/runtime/data/DenseBlockFP64DEDUP.java | 253 +++++++++++++++++++++
.../sysds/runtime/data/DenseBlockFactory.java | 35 ++-
.../sysds/runtime/data/DenseBlockLFP64DEDUP.java | 241 ++++++++++++++++++++
.../sysds/runtime/matrix/data/LibMatrixMult.java | 2 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 22 +-
.../transform/encode/ColumnEncoderRecode.java | 9 +
.../encode/ColumnEncoderWordEmbedding.java | 69 ++++--
.../runtime/transform/encode/EncoderFactory.java | 10 +-
.../transform/encode/MultiColumnEncoder.java | 17 +-
...=> TransformFrameEncodeWordEmbedding1Test.java} | 8 +-
.../TransformFrameEncodeWordEmbedding2Test.java | 66 ++++--
.../TransformFrameEncodeWordEmbeddings.dml | 2 +-
.../TransformFrameEncodeWordEmbeddings2.dml | 3 +-
13 files changed, 679 insertions(+), 58 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
new file mode 100644
index 0000000000..bf0e83ec5b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.data;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+public class DenseBlockFP64DEDUP extends DenseBlockDRB{
+ private double[][] _data;
+
+ protected DenseBlockFP64DEDUP(int[] dims) {
+ super(dims);
+ reset(_rlen, _odims, 0);
+ }
+
+ @Override
+ protected void allocateBlock(int bix, int length) {
+ _data[bix] = new double[length];
+ }
+
+ @Override
+ public void reset(int rlen, int[] odims, double v) {
+ if(rlen > capacity() / _odims[0])
+ _data = new double[rlen][];
+ else{
+ if(v == 0.0){
+ for(int i = 0; i < rlen; i++)
+ _data[i] = null;
+ }
+ else {
+ for(int i = 0; i < rlen; i++){
+ if(odims[0] > _odims[0] ||_data[i] == null )
+ allocateBlock(i, odims[0]);
+ Arrays.fill(_data[i], 0, odims[0], v);
+ }
+ }
+ }
+ _rlen = rlen;
+ _odims = odims;
+ }
+
+ @Override
+ public void resetNoFill(int rlen, int[] odims) {
+ if(_data == null || rlen > _rlen){
+ _data = new double[rlen][];
+ }
+ _rlen = rlen;
+ _odims = odims;
+ }
+
+ @Override
+ public boolean isNumeric() {
+ return true;
+ }
+
+ @Override
+ public boolean isNumeric(Types.ValueType vt) {
+ return Types.ValueType.FP64 == vt;
+ }
+
+ @Override
+ public long capacity() {
+ return (_data != null) ? _data.length*_odims[0] : -1;
+ }
+
+ @Override
+ public long countNonZeros(){
+ long nnz = 0;
+ HashMap<double[], Long> cache = new HashMap<double[], Long>();
+ for (int i = 0; i < _rlen; i++) {
+ double[] row = this._data[i];
+ if(row == null)
+ continue;
+ Long count = cache.getOrDefault(row, null);
+ if(count == null){
+ count = Long.valueOf(countNonZeros(i));
+ cache.put(row, count);
+ }
+ nnz += count;
+ }
+ return nnz;
+ }
+
+ @Override
+ public int countNonZeros(int r) {
+ return _data[r] == null ? 0 : UtilFunctions.computeNnz(_data[r], 0,
_odims[0]);
+ }
+
+ @Override
+ protected long computeNnz(int bix, int start, int length) {
+ int nnz = 0;
+ int row_start = (int) Math.floor(start / _odims[0]);
+ int col_start = start % _odims[0];
+ for (int i = 0; i < length; i++) {
+ if(_data[row_start] == null){
+ i += _odims[0] - 1 - col_start;
+ col_start = 0;
+ row_start += 1;
+ continue;
+ }
+ nnz += _data[row_start][col_start] != 0 ? 1 : 0;
+ col_start += 1;
+ if(col_start == _odims[0]) {
+ col_start = 0;
+ row_start += 1;
+ }
+ }
+ return nnz;
+ }
+
+ @Override
+ public int pos(int r){
+ return 0;
+ }
+
+ @Override
+ public int pos(int[] ix){
+ int pos = ix[ix.length - 1];
+ for(int i = 1; i < ix.length - 1; i++)
+ pos += ix[i] * _odims[i];
+ return pos;
+ }
+
+ @Override
+ public double[] values(int r) {
+ return valuesAt(r);
+ }
+
+ @Override
+ public double[] valuesAt(int bix) {
+ return _data[bix] == null ? new double[_odims[0]] : _data[bix];
+ }
+
+ @Override
+ public int index(int r) {
+ return r;
+ }
+
+ @Override
+ public int numBlocks(){
+ return _data.length;
+ }
+
+ @Override
+ public int size(int bix) {
+ return _odims[0];
+ }
+
+ @Override
+ public void incr(int r, int c) {
+ incr(r,c,1.0);
+ }
+
+ @Override
+ public void incr(int r, int c, double delta) {
+ if(_data[r] == null)
+ allocateBlock(r, _odims[0]);
+ _data[r][c] += delta;
+ }
+
+ @Override
+ protected void fillBlock(int bix, int fromIndex, int toIndex, double v) {
+ if(_data[bix] == null)
+ allocateBlock(bix, _odims[0]);
+ Arrays.fill(_data[bix], fromIndex, toIndex, v);
+ }
+
+ @Override
+ protected void setInternal(int bix, int ix, double v) {
+ set(bix, ix, v);
+ }
+
+ @Override
+ public DenseBlock set(int r, int c, double v) {
+ if(_data[r] == null)
+ _data[r] = new double[_odims[0]];
+ _data[r][c] = v;
+ return this;
+ }
+
+ @Override
+ public DenseBlock set(int r, double[] v) {
+ if(v.length == _odims[0])
+ _data[r] = v;
+ else
+ throw new RuntimeException("set Denseblock called with an array
length [" + v.length +"], array to overwrite is of length [" + _odims[0] + "]");
+ return this;
+ }
+
+ @Override
+ public DenseBlock set(DenseBlock db) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public DenseBlock set(int[] ix, double v) {
+ return set(ix[0], pos(ix), v);
+ }
+
+ @Override
+ public DenseBlock set(int[] ix, long v) {
+ return set(ix[0], pos(ix), v);
+ }
+
+ @Override
+ public DenseBlock set(int[] ix, String v) {
+ return set(ix[0], pos(ix), Double.parseDouble(v));
+ }
+
+ @Override
+ public double get(int r, int c) {
+ if(_data[r] == null)
+ return 0.0;
+ else
+ return _data[r][c];
+ }
+
+ @Override
+ public double get(int[] ix) {
+ return get(ix[0], pos(ix));
+ }
+
+ @Override
+ public String getString(int[] ix) {
+ return String.valueOf(get(ix[0], pos(ix)));
+ }
+
+ @Override
+ public long getLong(int[] ix) {
+ return UtilFunctions.toLong(get(ix[0], pos(ix)));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
index e840585f49..c48a9d0aca 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
@@ -32,15 +32,29 @@ public abstract class DenseBlockFactory
public static DenseBlock createDenseBlock(int rlen, int clen) {
return createDenseBlock(new int[]{rlen, clen});
}
+
+ public static DenseBlock createDenseBlock(int rlen, int clen, boolean
dedup) {
+ return createDenseBlock(new int[]{rlen, clen}, dedup);
+ }
public static DenseBlock createDenseBlock(int[] dims) {
return createDenseBlock(ValueType.FP64, dims);
}
+
+ public static DenseBlock createDenseBlock(int[] dims, boolean dedup) {
+ return createDenseBlock(ValueType.FP64, dims, dedup);
+ }
public static DenseBlock createDenseBlock(ValueType vt, int[] dims) {
DenseBlock.Type type = (UtilFunctions.prod(dims) <
Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
- return createDenseBlock(vt, type, dims);
+ return createDenseBlock(vt, type, dims, false);
+ }
+
+ public static DenseBlock createDenseBlock(ValueType vt, int[] dims,
boolean dedup) {
+ DenseBlock.Type type = (UtilFunctions.prod(dims) <
Integer.MAX_VALUE) ?
+ DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
+ return createDenseBlock(vt, type, dims, dedup);
}
public static DenseBlock createDenseBlock(BitSet data, int[] dims) {
@@ -75,7 +89,24 @@ public abstract class DenseBlockFactory
return createDenseBlock(data, new int[]{rlen, clen});
}
- public static DenseBlock createDenseBlock(ValueType vt, DenseBlock.Type
type, int[] dims) {
+ public static DenseBlock createDenseBlock(ValueType vt, DenseBlock.Type
type, int[] dims, boolean dedup) {
+ if( dedup ) {
+ switch( type ) {
+ case DRB:
+ switch(vt) {
+ case FP64: return new
DenseBlockFP64DEDUP(dims);
+ default:
+ throw new
DMLRuntimeException("Unsupported dense block value type with deduplication
enabled: "+vt.name());
+ }
+ case LDRB:
+ switch(vt) {
+ default:
+ throw new
NotImplementedException();
+ }
+ default:
+ throw new
DMLRuntimeException("Unexpected dense block type: "+type.name());
+ }
+ }
switch( type ) {
case DRB:
switch(vt) {
diff --git
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
new file mode 100644
index 0000000000..5a4249c920
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.data;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+public class DenseBlockLFP64DEDUP extends DenseBlockLDRB{
+ //WIP
+ private double[][] _data;
+
+ protected DenseBlockLFP64DEDUP(int[] dims) {
+ super(dims);
+ reset(_rlen, _odims, 0);
+ }
+
+ @Override
+ protected void allocateBlocks(int numBlocks) {
+ _data = new double[numBlocks][];
+ }
+
+ @Override
+ protected void allocateBlock(int bix, int length) {
+ _data[bix] = new double[length];
+ }
+
+ @Override
+ public void reset(int rlen, int[] odims, double v) {
+ if(rlen > capacity() / _odims[0]) {
+ this.allocateBlocks(rlen);
+ if (v != 0.0) {
+ for (int i = 0; i < rlen; i++) {
+ allocateBlock(i, odims[0]);
+ Arrays.fill(_data[i], 0, odims[0], v);
+ }
+ }
+ }
+ else{
+ if(v == 0.0){
+ for(int i = 0; i < rlen; i++)
+ _data[i] = null;
+ }
+ else {
+ for(int i = 0; i < rlen; i++){
+ if(odims[0] > _odims[0] ||_data[i] == null )
+ allocateBlock(i, odims[0]);
+ Arrays.fill(_data[i], 0, odims[0], v);
+ }
+ }
+ }
+ _blen = 1;
+ _rlen = rlen;
+ _odims = odims;
+ }
+
+ @Override
+ public boolean isNumeric() {
+ return true;
+ }
+
+ @Override
+ public boolean isNumeric(Types.ValueType vt) {
+ return Types.ValueType.FP64 == vt;
+ }
+
+ @Override
+ public boolean isContiguous() {
+ return false;
+ }
+
+ @Override
+ public long capacity() {
+ return (_data != null) ? _data.length*_odims[0] : -1;
+ }
+
+ @Override
+ public long countNonZeros(){
+ long nnz = 0;
+ HashMap<double[], Long> cache = new HashMap<double[], Long>();
+ for (int i = 0; i < _rlen; i++) {
+ double[] row = this._data[i];
+ if(row == null)
+ continue;
+ Long count = cache.getOrDefault(row, null);
+ if(count == null){
+ count = Long.valueOf(countNonZeros(i));
+ cache.put(row, count);
+ }
+ nnz += count;
+ }
+ return nnz;
+ }
+
+ @Override
+ public int countNonZeros(int r) {
+ return _data[r] == null ? 0 : UtilFunctions.computeNnz(_data[r], 0,
_odims[0]);
+ }
+
+ @Override
+ protected long computeNnz(int bix, int start, int length) {
+ int nnz = 0;
+ int row_start = (int) Math.floor(start / _odims[0]);
+ int col_start = start % _odims[0];
+ for (int i = 0; i < length; i++) {
+ if(_data[row_start] == null){
+ i += _odims[0] - 1 - col_start;
+ col_start = 0;
+ row_start += 1;
+ continue;
+ }
+ nnz += _data[row_start][col_start] != 0 ? 1 : 0;
+ col_start += 1;
+ if(col_start == _odims[0]) {
+ col_start = 0;
+ row_start += 1;
+ }
+ }
+ return nnz;
+ }
+
+ @Override
+ public int pos(int r){
+ return 0;
+ }
+
+ @Override
+ public double[] values(int r) {
+ if(_data[r] == null)
+ allocateBlock(r, _odims[0]);
+ return _data[r];
+ }
+
+ @Override
+ public double[] valuesAt(int bix) {
+ return values(bix);
+ }
+
+
+ @Override
+ public int numBlocks(){
+ return _data.length;
+ }
+
+ @Override
+ public void incr(int r, int c) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public void incr(int r, int c, double delta) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ protected void fillBlock(int bix, int fromIndex, int toIndex, double v) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ protected void setInternal(int bix, int ix, double v) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public DenseBlock set(int r, int c, double v) {
+ if(_data[r] == null)
+ _data[r] = new double[_odims[0]];
+ _data[r][c] = v;
+ return this;
+ }
+
+ @Override
+ public DenseBlock set(int r, double[] v) {
+ if(v.length == _odims[0])
+ _data[r] = v;
+ else
+ throw new RuntimeException("set Denseblock called with an array
length [" + v.length +"], array to overwrite is of length [" + _odims[0] + "]");
+ return this;
+ }
+
+ @Override
+ public DenseBlock set(DenseBlock db) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public DenseBlock set(int[] ix, double v) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public DenseBlock set(int[] ix, long v) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public DenseBlock set(int[] ix, String v) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public double get(int r, int c) {
+ return _data[r][c];
+ }
+
+ @Override
+ public double get(int[] ix) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public String getString(int[] ix) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public long getLong(int[] ix) {
+ throw new NotImplementedException();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 99fb4b30ca..ce44e2e343 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -1738,7 +1738,7 @@ public class LibMatrixMult
int alen = a.size(i);
int[] aixs = a.indexes(i);
double[] avals = a.values(i);
- if( alen==1 ) {
+ if( alen==1 ) {
//row selection (now aggregation) with
potential scaling
int aix = aixs[apos];
int lnnz = 0;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index dcc17076c2..c075a9fd29 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -35,6 +35,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.logging.Log;
@@ -383,8 +384,12 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
allocateDenseBlock();
return this;
}
-
- public boolean allocateDenseBlock(boolean clearNNZ) {
+
+ public boolean allocateDenseBlock(boolean clearNNZ){
+ return allocateDenseBlock(clearNNZ, false);
+ }
+
+ public boolean allocateDenseBlock(boolean clearNNZ, boolean
containsDuplicates) {
//allocate block if non-existing or too small (guaranteed to be
0-initialized),
long limit = (long)rlen * clen;
//clear nnz if necessary
@@ -393,7 +398,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
sparse = false;
if( denseBlock == null ){
- denseBlock = DenseBlockFactory.createDenseBlock(rlen,
clen);
+ denseBlock = DenseBlockFactory.createDenseBlock(rlen,
clen, containsDuplicates);
return true;
}
else if( denseBlock.capacity() < limit ){
@@ -669,6 +674,17 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
}
+ public void quickSetRow(int r, double[] values){
+ if(sparse)
+ throw new NotImplementedException();
+ else{
+ //allocate and init dense block (w/o overwriting nnz)
+ allocateDenseBlock(false);
+ nonZeros += UtilFunctions.computeNnz(values, 0,
values.length) - denseBlock.countNonZeros(r);
+ denseBlock.set(r, values);
+ }
+ }
+
public double quickGetValueThreadSafe(int r, int c) {
if(sparse) {
if(!(sparseBlock instanceof SparseBlockMCSR))
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index eb7e706e0c..334149515b 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -88,6 +88,15 @@ public class ColumnEncoderRecode extends ColumnEncoder {
* @return string array of token and code
*/
public static String[] splitRecodeMapEntry(String value) {
+ // remove " chars from string (if the string contains comma in
the csv file, then it must contained by double quotes)
+ /*if(value.contains("\"")){
+ //remove just last and first appearance
+ int firstIndex = value.indexOf("\"");
+ int lastIndex = value.lastIndexOf("\"");
+ if (firstIndex != lastIndex)
+ value = value.substring(0, firstIndex) +
value.substring(firstIndex + 1, lastIndex) + value.substring(lastIndex + 1);
+ }*/
+
// Instead of using splitCSV which is forcing string with
RFC-4180 format,
// using Lop.DATATYPE_PREFIX separator to split token and code
int pos = value.lastIndexOf(Lop.DATATYPE_PREFIX);
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
index 03584cf5ee..9d08c3bc20 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -20,20 +20,31 @@
package org.apache.sysds.runtime.transform.encode;
import org.apache.commons.lang.NotImplementedException;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import java.util.HashMap;
+
import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
public class ColumnEncoderWordEmbedding extends ColumnEncoder {
- private MatrixBlock wordEmbeddings;
+ private MatrixBlock _wordEmbeddings;
+ private HashMap<Object, Long> _rcdMap;
+ private HashMap<String, double[]> _embMap;
+
+ private long lookupRCDMap(Object key) {
+ return _rcdMap.getOrDefault(key, -1L);
+ }
- //domain size is equal to the number columns of the embedding column
(equal to length of an embedding vector)
+ private double[] lookupEMBMap(Object key) {
+ return _embMap.getOrDefault(key, null);
+ }
+
+ //domain size is equal to the number columns of the embeddings column
thats equal to length of an embedding vector
@Override
public int getDomainSize(){
- return wordEmbeddings.getNumColumns();
+ return _wordEmbeddings.getNumColumns();
}
protected ColumnEncoderWordEmbedding(int colID) {
super(colID);
@@ -49,31 +60,44 @@ public class ColumnEncoderWordEmbedding extends
ColumnEncoder {
throw new NotImplementedException();
}
- //previous recode replaced strings with indices of the corresponding
matrix row index
+ //previously recode replaced strings with indices of the corresponding
matrix row index
//now, the indices are replaced with actual word embedding vectors
- //current limitation: in case the transform is done on multiple cols, the
same embedding
- //matrix is used for both transform
+ //current limitation: in case the transform is done on multiple cols, the
same embedding matrix is used for both transform
+
+ private double[] getEmbeddedingFromEmbeddingMatrix(long r){
+ double[] embedding = new double[getDomainSize()];
+ for (int i = 0; i < getDomainSize(); i++) {
+ embedding[i] = this._wordEmbeddings.quickGetValue((int) r, _colID
- 1 + i);
+ }
+ return embedding;
+
+ }
+
@Override
public void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol,
int rowStart, int blk){
- if (!(in instanceof MatrixBlock)){
+ /*if (!(in instanceof MatrixBlock)){
throw new DMLRuntimeException("ColumnEncoderWordEmbedding called
with: " + in.getClass().getSimpleName() +
" and not MatrixBlock");
- }
+ }*/
int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
- //map each recoded index to the corresponding embedding vector
+
+ //map each string to the corresponding embedding vector
for(int i=rowStart; i<rowEnd; i++){
- double embeddingIndex = in.getDouble(i, outputCol);
- //fill row with zeroes
- if(Double.isNaN(embeddingIndex)){
- for (int j = outputCol; j < outputCol + getDomainSize(); j++)
- out.quickSetValue(i, j, 0.0);
+ String key = in.getString(i, _colID-1);
+ if(key == null || key.isEmpty()) {
+ //codes[i-startInd] = Double.NaN;
+ continue;
}
- //array copy
- else{
- for (int j = outputCol; j < outputCol + getDomainSize(); j++){
- out.quickSetValue(i, j, wordEmbeddings.quickGetValue((int)
embeddingIndex - 1,j - outputCol ));
+ double[] embedding = lookupEMBMap(key);
+ if(embedding == null){
+ long code = lookupRCDMap(key);
+ if(code == -1L){
+ continue;
}
+ embedding = getEmbeddedingFromEmbeddingMatrix(code - 1);
+ _embMap.put(key, embedding);
}
+ out.quickSetRow(i, embedding);
}
}
@@ -100,12 +124,15 @@ public class ColumnEncoderWordEmbedding extends
ColumnEncoder {
@Override
public void initMetaData(FrameBlock meta) {
- return;
+ if(meta == null || meta.getNumRows() <= 0)
+ return;
+ _rcdMap = meta.getRecodeMap(_colID - 1); // 1-based
}
//save embeddings matrix reference for apply step
@Override
public void initEmbeddings(MatrixBlock embeddings){
- this.wordEmbeddings = embeddings;
+ this._wordEmbeddings = embeddings;
+ this._embMap = new HashMap<>((int) (embeddings.getNumRows()*1.2),1.0f);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 313258831a..07e84a81e7 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -114,13 +114,13 @@ public interface EncoderFactory {
// column follows binning or feature hashing
rcIDs = unionDistinct(rcIDs, except(except(dcIDs,
binIDs), haIDs));
// NOTE: Word Embeddings requires recode as preparation
- rcIDs = unionDistinct(rcIDs, weIDs);
+ //rcIDs = unionDistinct(rcIDs, weIDs);
// Error out if the first level encoders have overlaps
- if (intersect(rcIDs, binIDs, haIDs))
- throw new DMLRuntimeException("More than one
encoders (recode, binning, hashing) on one column is not allowed");
+ if (intersect(rcIDs, binIDs, haIDs, weIDs))
+ throw new DMLRuntimeException("More than one
encoders (recode, binning, hashing, word_embedding) on one column is not
allowed");
- List<Integer> ptIDs =
except(except(UtilFunctions.getSeqList(1, clen, 1), unionDistinct(rcIDs,
haIDs)),
- binIDs);
+ List<Integer> ptIDs =
except(except(except(UtilFunctions.getSeqList(1, clen, 1), unionDistinct(rcIDs,
haIDs)),
+ binIDs), weIDs);
List<Integer> oIDs = Arrays.asList(ArrayUtils
.toObject(TfMetaUtils.parseJsonIDList(jSpec,
colnames, TfMethod.OMIT.toString(), minCol, maxCol)));
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 59c22f5640..2e64709b64 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -343,9 +343,13 @@ public class MultiColumnEncoder implements Encoder {
throw new DMLRuntimeException("Invalid input with wrong
number or rows");
boolean hasDC = false;
- for(ColumnEncoderComposite columnEncoder : _columnEncoders)
- hasDC =
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
- outputMatrixPreProcessing(out, in, hasDC);
+ boolean hasWE = false;
+ for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
+ hasDC |=
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
+ hasWE |=
columnEncoder.hasEncoder(ColumnEncoderWordEmbedding.class);
+ }
+ //hasWE = false;
+ outputMatrixPreProcessing(out, in, hasDC, hasWE);
if(k > 1) {
if(!_partitionDone) //happens if this method is
directly called
deriveNumRowPartitions(in, k);
@@ -533,7 +537,7 @@ public class MultiColumnEncoder implements Encoder {
return totMemOverhead;
}
- private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC) {
+ private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC, boolean hasWE) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
if(output.isInSparseFormat()) {
if (MatrixBlock.DEFAULT_SPARSEBLOCK !=
SparseBlock.Type.CSR
@@ -580,7 +584,7 @@ public class MultiColumnEncoder implements Encoder {
}
else {
// Allocate dense block and set nnz to total #entries
- output.allocateBlock();
+ output.allocateDenseBlock(true, hasWE);
//output.setAllNonZeros();
}
@@ -1119,12 +1123,13 @@ public class MultiColumnEncoder implements Encoder {
@Override
public Object call() throws Exception {
boolean hasUDF =
_encoder.getColumnEncoders().stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
+ boolean hasWE =
_encoder.getColumnEncoders().stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderWordEmbedding.class));
int numCols = _encoder.getNumOutCols();
boolean hasDC =
_encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
long estNNz = (long) _input.getNumRows() * (hasUDF ?
numCols : (long) _input.getNumColumns());
boolean sparse =
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) &&
!hasUDF;
_output.reset(_input.getNumRows(), numCols, sparse,
estNNz);
- outputMatrixPreProcessing(_output, _input, hasDC);
+ outputMatrixPreProcessing(_output, _input, hasDC,
hasWE);
return null;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
similarity index 96%
rename from
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
rename to
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
index 9972f6b1c6..a69e287d33 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
@@ -37,11 +37,11 @@ import java.util.List;
import java.util.Map;
import java.util.Random;
-public class TransformFrameEncodeWordEmbeddingTest extends AutomatedTestBase
+public class TransformFrameEncodeWordEmbedding1Test extends AutomatedTestBase
{
private final static String TEST_NAME1 =
"TransformFrameEncodeWordEmbeddings";
private final static String TEST_DIR = "functions/transform/";
- private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbedding1Test.class.getSimpleName() + "/";
@Override
public void setUp() {
@@ -61,7 +61,7 @@ public class TransformFrameEncodeWordEmbeddingTest extends
AutomatedTestBase
try
{
int rows = 100;
- int cols = 100;
+ int cols = 300;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
@@ -72,7 +72,7 @@ public class TransformFrameEncodeWordEmbeddingTest extends
AutomatedTestBase
// Generate the dictionary by assigning unique ID to each distinct
token
Map<String,Integer> map = writeDictToCsvFile(strings,
baseDirectory + INPUT_DIR + "dict");
// Create the dataset by repeating and shuffling the distinct
tokens
- List<String> stringsColumn = shuffleAndMultiplyStrings(strings,
10);
+ List<String> stringsColumn = shuffleAndMultiplyStrings(strings,
32);
writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR +
"data");
programArgs = new String[]{"-stats","-args", input("embeddings"),
input("data"), input("dict"), output("result")};
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
index 8ab52d9f64..06c8b6ee0b 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -42,18 +42,18 @@ import java.util.Random;
public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
{
private final static String TEST_NAME1 =
"TransformFrameEncodeWordEmbeddings2";
- private final static String TEST_NAME2 =
"TransformFrameEncodeWordEmbeddings2MultiCols1";
- private final static String TEST_NAME3 =
"TransformFrameEncodeWordEmbeddings2MultiCols2";
+ private final static String TEST_NAME2a =
"TransformFrameEncodeWordEmbeddings2MultiCols1";
+ private final static String TEST_NAME2b =
"TransformFrameEncodeWordEmbeddings2MultiCols2";
private final static String TEST_DIR = "functions/transform/";
- private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbedding1Test.class.getSimpleName() + "/";
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR,
TEST_NAME1));
- addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR,
TEST_NAME2));
- addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR,
TEST_NAME3));
+ addTestConfiguration(TEST_NAME2a, new TestConfiguration(TEST_DIR,
TEST_NAME2a));
+ addTestConfiguration(TEST_NAME2b, new TestConfiguration(TEST_DIR,
TEST_NAME2b));
}
@Test
@@ -64,23 +64,30 @@ public class TransformFrameEncodeWordEmbedding2Test extends
AutomatedTestBase
@Test
@Ignore
public void testNonRandomTransformToWordEmbeddings2Cols() {
- runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+ runTransformTest(TEST_NAME2a, ExecMode.SINGLE_NODE);
}
@Test
@Ignore
public void testRandomTransformToWordEmbeddings4Cols() {
- runTransformTestMultiCols(TEST_NAME3, ExecMode.SINGLE_NODE);
+ runTransformTestMultiCols(TEST_NAME2b, ExecMode.SINGLE_NODE);
}
- private void runTransformTest(String testname, ExecMode rt)
+ @Test
+ @Ignore
+ public void runBenchmark(){
+ runBenchmark(TEST_NAME1, ExecMode.SINGLE_NODE);
+ }
+
+
+ private void runBenchmark(String testname, ExecMode rt)
{
//set runtime platform
ExecMode rtold = setExecMode(rt);
try
{
int rows = 100;
- int cols = 100;
+ int cols = 300;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
@@ -100,6 +107,43 @@ public class TransformFrameEncodeWordEmbedding2Test
extends AutomatedTestBase
//run script
programArgs = new String[]{"-stats","-args", input("embeddings"),
input("data"), input("dict"), output("result")};
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+
+ }
+ finally {
+ resetExecMode(rtold);
+ }
+ }
+
+ private void runTransformTest(String testname, ExecMode rt)
+ {
+ //set runtime platform
+ ExecMode rtold = setExecMode(rt);
+ try
+ {
+ int rows = 100;
+ int cols = 300;
+ getAndLoadTestConfiguration(testname);
+ fullDMLScriptName = getScript();
+
+ // Generate random embeddings for the distinct tokens
+ double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10,
1, new Date().getTime());
+
+ // Generate random distinct tokens
+ List<String> strings = generateRandomStrings(rows, 10);
+
+ // Generate the dictionary by assigning unique ID to each distinct
token
+ Map<String,Integer> map = writeDictToCsvFile(strings,
baseDirectory + INPUT_DIR + "dict");
+
+ // Create the dataset by repeating and shuffling the distinct
tokens
+ List<String> stringsColumn = shuffleAndMultiplyStrings(strings,
32);
+ writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR +
"data");
+
+ //run script
+ programArgs = new String[]{"-stats","-args", input("embeddings"),
input("data"), input("dict"), output("result")};
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
// Manually derive the expected result
double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a,
map, stringsColumn);
@@ -107,10 +151,6 @@ public class TransformFrameEncodeWordEmbedding2Test
extends AutomatedTestBase
// Compare results
HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
double[][] resultActualDouble =
TestUtils.convertHashMapToDoubleArray(res_actual);
- //System.out.println("Actual Result [" + resultActualDouble.length
+ "x" + resultActualDouble[0].length + "]:");
- //print2DimDoubleArray(resultActualDouble);
- //System.out.println("\nExpected Result [" + res_expected.length +
"x" + res_expected[0].length + "]:");
- //print2DimDoubleArray(res_expected);
TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
}
catch(Exception ex) {
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
index 1aa1fb0fed..dcab56b0fd 100644
---
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
# Read the pre-trained word embeddings
-E = read($1, rows=100, cols=100, format="text");
+E = read($1, rows=100, cols=300, format="text");
# Read the token sequence (1K) w/ 100 distinct tokens
Data = read($2, data_type="frame", format="csv");
# Read the recode map for the distinct tokens
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
index 29a4bfab74..139bf2b9f4 100644
---
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
# Read the pre-trained word embeddings
-E = read($1, rows=100, cols=100, format="text");
+E = read($1, rows=100, cols=300, format="text");
# Read the token sequence (1K) w/ 100 distinct tokens
Data = read($2, data_type="frame", format="csv");
# Read the recode map for the distinct tokens
@@ -28,7 +28,6 @@ Meta = read($3, data_type="frame", format="csv");
jspec = "{ids: true, word_embedding: [1]}";
Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
-
write(Data_enc, $4, format="text");