This is an automated email from the ASF dual-hosted git repository.

rskraba pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git


The following commit(s) were added to refs/heads/master by this push:
     new 820ed6e5e AVRO-3531: Improve thread safety in GenericDatumReader 
(#1719)
820ed6e5e is described below

commit 820ed6e5ea4417b5735078bfd26c99f1305ea363
Author: clesaec <[email protected]>
AuthorDate: Mon Jul 4 18:38:51 2022 +0200

    AVRO-3531: Improve thread safety in GenericDatumReader (#1719)
    
    * avro-3531 : make generic datum reader thread safe
    
    * avro3531: use reentrant lock
    
    * avro-3531: spotless modif
    
    * AVRO-3531:use Conccurent hashmap
    
    * AVRO-3531: reput newInstanceFromString method for backward compatibility
    
    * Reduce visibility of the getReaderCache() method.
    
    * Fix typo from github editor
    
    Co-authored-by: Ryan Skraba <[email protected]>
---
 .../apache/avro/generic/GenericDatumReader.java    | 100 +++++++++++++----
 .../avro/generic/TestGenericDatumReader.java       | 120 +++++++++++++++++++++
 2 files changed, 197 insertions(+), 23 deletions(-)

diff --git 
a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java 
b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java
index 5b1bb31ef..b816ef365 100644
--- 
a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java
+++ 
b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java
@@ -18,13 +18,13 @@
 package org.apache.avro.generic;
 
 import java.io.IOException;
-import java.nio.ByteBuffer;
 import java.lang.reflect.Constructor;
-import java.lang.reflect.InvocationTargetException;
+import java.nio.ByteBuffer;
 import java.util.Collection;
 import java.util.HashMap;
-import java.util.IdentityHashMap;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
 
 import org.apache.avro.AvroRuntimeException;
 import org.apache.avro.Conversion;
@@ -452,14 +452,14 @@ public class GenericDatumReader<D> implements 
DatumReader<D> {
    * representation. By default, this calls {@link 
#readString(Object,Decoder)}.
    */
   protected Object readString(Object old, Schema expected, Decoder in) throws 
IOException {
-    Class stringClass = getStringClass(expected);
+    Class stringClass = this.getReaderCache().getStringClass(expected);
     if (stringClass == String.class) {
       return in.readString();
     }
     if (stringClass == CharSequence.class) {
       return readString(old, in);
     }
-    return newInstanceFromString(stringClass, in.readString());
+    return this.getReaderCache().newInstanceFromString(stringClass, 
in.readString());
   }
 
   /**
@@ -498,34 +498,88 @@ public class GenericDatumReader<D> implements 
DatumReader<D> {
     }
   }
 
-  private Map<Schema, Class> stringClassCache = new IdentityHashMap<>();
+  /**
+   * This class is used to reproduce part of IdentityHashMap in 
ConcurrentHashMap
+   * code.
+   */
+  private static final class IdentitySchemaKey {
+    private final Schema schema;
+
+    private final int hashcode;
+
+    public IdentitySchemaKey(Schema schema) {
+      this.schema = schema;
+      this.hashcode = System.identityHashCode(schema);
+    }
 
-  private Class getStringClass(Schema s) {
-    Class c = stringClassCache.get(s);
-    if (c == null) {
-      c = findStringClass(s);
-      stringClassCache.put(s, c);
+    @Override
+    public int hashCode() {
+      return this.hashcode;
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      if (obj == null || !(obj instanceof 
GenericDatumReader.IdentitySchemaKey)) {
+        return false;
+      }
+      IdentitySchemaKey key = (IdentitySchemaKey) obj;
+      return this == key || this.schema == key.schema;
     }
-    return c;
   }
 
-  private final Map<Class, Constructor> stringCtorCache = new HashMap<>();
+  // VisibleForTesting
+  static class ReaderCache {
+    private final Map<IdentitySchemaKey, Class> stringClassCache = new 
ConcurrentHashMap<>();
 
-  @SuppressWarnings("unchecked")
-  protected Object newInstanceFromString(Class c, String s) {
-    try {
-      Constructor ctor = stringCtorCache.get(c);
-      if (ctor == null) {
+    private final Map<Class, Function<String, Object>> stringCtorCache = new 
ConcurrentHashMap<>();
+
+    private final Function<Schema, Class> findStringClass;
+
+    public ReaderCache(Function<Schema, Class> findStringClass) {
+      this.findStringClass = findStringClass;
+    }
+
+    public Object newInstanceFromString(Class c, String s) {
+      final Function<String, Object> ctor = stringCtorCache.computeIfAbsent(c, 
this::buildFunction);
+      return ctor.apply(s);
+    }
+
+    private Function<String, Object> buildFunction(Class c) {
+      final Constructor ctor;
+      try {
         ctor = c.getDeclaredConstructor(String.class);
-        ctor.setAccessible(true);
-        stringCtorCache.put(c, ctor);
+      } catch (NoSuchMethodException e) {
+        throw new AvroRuntimeException(e);
       }
-      return ctor.newInstance(s);
-    } catch (NoSuchMethodException | InvocationTargetException | 
IllegalAccessException | InstantiationException e) {
-      throw new AvroRuntimeException(e);
+      ctor.setAccessible(true);
+
+      return (String s) -> {
+        try {
+          return ctor.newInstance(s);
+        } catch (ReflectiveOperationException e) {
+          throw new AvroRuntimeException(e);
+        }
+      };
+    }
+
+    public Class getStringClass(final Schema s) {
+      final IdentitySchemaKey key = new IdentitySchemaKey(s);
+      return this.stringClassCache.computeIfAbsent(key, (IdentitySchemaKey k) 
-> this.findStringClass.apply(k.schema));
     }
   }
 
+  private final ReaderCache readerCache = new 
ReaderCache(this::findStringClass);
+
+  // VisibleForTesting
+  ReaderCache getReaderCache() {
+    return readerCache;
+  }
+
+  @SuppressWarnings("unchecked")
+  protected Object newInstanceFromString(Class c, String s) {
+    return this.getReaderCache().newInstanceFromString(c, s);
+  }
+
   /**
    * Called to read byte arrays. Subclasses may override to use a different 
byte
    * array representation. By default, this calls
diff --git 
a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java
 
b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java
new file mode 100644
index 000000000..3ec67cd11
--- /dev/null
+++ 
b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.avro.generic;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.avro.Schema;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestGenericDatumReader {
+
+  private static final Random r = new Random(System.currentTimeMillis());
+
+  @Test
+  public void testReaderCache() {
+    final GenericDatumReader.ReaderCache cache = new 
GenericDatumReader.ReaderCache(this::findStringClass);
+    List<Thread> threads = IntStream.rangeClosed(1, 200).mapToObj((int index) 
-> {
+      final Schema schema = TestGenericDatumReader.this.build(index);
+      final WithSchema s = new WithSchema(schema, cache);
+      return (Runnable) () -> s.test();
+    }).map(Thread::new).collect(Collectors.toList());
+    threads.forEach(Thread::start);
+    threads.forEach((Thread t) -> {
+      try {
+        t.join();
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    });
+  }
+
+  @Test
+  public void testNewInstanceFromString() {
+    final GenericDatumReader.ReaderCache cache = new 
GenericDatumReader.ReaderCache(this::findStringClass);
+
+    Object object = cache.newInstanceFromString(StringBuilder.class, "Hello");
+    assertEquals(StringBuilder.class, object.getClass());
+    StringBuilder builder = (StringBuilder) object;
+    assertEquals("Hello", builder.toString());
+
+  }
+
+  static class WithSchema {
+    private final Schema schema;
+
+    private final GenericDatumReader.ReaderCache cache;
+
+    public WithSchema(Schema schema, GenericDatumReader.ReaderCache cache) {
+      this.schema = schema;
+      this.cache = cache;
+    }
+
+    public void test() {
+      this.cache.getStringClass(schema);
+    }
+  }
+
+  private List<Schema> list = new ArrayList<>();
+
+  private Schema build(int index) {
+    int schemaNum = (index - 1) % 50;
+    if (index <= 50) {
+      Schema schema = Schema.createRecord("record_" + schemaNum, "doc", 
"namespace", false,
+          Arrays.asList(new Schema.Field("field" + schemaNum, 
Schema.create(Schema.Type.STRING))));
+      list.add(schema);
+    }
+
+    return list.get(schemaNum);
+  }
+
+  private Class findStringClass(Schema schema) {
+    this.sleep();
+    if (schema.getType() == Schema.Type.INT) {
+      return Integer.class;
+    }
+    if (schema.getType() == Schema.Type.STRING) {
+      return String.class;
+    }
+    if (schema.getType() == Schema.Type.LONG) {
+      return Long.class;
+    }
+    if (schema.getType() == Schema.Type.FLOAT) {
+      return Float.class;
+    }
+    return String.class;
+  }
+
+  private void sleep() {
+    long timeToSleep = r.nextInt(30) + 10L;
+    if (timeToSleep > 25) {
+      try {
+        Thread.sleep(timeToSleep);
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+}

Reply via email to