manolama commented on code in PR #38423:
URL: https://github.com/apache/arrow/pull/38423#discussion_r1385882993


##########
java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java:
##########
@@ -846,4 +862,293 @@ protected void validateListAsMapData(VectorSchemaRoot 
root) {
       }
     }
   }
+
+  /**
+   * Utility to write permutations of dictionary encoding.
+   *
+   * state == 1, one delta dictionary.
+   * state == 2, one standalone dictionary.
+   * state == 3, one of each
+   * state == 4, delta with nothing at start and end
+   * state == 5, both deltas
+   * state == 6, both deltas and standalone
+   * state == 7, replacement dictionary
+   */
+  protected void writeDataMultiBatchWithDictionaries(OutputStream stream, int 
state) throws IOException {
+    DictionaryProvider.MapDictionaryProvider provider = new 
DictionaryProvider.MapDictionaryProvider();
+    DictionaryEncoding deltaEncoding =
+        new DictionaryEncoding(42, false, new ArrowType.Int(16, false), true);
+    DictionaryEncoding replacementEncoding =
+        new DictionaryEncoding(24, false, new ArrowType.Int(16, false), false);
+    DictionaryEncoding deltaCEncoding =
+        new DictionaryEncoding(1, false, new ArrowType.Int(16, false), true);
+    DictionaryEncoding replacementEncodingUpdated =
+        new DictionaryEncoding(2, false, new ArrowType.Int(16, false), false);
+
+    boolean isFile = stream instanceof FileOutputStream;
+    try (BatchedDictionary vectorA = newDictionary("vectorA", deltaEncoding, 
isFile);
+         BatchedDictionary vectorB = newDictionary("vectorB", 
replacementEncoding, isFile);
+         BatchedDictionary vectorC = newDictionary("vectorC", deltaCEncoding, 
isFile);
+         BatchedDictionary vectorD = newDictionary("vectorD", 
replacementEncodingUpdated, isFile);) {
+      switch (state) {
+        case 1:
+          provider.put(vectorA);
+          break;
+        case 2:
+          provider.put(vectorB);
+          break;
+        case 3:
+          provider.put(vectorA);
+          provider.put(vectorB);
+          break;
+        case 4:
+          provider.put(vectorC);
+          break;
+        case 5:
+          provider.put(vectorA);
+          provider.put(vectorC);
+          break;
+        case 6:
+          provider.put(vectorA);
+          provider.put(vectorB);
+          provider.put(vectorC);
+          break;
+        case 7:
+          provider.put(vectorD);
+          break;
+        default:
+          throw new IllegalStateException("Unsupported state: " + state);
+      }
+
+      VectorSchemaRoot root = null;
+      switch (state) {
+        case 1:
+          root = VectorSchemaRoot.of(vectorA.getIndexVector());
+          break;
+        case 2:
+          root = VectorSchemaRoot.of(vectorB.getIndexVector());
+          break;
+        case 3:
+          root = VectorSchemaRoot.of(vectorA.getIndexVector(), 
vectorB.getIndexVector());
+          break;
+        case 4:
+          root = VectorSchemaRoot.of(vectorC.getIndexVector());
+          break;
+        case 5:
+          root = VectorSchemaRoot.of(vectorA.getIndexVector(), 
vectorC.getIndexVector());
+          break;
+        case 6:
+          root = VectorSchemaRoot.of(vectorA.getIndexVector(), 
vectorB.getIndexVector(), vectorC.getIndexVector());
+          break;
+        case 7:
+          root = VectorSchemaRoot.of(vectorD.getIndexVector());
+          break;
+        default:
+          throw new IllegalStateException("Unsupported state: " + state);
+      }
+
+      ArrowWriter arrowWriter = null;
+      try {
+        if (stream instanceof FileOutputStream) {
+          FileChannel channel = ((FileOutputStream) stream).getChannel();
+          arrowWriter = new ArrowFileWriter(root, provider, channel);
+        } else {
+          arrowWriter = new ArrowStreamWriter(root, provider, stream);
+        }
+
+        vectorA.setSafe(0, "foo".getBytes(StandardCharsets.UTF_8));
+        vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8));
+        vectorB.setSafe(0, "lorem".getBytes(StandardCharsets.UTF_8));
+        vectorB.setSafe(1, "ipsum".getBytes(StandardCharsets.UTF_8));
+        vectorC.setNull(0);
+        vectorC.setNull(1);
+        vectorD.setSafe(0, "porro".getBytes(StandardCharsets.UTF_8));
+        vectorD.setSafe(1, "amet".getBytes(StandardCharsets.UTF_8));
+
+        // batch 1
+        arrowWriter.start();
+        root.setRowCount(2);
+        arrowWriter.writeBatch();
+
+        // batch 2
+        vectorA.setSafe(0, "meep".getBytes(StandardCharsets.UTF_8));
+        vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8));
+        vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8));
+        vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8));
+        vectorC.setSafe(0, "qui".getBytes(StandardCharsets.UTF_8));
+        vectorC.setSafe(1, "dolor".getBytes(StandardCharsets.UTF_8));
+        vectorD.setSafe(0, "amet".getBytes(StandardCharsets.UTF_8));
+        if (state == 7) {
+          vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8));
+        }
+
+        root.setRowCount(2);
+        arrowWriter.writeBatch();
+
+        // batch 3
+        vectorA.setNull(0);
+        vectorA.setNull(1);
+        vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8));
+        vectorB.setNull(1);
+        vectorC.setNull(0);
+        vectorC.setSafe(1, "qui".getBytes(StandardCharsets.UTF_8));
+        vectorD.setNull(0);
+        if (state == 7) {
+          vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8));
+        }
+
+        root.setRowCount(2);
+        arrowWriter.writeBatch();
+
+        // batch 4
+        vectorA.setSafe(0, "bar".getBytes(StandardCharsets.UTF_8));
+        vectorA.setSafe(1, "zap".getBytes(StandardCharsets.UTF_8));
+        vectorB.setNull(0);
+        vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8));
+        vectorC.setNull(0);
+        vectorC.setNull(1);
+        if (state == 7) {
+          vectorD.setSafe(0, "quia".getBytes(StandardCharsets.UTF_8));
+        }
+        vectorD.setNull(1);
+
+        root.setRowCount(2);
+        arrowWriter.writeBatch();
+
+        arrowWriter.end();
+      } catch (Exception e) {
+        if (arrowWriter != null) {
+          arrowWriter.close();
+        }
+        throw e;
+      }
+    }
+  }
+
+  Map<Integer, String[][]> valuesPerBlock = new HashMap<Integer, String[][]>();
+
+  {
+    valuesPerBlock.put(0, new String[][]{
+        {"foo", "bar"},
+        {"lorem", "ipsum"},
+        {null, null},
+        {"porro", "amet"}
+    });
+    valuesPerBlock.put(1, new String[][]{
+        {"meep", "bar"},
+        {"ipsum", "lorem"},
+        {"qui", "dolor"},
+        {"amet", "quia"}
+    });
+    valuesPerBlock.put(2, new String[][]{
+        {null, null},
+        {"ipsum", null},
+        {null, "qui"},
+        {null, "quia"}
+    });
+    valuesPerBlock.put(3, new String[][]{
+        {"bar", "zap"},
+        {null, "lorem"},
+        {null, null},
+        {"quia", null}
+    });
+  }
+
+  protected void assertDictionary(FieldVector encoded, ArrowReader reader, 
String... expected) throws Exception {
+    DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary();
+    BaseDictionary dictionary = 
reader.getDictionaryVectors().get(dictionaryEncoding.getId());
+    try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) {
+      Assertions.assertEquals(expected.length, encoded.getValueCount());
+      for (int i = 0; i < expected.length; i++) {
+        if (expected[i] == null) {
+          Assertions.assertNull(decoded.getObject(i));
+        } else {
+          assertNotNull(decoded.getObject(i));
+          Assertions.assertEquals(expected[i], 
decoded.getObject(i).toString());
+        }
+      }
+    }
+  }
+
+  protected void assertBlock(File file, int block, int state) throws Exception 
{
+    try (FileInputStream fileInputStream = new FileInputStream(file);
+         ArrowFileReader reader = new 
ArrowFileReader(fileInputStream.getChannel(), allocator);) {
+      reader.loadRecordBatch(reader.getRecordBlocks().get(block));
+      assertBlock(reader, block, state);
+    }
+  }
+
+  protected void assertBlock(ArrowReader reader, int block, int state) throws 
Exception {
+    VectorSchemaRoot r = reader.getVectorSchemaRoot();
+    FieldVector dictA = r.getVector("vectorA");
+    FieldVector dictB = r.getVector("vectorB");
+    FieldVector dictC = r.getVector("vectorC");
+    FieldVector dictD = r.getVector("vectorD");
+
+    switch (state) {
+      case 1:
+        assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]);
+        assertNull(dictB);
+        assertNull(dictC);
+        assertNull(dictD);
+        break;
+      case 2:
+        assertNull(dictA);
+        assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]);
+        assertNull(dictC);
+        assertNull(dictD);
+        break;
+      case 3:
+        assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]);
+        assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]);
+        assertNull(dictC);
+        assertNull(dictD);
+        break;
+      case 4:
+        assertNull(dictA);
+        assertNull(dictB);
+        assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]);
+        assertNull(dictD);
+        break;
+      case 5:
+        assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]);
+        assertNull(dictB);
+        assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]);
+        assertNull(dictD);
+        break;
+      case 6:
+        assertDictionary(dictA, reader, valuesPerBlock.get(block)[0]);
+        assertDictionary(dictB, reader, valuesPerBlock.get(block)[1]);
+        assertDictionary(dictC, reader, valuesPerBlock.get(block)[2]);
+        assertNull(dictD);
+        break;
+      case 7:
+        assertNull(dictA);
+        assertNull(dictB);
+        assertNull(dictC);
+        assertDictionary(dictD, reader, valuesPerBlock.get(block)[3]);
+        break;
+      default:
+        throw new IllegalStateException("Unsupported state: " + state);
+    }
+  }
+
+  protected static Collection<Arguments> dictionaryParams() {
+    List<Arguments> params = new ArrayList<>();
+    for (int i = 1; i < 8; i++) {

Review Comment:
   Just the number of unique test states. I'll drop a note.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to