http://git-wip-us.apache.org/repos/asf/hive/blob/d467e172/serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java ---------------------------------------------------------------------- diff --git a/serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java b/serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java index 301ee8b..66e3a96 100644 --- a/serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java +++ b/serde/src/test/org/apache/hadoop/hive/serde2/SerdeRandomRowSource.java @@ -25,19 +25,29 @@ import java.util.HashSet; import java.util.List; import java.util.Random; -import org.apache.commons.lang.ArrayUtils; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.RandomTypeUtil; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableByteObjectInspector; @@ -56,10 +66,20 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObj import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableTimestampObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.BytesWritable; + +import com.google.common.base.Preconditions; +import com.google.common.base.Charsets; /** * Generate object inspector and random row object[]. @@ -72,6 +92,14 @@ public class SerdeRandomRowSource { private List<String> typeNames; + private Category[] categories; + + private TypeInfo[] typeInfos; + + private List<ObjectInspector> objectInspectorList; + + // Primitive. + private PrimitiveCategory[] primitiveCategories; private PrimitiveTypeInfo[] primitiveTypeInfos; @@ -80,10 +108,25 @@ public class SerdeRandomRowSource { private StructObjectInspector rowStructObjectInspector; + private String[] alphabets; + + private boolean allowNull; + + private boolean addEscapables; + private String needsEscapeStr; + public List<String> typeNames() { return typeNames; } + public Category[] categories() { + return categories; + } + + public TypeInfo[] typeInfos() { + return typeInfos; + } + public PrimitiveCategory[] primitiveCategories() { return primitiveCategories; } @@ -97,30 +140,37 @@ public class SerdeRandomRowSource { } public StructObjectInspector partialRowStructObjectInspector(int partialFieldCount) { - ArrayList<ObjectInspector> partialPrimitiveObjectInspectorList = + ArrayList<ObjectInspector> partialObjectInspectorList = new ArrayList<ObjectInspector>(partialFieldCount); List<String> columnNames = new ArrayList<String>(partialFieldCount); for (int i = 0; i < partialFieldCount; i++) { columnNames.add(String.format("partial%d", i)); - partialPrimitiveObjectInspectorList.add( - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - primitiveTypeInfos[i])); + partialObjectInspectorList.add(getObjectInspector(typeInfos[i])); } return ObjectInspectorFactory.getStandardStructObjectInspector( - columnNames, primitiveObjectInspectorList); + columnNames, objectInspectorList); + } + + public enum SupportedTypes { + ALL, PRIMITIVE, ALL_EXCEPT_MAP + } + + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth) { + init(r, supportedTypes, maxComplexDepth, true); } - public void init(Random r) { + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth, boolean allowNull) { this.r = r; - chooseSchema(); + this.allowNull = allowNull; + chooseSchema(supportedTypes, maxComplexDepth); } /* * For now, exclude CHAR until we determine why there is a difference (blank padding) * serializing with LazyBinarySerializeWrite and the regular SerDe... */ - private static String[] possibleHiveTypeNames = { + private static String[] possibleHivePrimitiveTypeNames = { "boolean", "tinyint", "smallint", @@ -140,67 +190,284 @@ public class SerdeRandomRowSource { "decimal" }; - private void chooseSchema() { + private static String[] possibleHiveComplexTypeNames = { + "array", + "struct", + "uniontype", + "map" + }; + + private String getRandomTypeName(SupportedTypes supportedTypes) { + String typeName = null; + if (r.nextInt(10 ) != 0) { + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + } else { + switch (supportedTypes) { + case PRIMITIVE: + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + break; + case ALL_EXCEPT_MAP: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length - 1)]; + break; + case ALL: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length)]; + break; + } + } + return typeName; + } + + private String getDecoratedTypeName(String typeName, SupportedTypes supportedTypes, int depth, int maxDepth) { + depth++; + if (depth < maxDepth) { + supportedTypes = SupportedTypes.PRIMITIVE; + } + if (typeName.equals("char")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("char(%d)", maxLength); + } else if (typeName.equals("varchar")) { + int maxLength = 1 + r.nextInt(100); + typeName = String.format("varchar(%d)", maxLength); + } else if (typeName.equals("decimal")) { + typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + } else if (typeName.equals("array")) { + String elementTypeName = getRandomTypeName(supportedTypes); + elementTypeName = getDecoratedTypeName(elementTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("array<%s>", elementTypeName); + } else if (typeName.equals("map")) { + String keyTypeName = getRandomTypeName(SupportedTypes.PRIMITIVE); + keyTypeName = getDecoratedTypeName(keyTypeName, supportedTypes, depth, maxDepth); + String valueTypeName = getRandomTypeName(supportedTypes); + valueTypeName = getDecoratedTypeName(valueTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("map<%s,%s>", keyTypeName, valueTypeName); + } else if (typeName.equals("struct")) { + final int fieldCount = 1 + r.nextInt(10); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append("col"); + sb.append(i); + sb.append(":"); + sb.append(fieldTypeName); + } + typeName = String.format("struct<%s>", sb.toString()); + } else if (typeName.equals("struct") || + typeName.equals("uniontype")) { + final int fieldCount = 1 + r.nextInt(10); + final StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append(fieldTypeName); + } + typeName = String.format("uniontype<%s>", sb.toString()); + } + return typeName; + } + + private ObjectInspector getObjectInspector(TypeInfo typeInfo) { + ObjectInspector objectInspector; + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + final PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) typeInfo; + objectInspector = + PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveType); + } + break; + case MAP: + { + final MapTypeInfo mapType = (MapTypeInfo) typeInfo; + final MapObjectInspector mapInspector = + ObjectInspectorFactory.getStandardMapObjectInspector( + getObjectInspector(mapType.getMapKeyTypeInfo()), + getObjectInspector(mapType.getMapValueTypeInfo())); + objectInspector = mapInspector; + } + break; + case LIST: + { + final ListTypeInfo listType = (ListTypeInfo) typeInfo; + final ListObjectInspector listInspector = + ObjectInspectorFactory.getStandardListObjectInspector( + getObjectInspector(listType.getListElementTypeInfo())); + objectInspector = listInspector; + } + break; + case STRUCT: + { + final StructTypeInfo structType = (StructTypeInfo) typeInfo; + final List<TypeInfo> fieldTypes = structType.getAllStructFieldTypeInfos(); + + final List<ObjectInspector> fieldInspectors = new ArrayList<ObjectInspector>(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + final StructObjectInspector structInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + structType.getAllStructFieldNames(), fieldInspectors); + objectInspector = structInspector; + } + break; + case UNION: + { + final UnionTypeInfo unionType = (UnionTypeInfo) typeInfo; + final List<TypeInfo> fieldTypes = unionType.getAllUnionObjectTypeInfos(); + + final List<ObjectInspector> fieldInspectors = new ArrayList<ObjectInspector>(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + final UnionObjectInspector unionInspector = + ObjectInspectorFactory.getStandardUnionObjectInspector( + fieldInspectors); + objectInspector = unionInspector; + } + break; + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); + } + Preconditions.checkState(objectInspector != null); + return objectInspector; + } + + private void chooseSchema(SupportedTypes supportedTypes, int maxComplexDepth) { HashSet hashSet = null; - boolean allTypes; - boolean onlyOne = (r.nextInt(100) == 7); + final boolean allTypes; + final boolean onlyOne = (r.nextInt(100) == 7); if (onlyOne) { columnCount = 1; allTypes = false; } else { allTypes = r.nextBoolean(); if (allTypes) { - // One of each type. - columnCount = possibleHiveTypeNames.length; + switch (supportedTypes) { + case ALL: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + case ALL_EXCEPT_MAP: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case PRIMITIVE: + columnCount = possibleHivePrimitiveTypeNames.length; + break; + } hashSet = new HashSet<Integer>(); } else { columnCount = 1 + r.nextInt(20); } } typeNames = new ArrayList<String>(columnCount); + categories = new Category[columnCount]; + typeInfos = new TypeInfo[columnCount]; + objectInspectorList = new ArrayList<ObjectInspector>(columnCount); + primitiveCategories = new PrimitiveCategory[columnCount]; primitiveTypeInfos = new PrimitiveTypeInfo[columnCount]; primitiveObjectInspectorList = new ArrayList<ObjectInspector>(columnCount); - List<String> columnNames = new ArrayList<String>(columnCount); + final List<String> columnNames = new ArrayList<String>(columnCount); for (int c = 0; c < columnCount; c++) { columnNames.add(String.format("col%d", c)); String typeName; if (onlyOne) { - typeName = possibleHiveTypeNames[r.nextInt(possibleHiveTypeNames.length)]; + typeName = getRandomTypeName(supportedTypes); } else { int typeNum; if (allTypes) { + int maxTypeNum = 0; + switch (supportedTypes) { + case PRIMITIVE: + maxTypeNum = possibleHivePrimitiveTypeNames.length; + break; + case ALL_EXCEPT_MAP: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case ALL: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + } while (true) { - typeNum = r.nextInt(possibleHiveTypeNames.length); - Integer typeNumInteger = new Integer(typeNum); + + typeNum = r.nextInt(maxTypeNum); + + final Integer typeNumInteger = new Integer(typeNum); if (!hashSet.contains(typeNumInteger)) { hashSet.add(typeNumInteger); break; } } } else { - typeNum = r.nextInt(possibleHiveTypeNames.length); + if (supportedTypes == SupportedTypes.PRIMITIVE || r.nextInt(10) != 0) { + typeNum = r.nextInt(possibleHivePrimitiveTypeNames.length); + } else { + typeNum = possibleHivePrimitiveTypeNames.length + r.nextInt(possibleHiveComplexTypeNames.length); + if (supportedTypes == SupportedTypes.ALL_EXCEPT_MAP) { + typeNum--; + } + } + } + if (typeNum < possibleHivePrimitiveTypeNames.length) { + typeName = possibleHivePrimitiveTypeNames[typeNum]; + } else { + typeName = possibleHiveComplexTypeNames[typeNum - possibleHivePrimitiveTypeNames.length]; + } + + } + + final String decoratedTypeName = getDecoratedTypeName(typeName, supportedTypes, 0, maxComplexDepth); + + final TypeInfo typeInfo; + try { + typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(decoratedTypeName); + } catch (Exception e) { + throw new RuntimeException("Cannot convert type name " + decoratedTypeName + " to a type " + e); + } + + typeInfos[c] = typeInfo; + final Category category = typeInfo.getCategory(); + categories[c] = category; + ObjectInspector objectInspector = getObjectInspector(typeInfo); + switch (category) { + case PRIMITIVE: + { + final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + final PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + objectInspector = PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveTypeInfo); + primitiveTypeInfos[c] = primitiveTypeInfo; + primitiveCategories[c] = primitiveCategory; + primitiveObjectInspectorList.add(objectInspector); } - typeName = possibleHiveTypeNames[typeNum]; + break; + case LIST: + case MAP: + case STRUCT: + case UNION: + primitiveObjectInspectorList.add(null); + break; + default: + throw new RuntimeException("Unexpected catagory " + category); } - if (typeName.equals("char")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("char(%d)", maxLength); - } else if (typeName.equals("varchar")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("varchar(%d)", maxLength); - } else if (typeName.equals("decimal")) { - typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + objectInspectorList.add(objectInspector); + + if (category == Category.PRIMITIVE) { } - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) TypeInfoUtils.getTypeInfoFromTypeString(typeName); - primitiveTypeInfos[c] = primitiveTypeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[c] = primitiveCategory; - primitiveObjectInspectorList.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo)); - typeNames.add(typeName); + typeNames.add(decoratedTypeName); } - rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, primitiveObjectInspectorList); + rowStructObjectInspector = ObjectInspectorFactory. + getStandardStructObjectInspector(columnNames, objectInspectorList); + alphabets = new String[columnCount]; } public Object[][] randomRows(int n) { @@ -214,23 +481,71 @@ public class SerdeRandomRowSource { public Object[] randomRow() { Object row[] = new Object[columnCount]; for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } + row[c] = randomWritable(c); } return row; } + public Object[] randomPrimitiveRow(int columnCount) { + return randomPrimitiveRow(columnCount, r, primitiveTypeInfos); + } + + public static Object[] randomPrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + final Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomPrimitiveObject(r, primitiveTypeInfos[c]); + } + return row; + } + + public static Object[] randomWritablePrimitiveRow( + int columnCount, Random r, PrimitiveTypeInfo[] primitiveTypeInfos) { + + final Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + final PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[c]; + final ObjectInspector objectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + final Object object = randomPrimitiveObject(r, primitiveTypeInfo); + row[c] = getWritablePrimitiveObject(primitiveTypeInfo, objectInspector, object); + } + return row; + } + + public void addBinarySortableAlphabets() { + for (int c = 0; c < columnCount; c++) { + switch (primitiveCategories[c]) { + case STRING: + case CHAR: + case VARCHAR: + byte[] bytes = new byte[10 + r.nextInt(10)]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = (byte) (32 + r.nextInt(96)); + } + final int alwaysIndex = r.nextInt(bytes.length); + bytes[alwaysIndex] = 0; // Must be escaped by BinarySortable. + final int alwaysIndex2 = r.nextInt(bytes.length); + bytes[alwaysIndex2] = 1; // Must be escaped by BinarySortable. + alphabets[c] = new String(bytes, Charsets.UTF_8); + break; + default: + // No alphabet needed. + break; + } + } + } + + public void addEscapables(String needsEscapeStr) { + addEscapables = true; + this.needsEscapeStr = needsEscapeStr; + } + public static void sort(Object[][] rows, ObjectInspector oi) { for (int i = 0; i < rows.length; i++) { for (int j = i + 1; j < rows.length; j++) { if (ObjectInspectorUtils.compare(rows[i], oi, rows[j], oi) > 0) { - Object[] t = rows[i]; + final Object[] t = rows[i]; rows[i] = rows[j]; rows[j] = t; } @@ -242,11 +557,9 @@ public class SerdeRandomRowSource { SerdeRandomRowSource.sort(rows, rowStructObjectInspector); } - public Object getWritableObject(int column, Object object) { - ObjectInspector objectInspector = primitiveObjectInspectorList.get(column); - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + public static Object getWritablePrimitiveObject( + PrimitiveTypeInfo primitiveTypeInfo, ObjectInspector objectInspector, Object object) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object); case BYTE: @@ -267,13 +580,13 @@ public class SerdeRandomRowSource { return ((WritableStringObjectInspector) objectInspector).create((String) object); case CHAR: { - WritableHiveCharObjectInspector writableCharObjectInspector = + final WritableHiveCharObjectInspector writableCharObjectInspector = new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); return writableCharObjectInspector.create((HiveChar) object); } case VARCHAR: { - WritableHiveVarcharObjectInspector writableVarcharObjectInspector = + final WritableHiveVarcharObjectInspector writableVarcharObjectInspector = new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); return writableVarcharObjectInspector.create((HiveVarchar) object); } @@ -287,21 +600,171 @@ public class SerdeRandomRowSource { return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).create((HiveIntervalDayTime) object); case DECIMAL: { - WritableHiveDecimalObjectInspector writableDecimalObjectInspector = + final WritableHiveDecimalObjectInspector writableDecimalObjectInspector = new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); return writableDecimalObjectInspector.create((HiveDecimal) object); } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public Object randomWritable(int column) { + return randomWritable(typeInfos[column], objectInspectorList.get(column)); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector) { + return randomWritable(typeInfo, objectInspector, allowNull); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector, boolean allowNull) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + final Object object = randomPrimitiveObject(r, (PrimitiveTypeInfo) typeInfo); + return getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, object); + } + case LIST: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + // Always generate a list with at least 1 value? + final int elementCount = 1 + r.nextInt(100); + final StandardListObjectInspector listObjectInspector = + (StandardListObjectInspector) objectInspector; + final ObjectInspector elementObjectInspector = + listObjectInspector.getListElementObjectInspector(); + final TypeInfo elementTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + elementObjectInspector); + boolean isStringFamily = false; + PrimitiveCategory primitiveCategory = null; + if (elementTypeInfo.getCategory() == Category.PRIMITIVE) { + primitiveCategory = ((PrimitiveTypeInfo) elementTypeInfo).getPrimitiveCategory(); + if (primitiveCategory == PrimitiveCategory.STRING || + primitiveCategory == PrimitiveCategory.BINARY || + primitiveCategory == PrimitiveCategory.CHAR || + primitiveCategory == PrimitiveCategory.VARCHAR) { + isStringFamily = true; + } + } + final Object listObj = listObjectInspector.create(elementCount); + for (int i = 0; i < elementCount; i++) { + final Object ele = randomWritable(elementTypeInfo, elementObjectInspector, allowNull); + // UNDONE: For now, a 1-element list with a null element is a null list... + if (ele == null && elementCount == 1) { + return null; + } + if (isStringFamily && elementCount == 1) { + switch (primitiveCategory) { + case STRING: + if (((Text) ele).getLength() == 0) { + return null; + } + break; + case BINARY: + if (((BytesWritable) ele).getLength() == 0) { + return null; + } + break; + case CHAR: + if (((HiveCharWritable) ele).getHiveChar().getStrippedValue().isEmpty()) { + return null; + } + break; + case VARCHAR: + if (((HiveVarcharWritable) ele).getHiveVarchar().getValue().isEmpty()) { + return null; + } + break; + default: + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); + } + } + listObjectInspector.set(listObj, i, ele); + } + return listObj; + } + case MAP: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + final int keyPairCount = r.nextInt(100); + final StandardMapObjectInspector mapObjectInspector = + (StandardMapObjectInspector) objectInspector; + final ObjectInspector keyObjectInspector = + mapObjectInspector.getMapKeyObjectInspector(); + final TypeInfo keyTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + keyObjectInspector); + final ObjectInspector valueObjectInspector = + mapObjectInspector.getMapValueObjectInspector(); + final TypeInfo valueTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + valueObjectInspector); + final Object mapObj = mapObjectInspector.create(); + for (int i = 0; i < keyPairCount; i++) { + final Object key = randomWritable(keyTypeInfo, keyObjectInspector); + final Object value = randomWritable(valueTypeInfo, valueObjectInspector); + mapObjectInspector.put(mapObj, key, value); + } + return mapObj; + } + case STRUCT: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + final StandardStructObjectInspector structObjectInspector = + (StandardStructObjectInspector) objectInspector; + final List<? extends StructField> fieldRefsList = structObjectInspector.getAllStructFieldRefs(); + final int fieldCount = fieldRefsList.size(); + final Object structObj = structObjectInspector.create(); + for (int i = 0; i < fieldCount; i++) { + final StructField fieldRef = fieldRefsList.get(i); + final ObjectInspector fieldObjectInspector = + fieldRef.getFieldObjectInspector(); + final TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + final Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector); + structObjectInspector.setStructFieldData(structObj, fieldRef, fieldObj); + } + return structObj; + } + case UNION: + { + final StandardUnionObjectInspector unionObjectInspector = + (StandardUnionObjectInspector) objectInspector; + final List<ObjectInspector> objectInspectorList = unionObjectInspector.getObjectInspectors(); + final int unionCount = objectInspectorList.size(); + final byte tag = (byte) r.nextInt(unionCount); + final ObjectInspector fieldObjectInspector = + objectInspectorList.get(tag); + final TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + final Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector, false); + if (fieldObj == null) { + throw new RuntimeException(); + } + return new StandardUnion(tag, fieldObj); + } + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); } } - public Object randomObject(int column) { - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + public Object randomPrimitiveObject(int column) { + return randomPrimitiveObject(r, primitiveTypeInfos[column]); + } + + public static Object randomPrimitiveObject(Random r, PrimitiveTypeInfo primitiveTypeInfo) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: - return Boolean.valueOf(r.nextInt(1) == 1); + return Boolean.valueOf(r.nextBoolean()); case BYTE: return Byte.valueOf((byte) r.nextInt()); case SHORT: @@ -336,26 +799,30 @@ public class SerdeRandomRowSource { return dec; } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getCategory()); } } public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo) { - int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); - HiveChar hiveChar = new HiveChar(randomString, maxLength); + final int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); + final String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); + final HiveChar hiveChar = new HiveChar(randomString, maxLength); return hiveChar; } - public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo) { - int maxLength = 1 + r.nextInt(varcharTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); - HiveVarchar hiveVarchar = new HiveVarchar(randomString, maxLength); + public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo, String alphabet) { + final int maxLength = 1 + r.nextInt(varcharTypeInfo.getLength()); + final String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); + final HiveVarchar hiveVarchar = new HiveVarchar(randomString, maxLength); return hiveVarchar; } + public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo) { + return getRandHiveVarchar(r, varcharTypeInfo, "abcdefghijklmnopqrstuvwxyz"); + } + public static byte[] getRandBinary(Random r, int len){ - byte[] bytes = new byte[len]; + final byte[] bytes = new byte[len]; for (int j = 0; j < len; j++){ bytes[j] = Byte.valueOf((byte) r.nextInt()); } @@ -366,11 +833,11 @@ public class SerdeRandomRowSource { public static HiveDecimal getRandHiveDecimal(Random r, DecimalTypeInfo decimalTypeInfo) { while (true) { - StringBuilder sb = new StringBuilder(); - int precision = 1 + r.nextInt(18); - int scale = 0 + r.nextInt(precision + 1); + final StringBuilder sb = new StringBuilder(); + final int precision = 1 + r.nextInt(18); + final int scale = 0 + r.nextInt(precision + 1); - int integerDigits = precision - scale; + final int integerDigits = precision - scale; if (r.nextBoolean()) { sb.append("-"); @@ -385,19 +852,17 @@ public class SerdeRandomRowSource { sb.append("."); sb.append(RandomTypeUtil.getRandString(r, DECIMAL_CHARS, scale)); } - HiveDecimal dec = HiveDecimal.create(sb.toString()); - - return dec; + return HiveDecimal.create(sb.toString()); } } public static HiveIntervalYearMonth getRandIntervalYearMonth(Random r) { - String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; - String intervalYearMonthStr = String.format("%s%d-%d", + final String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; + final String intervalYearMonthStr = String.format("%s%d-%d", yearMonthSignStr, Integer.valueOf(1800 + r.nextInt(500)), // year Integer.valueOf(0 + r.nextInt(12))); // month - HiveIntervalYearMonth intervalYearMonthVal = HiveIntervalYearMonth.valueOf(intervalYearMonthStr); + final HiveIntervalYearMonth intervalYearMonthVal = HiveIntervalYearMonth.valueOf(intervalYearMonthStr); return intervalYearMonthVal; } @@ -407,15 +872,15 @@ public class SerdeRandomRowSource { optionalNanos = String.format(".%09d", Integer.valueOf(0 + r.nextInt(DateUtils.NANOS_PER_SEC))); } - String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; - String dayTimeStr = String.format("%s%d %02d:%02d:%02d%s", + final String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; + final String dayTimeStr = String.format("%s%d %02d:%02d:%02d%s", yearMonthSignStr, Integer.valueOf(1 + r.nextInt(28)), // day Integer.valueOf(0 + r.nextInt(24)), // hour Integer.valueOf(0 + r.nextInt(60)), // minute Integer.valueOf(0 + r.nextInt(60)), // second optionalNanos); - HiveIntervalDayTime intervalDayTimeVal = HiveIntervalDayTime.valueOf(dayTimeStr); + final HiveIntervalDayTime intervalDayTimeVal = HiveIntervalDayTime.valueOf(dayTimeStr); return intervalDayTimeVal; } }
http://git-wip-us.apache.org/repos/asf/hive/blob/d467e172/serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java ---------------------------------------------------------------------- diff --git a/serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java b/serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java index 19b04bb..2442fca 100644 --- a/serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java +++ b/serde/src/test/org/apache/hadoop/hive/serde2/VerifyFast.java @@ -18,9 +18,14 @@ package org.apache.hadoop.hive.serde2; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; import junit.framework.TestCase; @@ -30,7 +35,7 @@ import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.serde2.fast.DeserializeRead; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; import org.apache.hadoop.hive.serde2.fast.SerializeWrite; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DateWritable; @@ -44,7 +49,13 @@ import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; @@ -52,7 +63,6 @@ import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; /** * TestBinarySortableSerDe. @@ -61,338 +71,635 @@ import org.apache.hadoop.io.Writable; public class VerifyFast { public static void verifyDeserializeRead(DeserializeRead deserializeRead, - PrimitiveTypeInfo primitiveTypeInfo, Writable writable) throws IOException { + TypeInfo typeInfo, Object object) throws IOException { boolean isNull; isNull = !deserializeRead.readNextField(); + doVerifyDeserializeRead(deserializeRead, typeInfo, object, isNull); + } + + public static void doVerifyDeserializeRead(DeserializeRead deserializeRead, + TypeInfo typeInfo, Object object, boolean isNull) throws IOException { if (isNull) { - if (writable != null) { - TestCase.fail("Field reports null but object is not null (class " + writable.getClass().getName() + ", " + writable.toString() + ")"); + if (object != null) { + TestCase.fail("Field reports null but object is not null (class " + object.getClass().getName() + ", " + object.toString() + ")"); } return; - } else if (writable == null) { + } else if (object == null) { TestCase.fail("Field report not null but object is null"); } - switch (primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - { - boolean value = deserializeRead.currentBoolean; - if (!(writable instanceof BooleanWritable)) { - TestCase.fail("Boolean expected writable not Boolean"); - } - boolean expected = ((BooleanWritable) writable).get(); - if (value != expected) { - TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case BYTE: - { - byte value = deserializeRead.currentByte; - if (!(writable instanceof ByteWritable)) { - TestCase.fail("Byte expected writable not Byte"); - } - byte expected = ((ByteWritable) writable).get(); - if (value != expected) { - TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); - } - } - break; - case SHORT: - { - short value = deserializeRead.currentShort; - if (!(writable instanceof ShortWritable)) { - TestCase.fail("Short expected writable not Short"); - } - short expected = ((ShortWritable) writable).get(); - if (value != expected) { - TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case INT: - { - int value = deserializeRead.currentInt; - if (!(writable instanceof IntWritable)) { - TestCase.fail("Integer expected writable not Integer"); - } - int expected = ((IntWritable) writable).get(); - if (value != expected) { - TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case LONG: - { - long value = deserializeRead.currentLong; - if (!(writable instanceof LongWritable)) { - TestCase.fail("Long expected writable not Long"); - } - Long expected = ((LongWritable) writable).get(); - if (value != expected) { - TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case FLOAT: - { - float value = deserializeRead.currentFloat; - if (!(writable instanceof FloatWritable)) { - TestCase.fail("Float expected writable not Float"); - } - float expected = ((FloatWritable) writable).get(); - if (value != expected) { - TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case DOUBLE: - { - double value = deserializeRead.currentDouble; - if (!(writable instanceof DoubleWritable)) { - TestCase.fail("Double expected writable not Double"); - } - double expected = ((DoubleWritable) writable).get(); - if (value != expected) { - TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); - } - } - break; - case STRING: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - String expected = ((Text) writable).toString(); - if (!string.equals(expected)) { - TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); - } - } - break; - case CHAR: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - - HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); - - HiveChar expected = ((HiveCharWritable) writable).getHiveChar(); - if (!hiveChar.equals(expected)) { - TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = deserializeRead.currentBoolean; + if (!(object instanceof BooleanWritable)) { + TestCase.fail("Boolean expected writable not Boolean"); + } + boolean expected = ((BooleanWritable) object).get(); + if (value != expected) { + TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BYTE: + { + byte value = deserializeRead.currentByte; + if (!(object instanceof ByteWritable)) { + TestCase.fail("Byte expected writable not Byte"); + } + byte expected = ((ByteWritable) object).get(); + if (value != expected) { + TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); + } + } + break; + case SHORT: + { + short value = deserializeRead.currentShort; + if (!(object instanceof ShortWritable)) { + TestCase.fail("Short expected writable not Short"); + } + short expected = ((ShortWritable) object).get(); + if (value != expected) { + TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INT: + { + int value = deserializeRead.currentInt; + if (!(object instanceof IntWritable)) { + TestCase.fail("Integer expected writable not Integer"); + } + int expected = ((IntWritable) object).get(); + if (value != expected) { + TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case LONG: + { + long value = deserializeRead.currentLong; + if (!(object instanceof LongWritable)) { + TestCase.fail("Long expected writable not Long"); + } + Long expected = ((LongWritable) object).get(); + if (value != expected) { + TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case FLOAT: + { + float value = deserializeRead.currentFloat; + if (!(object instanceof FloatWritable)) { + TestCase.fail("Float expected writable not Float"); + } + float expected = ((FloatWritable) object).get(); + if (value != expected) { + TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case DOUBLE: + { + double value = deserializeRead.currentDouble; + if (!(object instanceof DoubleWritable)) { + TestCase.fail("Double expected writable not Double"); + } + double expected = ((DoubleWritable) object).get(); + if (value != expected) { + TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case STRING: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + String expected = ((Text) object).toString(); + if (!string.equals(expected)) { + TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); + } + } + break; + case CHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); + + HiveChar expected = ((HiveCharWritable) object).getHiveChar(); + if (!hiveChar.equals(expected)) { + TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + } + } + break; + case VARCHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + + HiveVarchar expected = ((HiveVarcharWritable) object).getHiveVarchar(); + if (!hiveVarchar.equals(expected)) { + TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); + } + } + break; + case DECIMAL: + { + HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); + if (value == null) { + TestCase.fail("Decimal field evaluated to NULL"); + } + HiveDecimal expected = ((HiveDecimalWritable) object).getHiveDecimal(); + if (!value.equals(expected)) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); + TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + } + } + break; + case DATE: + { + Date value = deserializeRead.currentDateWritable.get(); + Date expected = ((DateWritable) object).get(); + if (!value.equals(expected)) { + TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case TIMESTAMP: + { + Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); + Timestamp expected = ((TimestampWritable) object).getTimestamp(); + if (!value.equals(expected)) { + TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); + HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); + HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case BINARY: + { + byte[] byteArray = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + BytesWritable bytesWritable = (BytesWritable) object; + byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); + if (byteArray.length != expected.length){ + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + for (int b = 0; b < byteArray.length; b++) { + if (byteArray[b] != expected[b]) { + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + } + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); } } break; - case VARCHAR: - { - byte[] stringBytes = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - Text text = new Text(stringBytes); - String string = text.toString(); - - HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + case LIST: + case MAP: + case STRUCT: + case UNION: + throw new Error("Complex types need to be handled separately"); + default: + throw new Error("Unknown category " + typeInfo.getCategory()); + } + } - HiveVarchar expected = ((HiveVarcharWritable) writable).getHiveVarchar(); - if (!hiveVarchar.equals(expected)) { - TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); - } - } - break; - case DECIMAL: - { - HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); - if (value == null) { - TestCase.fail("Decimal field evaluated to NULL"); - } - HiveDecimal expected = ((HiveDecimalWritable) writable).getHiveDecimal(); - if (!value.equals(expected)) { - DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; - int precision = decimalTypeInfo.getPrecision(); - int scale = decimalTypeInfo.getScale(); - TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); - } - } - break; - case DATE: - { - Date value = deserializeRead.currentDateWritable.get(); - Date expected = ((DateWritable) writable).get(); - if (!value.equals(expected)) { - TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + public static void serializeWrite(SerializeWrite serializeWrite, + TypeInfo typeInfo, Object object) throws IOException { + if (object == null) { + serializeWrite.writeNull(); + return; + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = ((BooleanWritable) object).get(); + serializeWrite.writeBoolean(value); + } + break; + case BYTE: + { + byte value = ((ByteWritable) object).get(); + serializeWrite.writeByte(value); + } + break; + case SHORT: + { + short value = ((ShortWritable) object).get(); + serializeWrite.writeShort(value); + } + break; + case INT: + { + int value = ((IntWritable) object).get(); + serializeWrite.writeInt(value); + } + break; + case LONG: + { + long value = ((LongWritable) object).get(); + serializeWrite.writeLong(value); + } + break; + case FLOAT: + { + float value = ((FloatWritable) object).get(); + serializeWrite.writeFloat(value); + } + break; + case DOUBLE: + { + double value = ((DoubleWritable) object).get(); + serializeWrite.writeDouble(value); + } + break; + case STRING: + { + Text value = (Text) object; + byte[] stringBytes = value.getBytes(); + int stringLength = stringBytes.length; + serializeWrite.writeString(stringBytes, 0, stringLength); + } + break; + case CHAR: + { + HiveChar value = ((HiveCharWritable) object).getHiveChar(); + serializeWrite.writeHiveChar(value); + } + break; + case VARCHAR: + { + HiveVarchar value = ((HiveVarcharWritable) object).getHiveVarchar(); + serializeWrite.writeHiveVarchar(value); + } + break; + case DECIMAL: + { + HiveDecimal value = ((HiveDecimalWritable) object).getHiveDecimal(); + DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; + serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); + } + break; + case DATE: + { + Date value = ((DateWritable) object).get(); + serializeWrite.writeDate(value); + } + break; + case TIMESTAMP: + { + Timestamp value = ((TimestampWritable) object).getTimestamp(); + serializeWrite.writeTimestamp(value); + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + serializeWrite.writeHiveIntervalYearMonth(value); + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + serializeWrite.writeHiveIntervalDayTime(value); + } + break; + case BINARY: + { + BytesWritable byteWritable = (BytesWritable) object; + byte[] binaryBytes = byteWritable.getBytes(); + int length = byteWritable.getLength(); + serializeWrite.writeBinary(binaryBytes, 0, length); + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); } } break; - case TIMESTAMP: + case LIST: { - Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); - Timestamp expected = ((TimestampWritable) writable).getTimestamp(); - if (!value.equals(expected)) { - TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList<Object> elements = (ArrayList<Object>) object; + serializeWrite.beginList(elements); + boolean isFirst = true; + for (Object elementObject : elements) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateList(); + } + if (elementObject == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, elementTypeInfo, elementObject); + } } - } - break; - case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); - HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) writable).getHiveIntervalYearMonth(); - if (!value.equals(expected)) { - TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + serializeWrite.finishList(); + } + break; + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap<Object, Object> hashMap = (HashMap<Object, Object>) object; + serializeWrite.beginMap(hashMap); + boolean isFirst = true; + for (Entry<Object, Object> entry : hashMap.entrySet()) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateKeyValuePair(); + } + if (entry.getKey() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, keyTypeInfo, entry.getKey()); + } + serializeWrite.separateKey(); + if (entry.getValue() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, valueTypeInfo, entry.getValue()); + } } - } - break; - case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); - HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) writable).getHiveIntervalDayTime(); - if (!value.equals(expected)) { - TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + serializeWrite.finishMap(); + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList<TypeInfo> fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + ArrayList<Object> fieldValues = (ArrayList<Object>) object; + final int size = fieldValues.size(); + serializeWrite.beginStruct(fieldValues); + boolean isFirst = true; + for (int i = 0; i < size; i++) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateStruct(); + } + serializeWrite(serializeWrite, fieldTypeInfos.get(i), fieldValues.get(i)); } + serializeWrite.finishStruct(); } break; - case BINARY: + case UNION: { - byte[] byteArray = Arrays.copyOfRange( - deserializeRead.currentBytes, - deserializeRead.currentBytesStart, - deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); - BytesWritable bytesWritable = (BytesWritable) writable; - byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); - if (byteArray.length != expected.length){ - TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) - + " found " + Arrays.toString(byteArray) + ")"); - } - for (int b = 0; b < byteArray.length; b++) { - if (byteArray[b] != expected[b]) { - TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) - + " found " + Arrays.toString(byteArray) + ")"); - } - } + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List<TypeInfo> fieldTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = fieldTypeInfos.size(); + StandardUnion standardUnion = (StandardUnion) object; + byte tag = standardUnion.getTag(); + serializeWrite.beginUnion(tag); + serializeWrite(serializeWrite, fieldTypeInfos.get(tag), standardUnion.getObject()); + serializeWrite.finishUnion(); } break; default: - throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + throw new Error("Unknown category " + typeInfo.getCategory().name()); } } - public static void serializeWrite(SerializeWrite serializeWrite, - PrimitiveTypeInfo primitiveTypeInfo, Writable writable) throws IOException { - if (writable == null) { - serializeWrite.writeNull(); - return; + public Object readComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + return null; + } else { + return doReadComplexPrimitiveField(deserializeRead, primitiveTypeInfo); } + } + + private static Object doReadComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { switch (primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - { - boolean value = ((BooleanWritable) writable).get(); - serializeWrite.writeBoolean(value); - } - break; + case BOOLEAN: + return new BooleanWritable(deserializeRead.currentBoolean); case BYTE: - { - byte value = ((ByteWritable) writable).get(); - serializeWrite.writeByte(value); - } - break; + return new ByteWritable(deserializeRead.currentByte); case SHORT: - { - short value = ((ShortWritable) writable).get(); - serializeWrite.writeShort(value); - } - break; + return new ShortWritable(deserializeRead.currentShort); case INT: - { - int value = ((IntWritable) writable).get(); - serializeWrite.writeInt(value); - } - break; + return new IntWritable(deserializeRead.currentInt); case LONG: - { - long value = ((LongWritable) writable).get(); - serializeWrite.writeLong(value); - } - break; + return new LongWritable(deserializeRead.currentLong); case FLOAT: - { - float value = ((FloatWritable) writable).get(); - serializeWrite.writeFloat(value); - } - break; + return new FloatWritable(deserializeRead.currentFloat); case DOUBLE: - { - double value = ((DoubleWritable) writable).get(); - serializeWrite.writeDouble(value); - } - break; + return new DoubleWritable(deserializeRead.currentDouble); case STRING: - { - Text value = (Text) writable; - byte[] stringBytes = value.getBytes(); - int stringLength = stringBytes.length; - serializeWrite.writeString(stringBytes, 0, stringLength); - } - break; + return new Text(new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8)); case CHAR: - { - HiveChar value = ((HiveCharWritable) writable).getHiveChar(); - serializeWrite.writeHiveChar(value); - } - break; + return new HiveCharWritable(new HiveChar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((CharTypeInfo) primitiveTypeInfo).getLength())); case VARCHAR: - { - HiveVarchar value = ((HiveVarcharWritable) writable).getHiveVarchar(); - serializeWrite.writeHiveVarchar(value); + if (deserializeRead.currentBytes == null) { + throw new RuntimeException(); } - break; + return new HiveVarcharWritable(new HiveVarchar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((VarcharTypeInfo) primitiveTypeInfo).getLength())); case DECIMAL: - { - HiveDecimal value = ((HiveDecimalWritable) writable).getHiveDecimal(); - DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; - serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); - } - break; + return new HiveDecimalWritable(deserializeRead.currentHiveDecimalWritable); case DATE: - { - Date value = ((DateWritable) writable).get(); - serializeWrite.writeDate(value); - } - break; + return new DateWritable(deserializeRead.currentDateWritable); case TIMESTAMP: - { - Timestamp value = ((TimestampWritable) writable).getTimestamp(); - serializeWrite.writeTimestamp(value); - } - break; + return new TimestampWritable(deserializeRead.currentTimestampWritable); case INTERVAL_YEAR_MONTH: - { - HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) writable).getHiveIntervalYearMonth(); - serializeWrite.writeHiveIntervalYearMonth(value); - } - break; + return new HiveIntervalYearMonthWritable(deserializeRead.currentHiveIntervalYearMonthWritable); case INTERVAL_DAY_TIME: - { - HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) writable).getHiveIntervalDayTime(); - serializeWrite.writeHiveIntervalDayTime(value); - } - break; + return new HiveIntervalDayTimeWritable(deserializeRead.currentHiveIntervalDayTimeWritable); case BINARY: - { - BytesWritable byteWritable = (BytesWritable) writable; - byte[] binaryBytes = byteWritable.getBytes(); - int length = byteWritable.getLength(); - serializeWrite.writeBinary(binaryBytes, 0, length); + return new BytesWritable( + Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength + deserializeRead.currentBytesStart)); + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public static Object deserializeReadComplexType(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + + boolean isNull = !deserializeRead.readNextField(); + if (isNull) { + return null; + } + return getComplexField(deserializeRead, typeInfo); + } + + static int fake = 0; + + private static Object getComplexField(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return doReadComplexPrimitiveField(deserializeRead, (PrimitiveTypeInfo) typeInfo); + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList<Object> list = new ArrayList<Object>(); + Object eleObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + eleObj = null; + } else { + eleObj = getComplexField(deserializeRead, elementTypeInfo); + } + list.add(eleObj); + } + return list; + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap<Object, Object> hashMap = new HashMap<Object, Object>(); + Object keyObj; + Object valueObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + keyObj = null; + } else { + keyObj = getComplexField(deserializeRead, keyTypeInfo); + } + isNull = !deserializeRead.readComplexField(); + if (isNull) { + valueObj = null; + } else { + valueObj = getComplexField(deserializeRead, valueTypeInfo); + } + hashMap.put(keyObj, valueObj); + } + return hashMap; + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList<TypeInfo> fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final int size = fieldTypeInfos.size(); + ArrayList<Object> fieldValues = new ArrayList<Object>(); + Object fieldObj; + boolean isNull; + for (int i = 0; i < size; i++) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + fieldObj = null; + } else { + fieldObj = getComplexField(deserializeRead, fieldTypeInfos.get(i)); + } + fieldValues.add(fieldObj); + } + deserializeRead.finishComplexVariableFieldsType(); + return fieldValues; + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List<TypeInfo> unionTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = unionTypeInfos.size(); + Object tagObj; + int tag; + Object unionObj; + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the tag value. + tagObj = getComplexField(deserializeRead, TypeInfoFactory.intTypeInfo); + tag = ((IntWritable) tagObj).get(); + + isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the union value. + unionObj = new StandardUnion((byte) tag, getComplexField(deserializeRead, unionTypeInfos.get(tag))); + } + } + + deserializeRead.finishComplexVariableFieldsType(); + return unionObj; } - break; default: - throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); + throw new Error("Unexpected category " + typeInfo.getCategory()); } } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/hive/blob/d467e172/serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java ---------------------------------------------------------------------- diff --git a/serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java b/serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java index df5e8db..5302819 100644 --- a/serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java +++ b/serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/MyTestClass.java @@ -23,8 +23,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import org.apache.commons.lang.ArrayUtils; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; @@ -33,25 +31,6 @@ import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.RandomTypeUtil; import org.apache.hadoop.hive.serde2.binarysortable.MyTestPrimitiveClass.ExtraTypeInfo; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBooleanObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableByteObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDateObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveCharObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveDecimalObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveIntervalDayTimeObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveIntervalYearMonthObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveVarcharObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableShortObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableTimestampObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; public class MyTestClass { @@ -230,6 +209,9 @@ public class MyTestClass { for (int i = 0; i < minCount; i++) { Object[] row = rows[i]; for (int c = 0; c < primitiveCategories.length; c++) { + if (primitiveCategories[c] == null) { + continue; + } Object object = row[c]; // Current value. switch (primitiveCategories[c]) { case BOOLEAN:
