Jackie-Jiang commented on a change in pull request #7542:
URL: https://github.com/apache/pinot/pull/7542#discussion_r726442106



##########
File path: 
pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InTransformFunction.java
##########
@@ -86,125 +139,239 @@ public TransformResultMetadata getResultMetadata() {
 
   @Override
   public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
-    if (_results == null) {
-      _results = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    if (_intValuesSV == null) {
+      _intValuesSV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    } else {
+      Arrays.fill(_intValuesSV, 0);
     }
 
     int length = projectionBlock.getNumDocs();
-    FieldSpec.DataType storedType = 
_transformFunction.getResultMetadata().getDataType().getStoredType();
-    switch (storedType) {
-      case INT:
-        int[] intValues = 
_transformFunction.transformToIntValuesSV(projectionBlock);
-        if (!_stringValueSet.isEmpty()) {
-          Set<Integer> inIntValues = new HashSet<>();
-          for (String inValue : _stringValueSet) {
-            inIntValues.add(Integer.parseInt(inValue));
-          }
-          for (int i = 0; i < length; i++) {
-            _results[i] = inIntValues.contains(intValues[i]) ? 1 : 0;
-          }
-        } else {
-          int[][] inIntValues = new int[_valueTransformFunctions.length][];
-          for (int i = 0; i < _valueTransformFunctions.length; i++) {
-            inIntValues[i] = 
_valueTransformFunctions[i].transformToIntValuesSV(projectionBlock);
-          }
-          for (int i = 0; i < length; i++) {
-            for (int j = 0; j < inIntValues.length; j++) {
-              _results[i] = inIntValues[j][i] == intValues[i] ? 1 : 
_results[i];
+    TransformResultMetadata mainFunctionMetadata = 
_mainFunction.getResultMetadata();
+    DataType storedType = mainFunctionMetadata.getDataType().getStoredType();
+    if (_valueSet != null) {
+      if (_mainFunction.getResultMetadata().isSingleValue()) {
+        switch (storedType) {
+          case INT:
+            IntOpenHashSet inIntValues = (IntOpenHashSet) _valueSet;
+            int[] intValues = 
_mainFunction.transformToIntValuesSV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              if (inIntValues.contains(intValues[i])) {
+                _intValuesSV[i] = 1;
+              }
             }
-          }
+            break;
+          case LONG:
+            LongOpenHashSet inLongValues = (LongOpenHashSet) _valueSet;
+            long[] longValues = 
_mainFunction.transformToLongValuesSV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              if (inLongValues.contains(longValues[i])) {
+                _intValuesSV[i] = 1;
+              }
+            }
+            break;
+          case FLOAT:
+            FloatOpenHashSet inFloatValues = (FloatOpenHashSet) _valueSet;
+            float[] floatValues = 
_mainFunction.transformToFloatValuesSV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              if (inFloatValues.contains(floatValues[i])) {
+                _intValuesSV[i] = 1;
+              }
+            }
+            break;
+          case DOUBLE:
+            DoubleOpenHashSet inDoubleValues = (DoubleOpenHashSet) _valueSet;
+            double[] doubleValues = 
_mainFunction.transformToDoubleValuesSV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              if (inDoubleValues.contains(doubleValues[i])) {
+                _intValuesSV[i] = 1;
+              }
+            }
+            break;
+          case STRING:
+            ObjectOpenHashSet<String> inStringValues = 
(ObjectOpenHashSet<String>) _valueSet;
+            String[] stringValues = 
_mainFunction.transformToStringValuesSV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              if (inStringValues.contains(stringValues[i])) {
+                _intValuesSV[i] = 1;
+              }
+            }
+            break;
+          case BYTES:
+            ObjectOpenHashSet<ByteArray> inBytesValues = 
(ObjectOpenHashSet<ByteArray>) _valueSet;
+            byte[][] bytesValues = 
_mainFunction.transformToBytesValuesSV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              if (inBytesValues.contains(new ByteArray(bytesValues[i]))) {
+                _intValuesSV[i] = 1;
+              }
+            }
+            break;
+          default:
+            throw new IllegalStateException();
         }
-        break;
-      case LONG:
-        long[] longValues = 
_transformFunction.transformToLongValuesSV(projectionBlock);
-        if (!_stringValueSet.isEmpty()) {
-          Set<Long> inLongValues = new HashSet<>();
-          for (String inValue : _stringValueSet) {
-            inLongValues.add(Long.parseLong(inValue));
-          }
-          for (int i = 0; i < length; i++) {
-            _results[i] = inLongValues.contains(longValues[i]) ? 1 : 0;
-          }
-        } else {
-          long[][] inLongValues = new long[_valueTransformFunctions.length][];
-          for (int i = 0; i < _valueTransformFunctions.length; i++) {
-            inLongValues[i] = 
_valueTransformFunctions[i].transformToLongValuesSV(projectionBlock);
-          }
-          for (int i = 0; i < length; i++) {
-            for (int j = 0; j < inLongValues.length; j++) {
-              _results[i] = inLongValues[j][i] == longValues[i] ? 1 : 
_results[i];
+      } else {
+        switch (storedType) {
+          case INT:
+            IntOpenHashSet inIntValues = (IntOpenHashSet) _valueSet;
+            int[][] intValues = 
_mainFunction.transformToIntValuesMV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              for (int intValue : intValues[i]) {
+                if (inIntValues.contains(intValue)) {
+                  _intValuesSV[i] = 1;
+                  break;
+                }
+              }
             }
-          }
+            break;
+          case LONG:
+            LongOpenHashSet inLongValues = (LongOpenHashSet) _valueSet;
+            long[][] longValues = 
_mainFunction.transformToLongValuesMV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              for (long longValue : longValues[i]) {
+                if (inLongValues.contains(longValue)) {
+                  _intValuesSV[i] = 1;
+                  break;
+                }
+              }
+            }
+            break;
+          case FLOAT:
+            FloatOpenHashSet inFloatValues = (FloatOpenHashSet) _valueSet;
+            float[][] floatValues = 
_mainFunction.transformToFloatValuesMV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              for (float floatValue : floatValues[i]) {
+                if (inFloatValues.contains(floatValue)) {
+                  _intValuesSV[i] = 1;
+                  break;
+                }
+              }
+            }
+            break;
+          case DOUBLE:
+            DoubleOpenHashSet inDoubleValues = (DoubleOpenHashSet) _valueSet;
+            double[][] doubleValues = 
_mainFunction.transformToDoubleValuesMV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              for (double doubleValue : doubleValues[i]) {
+                if (inDoubleValues.contains(doubleValue)) {
+                  _intValuesSV[i] = 1;
+                  break;
+                }
+              }
+            }
+            break;
+          case STRING:
+            ObjectOpenHashSet<String> inStringValues = 
(ObjectOpenHashSet<String>) _valueSet;
+            String[][] stringValues = 
_mainFunction.transformToStringValuesMV(projectionBlock);
+            for (int i = 0; i < length; i++) {
+              for (String stringValue : stringValues[i]) {
+                if (inStringValues.contains(stringValue)) {
+                  _intValuesSV[i] = 1;
+                  break;
+                }
+              }
+            }
+            break;
+          default:
+            throw new IllegalStateException();
         }
-        break;
-      case FLOAT:
-        float[] floatValues = 
_transformFunction.transformToFloatValuesSV(projectionBlock);
-        if (!_stringValueSet.isEmpty()) {
-          Set<Float> inFloatValues = new HashSet<>();
-          for (String inValue : _stringValueSet) {
-            inFloatValues.add(Float.parseFloat(inValue));
+      }
+    } else {
+      int numValues = _valueFunctions.length;
+      switch (storedType) {
+        case INT:
+          int[] intValues = 
_mainFunction.transformToIntValuesSV(projectionBlock);
+          int[][] inIntValues = new int[numValues][];
+          for (int i = 0; i < numValues; i++) {
+            inIntValues[i] = 
_valueFunctions[i].transformToIntValuesSV(projectionBlock);
           }
           for (int i = 0; i < length; i++) {
-            _results[i] = inFloatValues.contains(floatValues[i]) ? 1 : 0;
+            for (int[] inIntValue : inIntValues) {
+              if (intValues[i] == inIntValue[i]) {
+                _intValuesSV[i] = 1;
+                break;
+              }
+            }
           }
-        } else {
-          float[][] inFloatValues = new 
float[_valueTransformFunctions.length][];
-          for (int i = 0; i < _valueTransformFunctions.length; i++) {
-            inFloatValues[i] = 
_valueTransformFunctions[i].transformToFloatValuesSV(projectionBlock);
+          break;
+        case LONG:
+          long[] longValues = 
_mainFunction.transformToLongValuesSV(projectionBlock);
+          long[][] inLongValues = new long[numValues][];
+          for (int i = 0; i < numValues; i++) {
+            inLongValues[i] = 
_valueFunctions[i].transformToLongValuesSV(projectionBlock);
           }
           for (int i = 0; i < length; i++) {
-            for (int j = 0; j < inFloatValues.length; j++) {
-              _results[i] = Float.compare(inFloatValues[j][i], floatValues[i]) 
== 0 ? 1 : _results[i];
+            for (long[] inLongValue : inLongValues) {
+              if (longValues[i] == inLongValue[i]) {
+                _intValuesSV[i] = 1;
+                break;
+              }
             }
           }
-        }
-        break;
-      case DOUBLE:
-        double[] doubleValues = 
_transformFunction.transformToDoubleValuesSV(projectionBlock);
-        if (!_stringValueSet.isEmpty()) {
-          Set<Double> inDoubleValues = new HashSet<>();
-          for (String inValue : _stringValueSet) {
-            inDoubleValues.add(Double.parseDouble(inValue));
+          break;
+        case FLOAT:
+          float[] floatValues = 
_mainFunction.transformToFloatValuesSV(projectionBlock);
+          float[][] inFloatValues = new float[numValues][];
+          for (int i = 0; i < numValues; i++) {
+            inFloatValues[i] = 
_valueFunctions[i].transformToFloatValuesSV(projectionBlock);
           }
           for (int i = 0; i < length; i++) {
-            _results[i] = inDoubleValues.contains(doubleValues[i]) ? 1 : 0;
+            for (float[] inFloatValue : inFloatValues) {
+              if (floatValues[i] == inFloatValue[i]) {
+                _intValuesSV[i] = 1;
+                break;
+              }
+            }
           }
-        } else {
-          double[][] inDoubleValues = new 
double[_valueTransformFunctions.length][];
-          for (int i = 0; i < _valueTransformFunctions.length; i++) {
-            inDoubleValues[i] = 
_valueTransformFunctions[i].transformToDoubleValuesSV(projectionBlock);
+          break;
+        case DOUBLE:
+          double[] doubleValues = 
_mainFunction.transformToDoubleValuesSV(projectionBlock);
+          double[][] inDoubleValues = new double[numValues][];
+          for (int i = 0; i < numValues; i++) {
+            inDoubleValues[i] = 
_valueFunctions[i].transformToDoubleValuesSV(projectionBlock);
           }
           for (int i = 0; i < length; i++) {
-            for (int j = 0; j < inDoubleValues.length; j++) {
-              _results[i] = Double.compare(inDoubleValues[j][i], 
doubleValues[i]) == 0 ? 1 : _results[i];
+            for (double[] inDoubleValue : inDoubleValues) {
+              if (doubleValues[i] == inDoubleValue[i]) {

Review comment:
       Good point. Changed it to the same way as `equals()` so that the 
behavior is aligned with the `HashSet`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to