This is an automated email from the ASF dual-hosted git repository.
nielsbasjes pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git
The following commit(s) were added to refs/heads/master by this push:
new fa0bb7098 AVRO-3717: [Java] Fix NPE when basic type with Nullable
annotation.
fa0bb7098 is described below
commit fa0bb7098083aba41c9aa9d9cf6383eb3e2c2696
Author: Yan Zhao <[email protected]>
AuthorDate: Sat Feb 18 23:41:56 2023 +0800
AVRO-3717: [Java] Fix NPE when basic type with Nullable annotation.
---
.../java/org/apache/avro/reflect/FieldAccess.java | 16 ++
.../apache/avro/reflect/FieldAccessReflect.java | 24 ++-
.../org/apache/avro/reflect/FieldAccessUnsafe.java | 16 +-
.../avro/reflect/TestReflectDatumReader.java | 209 +++++++++++++++++++++
4 files changed, 256 insertions(+), 9 deletions(-)
diff --git
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
index 961884951..dce1aed98 100644
--- a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
+++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
@@ -21,6 +21,22 @@ import java.lang.reflect.Field;
abstract class FieldAccess {
+ protected static final int INT_DEFAULT_VALUE = 0;
+
+ protected static final float FLOAT_DEFAULT_VALUE = 0.0f;
+
+ protected static final short SHORT_DEFAULT_VALUE = (short) 0;
+
+ protected static final byte BYTE_DEFAULT_VALUE = (byte) 0;
+
+ protected static final boolean BOOLEAN_DEFAULT_VALUE = false;
+
+ protected static final char CHAR_DEFAULT_VALUE = '\u0000';
+
+ protected static final long LONG_DEFAULT_VALUE = 0L;
+
+ protected static final double DOUBLE_DEFAULT_VALUE = 0.0d;
+
protected abstract FieldAccessor getAccessor(Field field);
}
diff --git
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
index c790dbfb8..5d51be054 100644
---
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
+++
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
@@ -62,7 +62,29 @@ class FieldAccessReflect extends FieldAccess {
@Override
public void set(Object object, Object value) throws
IllegalAccessException, IOException {
- field.set(object, value);
+ if (value == null && field.getType().isPrimitive()) {
+ Object defaultValue = null;
+ if (int.class.equals(field.getType())) {
+ defaultValue = INT_DEFAULT_VALUE;
+ } else if (float.class.equals(field.getType())) {
+ defaultValue = FLOAT_DEFAULT_VALUE;
+ } else if (short.class.equals(field.getType())) {
+ defaultValue = SHORT_DEFAULT_VALUE;
+ } else if (byte.class.equals(field.getType())) {
+ defaultValue = BYTE_DEFAULT_VALUE;
+ } else if (boolean.class.equals(field.getType())) {
+ defaultValue = BOOLEAN_DEFAULT_VALUE;
+ } else if (char.class.equals(field.getType())) {
+ defaultValue = CHAR_DEFAULT_VALUE;
+ } else if (long.class.equals(field.getType())) {
+ defaultValue = LONG_DEFAULT_VALUE;
+ } else if (double.class.equals(field.getType())) {
+ defaultValue = DOUBLE_DEFAULT_VALUE;
+ }
+ field.set(object, defaultValue);
+ } else {
+ field.set(object, value);
+ }
}
@Override
diff --git
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
index f555df49a..a2c5c4e1b 100644
---
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
+++
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
@@ -106,7 +106,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putInt(object, offset, (Integer) value);
+ UNSAFE.putInt(object, offset, value == null ? INT_DEFAULT_VALUE :
(Integer) value);
}
@Override
@@ -132,7 +132,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putFloat(object, offset, (Float) value);
+ UNSAFE.putFloat(object, offset, value == null ? FLOAT_DEFAULT_VALUE :
(Float) value);
}
@Override
@@ -158,7 +158,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putShort(object, offset, (Short) value);
+ UNSAFE.putShort(object, offset, value == null ? SHORT_DEFAULT_VALUE :
(Short) value);
}
@Override
@@ -184,7 +184,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putByte(object, offset, (Byte) value);
+ UNSAFE.putByte(object, offset, value == null ? BYTE_DEFAULT_VALUE :
(Byte) value);
}
@Override
@@ -210,7 +210,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putBoolean(object, offset, (Boolean) value);
+ UNSAFE.putBoolean(object, offset, value == null ? BOOLEAN_DEFAULT_VALUE
: (Boolean) value);
}
@Override
@@ -236,7 +236,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putChar(object, offset, (Character) value);
+ UNSAFE.putChar(object, offset, value == null ? CHAR_DEFAULT_VALUE :
(Character) value);
}
@Override
@@ -262,7 +262,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putLong(object, offset, (Long) value);
+ UNSAFE.putLong(object, offset, value == null ? LONG_DEFAULT_VALUE :
(Long) value);
}
@Override
@@ -288,7 +288,7 @@ class FieldAccessUnsafe extends FieldAccess {
@Override
protected void set(Object object, Object value) {
- UNSAFE.putDouble(object, offset, (Double) value);
+ UNSAFE.putDouble(object, offset, value == null ? DOUBLE_DEFAULT_VALUE :
(Double) value);
}
@Override
diff --git
a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
index 65d01307e..52b40b87b 100644
---
a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
+++
b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
@@ -30,6 +30,7 @@ import java.util.Set;
import java.util.Map;
import java.util.Optional;
+import org.apache.avro.Schema;
import org.apache.avro.io.Decoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.Encoder;
@@ -160,6 +161,36 @@ public class TestReflectDatumReader {
assertEquals(pojoWithOptional, deserialized);
}
+ @Test
+ public void testRead_PojoWithNullableAnnotation() throws IOException {
+ PojoWithBasicTypeNullableAnnotationV1 v1Pojo = new
PojoWithBasicTypeNullableAnnotationV1();
+ int idValue = 1;
+ v1Pojo.setId(idValue);
+ byte[] serializedBytes = serializeWithReflectDatumWriter(v1Pojo,
PojoWithBasicTypeNullableAnnotationV1.class);
+ Decoder decoder = DecoderFactory.get().binaryDecoder(serializedBytes,
null);
+
+ ReflectData reflectData = ReflectData.get();
+ Schema schemaV1 =
reflectData.getSchema(PojoWithBasicTypeNullableAnnotationV1.class);
+ Schema schemaV2 =
reflectData.getSchema(PojoWithBasicTypeNullableAnnotationV2.class);
+
+ ReflectDatumReader<PojoWithBasicTypeNullableAnnotationV2>
reflectDatumReader = new ReflectDatumReader<>(schemaV1,
+ schemaV2);
+
+ PojoWithBasicTypeNullableAnnotationV2 v2Pojo = new
PojoWithBasicTypeNullableAnnotationV2();
+ reflectDatumReader.read(v2Pojo, decoder);
+
+ assertEquals(v1Pojo.id, v2Pojo.id);
+ assertEquals(v2Pojo.id, idValue);
+ assertEquals(v2Pojo.intId, FieldAccess.INT_DEFAULT_VALUE);
+ assertEquals(v2Pojo.floatId, FieldAccess.FLOAT_DEFAULT_VALUE);
+ assertEquals(v2Pojo.shortId, FieldAccess.SHORT_DEFAULT_VALUE);
+ assertEquals(v2Pojo.byteId, FieldAccess.BYTE_DEFAULT_VALUE);
+ assertEquals(v2Pojo.booleanId, FieldAccess.BOOLEAN_DEFAULT_VALUE);
+ assertEquals(v2Pojo.charId, FieldAccess.CHAR_DEFAULT_VALUE);
+ assertEquals(v2Pojo.longId, FieldAccess.LONG_DEFAULT_VALUE);
+ assertEquals(v2Pojo.doubleId, FieldAccess.DOUBLE_DEFAULT_VALUE);
+ }
+
public static class PojoWithList {
private int id;
private List<Integer> relatedIds;
@@ -392,4 +423,182 @@ public class TestReflectDatumReader {
return relatedId.equals(other.relatedId);
}
}
+
+ public static class PojoWithBasicTypeNullableAnnotationV1 {
+
+ private int id;
+
+ public int getId() {
+ return id;
+ }
+
+ public void setId(int id) {
+ this.id = id;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + id;
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ PojoWithBasicTypeNullableAnnotationV1 other =
(PojoWithBasicTypeNullableAnnotationV1) obj;
+ return id == other.id;
+ }
+ }
+
+ public static class PojoWithBasicTypeNullableAnnotationV2 {
+
+ private int id;
+
+ @Nullable
+ private int intId;
+
+ @Nullable
+ private float floatId;
+
+ @Nullable
+ private short shortId;
+
+ @Nullable
+ private byte byteId;
+
+ @Nullable
+ private boolean booleanId;
+
+ @Nullable
+ private char charId;
+
+ @Nullable
+ private long longId;
+
+ @Nullable
+ private double doubleId;
+
+ public int getId() {
+ return id;
+ }
+
+ public void setId(int id) {
+ this.id = id;
+ }
+
+ public int getIntId() {
+ return intId;
+ }
+
+ public void setIntId(int intId) {
+ this.intId = intId;
+ }
+
+ public float getFloatId() {
+ return floatId;
+ }
+
+ public void setFloatId(float floatId) {
+ this.floatId = floatId;
+ }
+
+ public short getShortId() {
+ return shortId;
+ }
+
+ public void setShortId(short shortId) {
+ this.shortId = shortId;
+ }
+
+ public byte getByteId() {
+ return byteId;
+ }
+
+ public void setByteId(byte byteId) {
+ this.byteId = byteId;
+ }
+
+ public boolean isBooleanId() {
+ return booleanId;
+ }
+
+ public void setBooleanId(boolean booleanId) {
+ this.booleanId = booleanId;
+ }
+
+ public char getCharId() {
+ return charId;
+ }
+
+ public void setCharId(char charId) {
+ this.charId = charId;
+ }
+
+ public long getLongId() {
+ return longId;
+ }
+
+ public void setLongId(long longId) {
+ this.longId = longId;
+ }
+
+ public double getDoubleId() {
+ return doubleId;
+ }
+
+ public void setDoubleId(double doubleId) {
+ this.doubleId = doubleId;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ long temp;
+ int result = 1;
+ result = prime * result + id;
+ result = prime * result + intId;
+ result = prime * result + (floatId != 0.0f ?
Float.floatToIntBits(floatId) : 0);
+ result = prime * result + (int) shortId;
+ result = prime * result + (int) byteId;
+ result = prime * result + (booleanId ? 1 : 0);
+ result = prime * result + (int) charId;
+ result = prime * result + (int) (longId ^ (longId >>> 32));
+ temp = Double.doubleToLongBits(doubleId);
+ result = 31 * result + (int) (temp ^ (temp >>> 32));
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+ PojoWithBasicTypeNullableAnnotationV2 that =
(PojoWithBasicTypeNullableAnnotationV2) o;
+ if (id != that.id)
+ return false;
+ if (intId != that.intId)
+ return false;
+ if (Float.compare(that.floatId, floatId) != 0)
+ return false;
+ if (shortId != that.shortId)
+ return false;
+ if (byteId != that.byteId)
+ return false;
+ if (booleanId != that.booleanId)
+ return false;
+ if (charId != that.charId)
+ return false;
+ if (longId != that.longId)
+ return false;
+ return Double.compare(that.doubleId, doubleId) == 0;
+ }
+ }
}