http://git-wip-us.apache.org/repos/asf/hive/blob/d467e172/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ---------------------------------------------------------------------- diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index cbde615..ec392c2 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -25,21 +25,29 @@ import java.util.HashSet; import java.util.List; import java.util.Random; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import org.apache.commons.lang.ArrayUtils; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.common.type.RandomTypeUtil; -import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector.StandardUnion; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableBooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableByteObjectInspector; @@ -58,11 +66,19 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObj import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableTimestampObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.BytesWritable; +import com.google.common.base.Preconditions; import com.google.common.base.Charsets; /** @@ -76,6 +92,14 @@ public class VectorRandomRowSource { private List<String> typeNames; + private Category[] categories; + + private TypeInfo[] typeInfos; + + private List<ObjectInspector> objectInspectorList; + + // Primitive. + private PrimitiveCategory[] primitiveCategories; private PrimitiveTypeInfo[] primitiveTypeInfos; @@ -86,6 +110,8 @@ public class VectorRandomRowSource { private String[] alphabets; + private boolean allowNull; + private boolean addEscapables; private String needsEscapeStr; @@ -93,6 +119,14 @@ public class VectorRandomRowSource { return typeNames; } + public Category[] categories() { + return categories; + } + + public TypeInfo[] typeInfos() { + return typeInfos; + } + public PrimitiveCategory[] primitiveCategories() { return primitiveCategories; } @@ -106,30 +140,37 @@ public class VectorRandomRowSource { } public StructObjectInspector partialRowStructObjectInspector(int partialFieldCount) { - ArrayList<ObjectInspector> partialPrimitiveObjectInspectorList = + ArrayList<ObjectInspector> partialObjectInspectorList = new ArrayList<ObjectInspector>(partialFieldCount); List<String> columnNames = new ArrayList<String>(partialFieldCount); for (int i = 0; i < partialFieldCount; i++) { columnNames.add(String.format("partial%d", i)); - partialPrimitiveObjectInspectorList.add( - PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - primitiveTypeInfos[i])); + partialObjectInspectorList.add(getObjectInspector(typeInfos[i])); } return ObjectInspectorFactory.getStandardStructObjectInspector( - columnNames, primitiveObjectInspectorList); + columnNames, objectInspectorList); + } + + public enum SupportedTypes { + ALL, PRIMITIVES, ALL_EXCEPT_MAP + } + + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth) { + init(r, supportedTypes, maxComplexDepth, true); } - public void init(Random r) { + public void init(Random r, SupportedTypes supportedTypes, int maxComplexDepth, boolean allowNull) { this.r = r; - chooseSchema(); + this.allowNull = allowNull; + chooseSchema(supportedTypes, maxComplexDepth); } /* * For now, exclude CHAR until we determine why there is a difference (blank padding) * serializing with LazyBinarySerializeWrite and the regular SerDe... */ - private static String[] possibleHiveTypeNames = { + private static String[] possibleHivePrimitiveTypeNames = { "boolean", "tinyint", "smallint", @@ -149,39 +190,217 @@ public class VectorRandomRowSource { "decimal" }; - private void chooseSchema() { + private static String[] possibleHiveComplexTypeNames = { + "array", + "struct", + "uniontype", + "map" + }; + + private String getRandomTypeName(SupportedTypes supportedTypes) { + String typeName = null; + if (r.nextInt(10 ) != 0) { + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + } else { + switch (supportedTypes) { + case PRIMITIVES: + typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + break; + case ALL_EXCEPT_MAP: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length - 1)]; + break; + case ALL: + typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length)]; + break; + } + } + return typeName; + } + + private String getDecoratedTypeName(String typeName, SupportedTypes supportedTypes, int depth, int maxDepth) { + depth++; + if (depth < maxDepth) { + supportedTypes = SupportedTypes.PRIMITIVES; + } + if (typeName.equals("char")) { + final int maxLength = 1 + r.nextInt(100); + typeName = String.format("char(%d)", maxLength); + } else if (typeName.equals("varchar")) { + final int maxLength = 1 + r.nextInt(100); + typeName = String.format("varchar(%d)", maxLength); + } else if (typeName.equals("decimal")) { + typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + } else if (typeName.equals("array")) { + String elementTypeName = getRandomTypeName(supportedTypes); + elementTypeName = getDecoratedTypeName(elementTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("array<%s>", elementTypeName); + } else if (typeName.equals("map")) { + String keyTypeName = getRandomTypeName(SupportedTypes.PRIMITIVES); + keyTypeName = getDecoratedTypeName(keyTypeName, supportedTypes, depth, maxDepth); + String valueTypeName = getRandomTypeName(supportedTypes); + valueTypeName = getDecoratedTypeName(valueTypeName, supportedTypes, depth, maxDepth); + typeName = String.format("map<%s,%s>", keyTypeName, valueTypeName); + } else if (typeName.equals("struct")) { + final int fieldCount = 1 + r.nextInt(10); + final StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append("col"); + sb.append(i); + sb.append(":"); + sb.append(fieldTypeName); + } + typeName = String.format("struct<%s>", sb.toString()); + } else if (typeName.equals("struct") || + typeName.equals("uniontype")) { + final int fieldCount = 1 + r.nextInt(10); + final StringBuilder sb = new StringBuilder(); + for (int i = 0; i < fieldCount; i++) { + String fieldTypeName = getRandomTypeName(supportedTypes); + fieldTypeName = getDecoratedTypeName(fieldTypeName, supportedTypes, depth, maxDepth); + if (i > 0) { + sb.append(","); + } + sb.append(fieldTypeName); + } + typeName = String.format("uniontype<%s>", sb.toString()); + } + return typeName; + } + + private ObjectInspector getObjectInspector(TypeInfo typeInfo) { + final ObjectInspector objectInspector; + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + final PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) typeInfo; + objectInspector = + PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveType); + } + break; + case MAP: + { + final MapTypeInfo mapType = (MapTypeInfo) typeInfo; + final MapObjectInspector mapInspector = + ObjectInspectorFactory.getStandardMapObjectInspector( + getObjectInspector(mapType.getMapKeyTypeInfo()), + getObjectInspector(mapType.getMapValueTypeInfo())); + objectInspector = mapInspector; + } + break; + case LIST: + { + final ListTypeInfo listType = (ListTypeInfo) typeInfo; + final ListObjectInspector listInspector = + ObjectInspectorFactory.getStandardListObjectInspector( + getObjectInspector(listType.getListElementTypeInfo())); + objectInspector = listInspector; + } + break; + case STRUCT: + { + final StructTypeInfo structType = (StructTypeInfo) typeInfo; + final List<TypeInfo> fieldTypes = structType.getAllStructFieldTypeInfos(); + + final List<ObjectInspector> fieldInspectors = new ArrayList<ObjectInspector>(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + final StructObjectInspector structInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + structType.getAllStructFieldNames(), fieldInspectors); + objectInspector = structInspector; + } + break; + case UNION: + { + final UnionTypeInfo unionType = (UnionTypeInfo) typeInfo; + final List<TypeInfo> fieldTypes = unionType.getAllUnionObjectTypeInfos(); + + final List<ObjectInspector> fieldInspectors = new ArrayList<ObjectInspector>(); + for (TypeInfo fieldType : fieldTypes) { + fieldInspectors.add(getObjectInspector(fieldType)); + } + + final UnionObjectInspector unionInspector = + ObjectInspectorFactory.getStandardUnionObjectInspector( + fieldInspectors); + objectInspector = unionInspector; + } + break; + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); + } + Preconditions.checkState(objectInspector != null); + return objectInspector; + } + + private void chooseSchema(SupportedTypes supportedTypes, int maxComplexDepth) { HashSet hashSet = null; - boolean allTypes; - boolean onlyOne = (r.nextInt(100) == 7); + final boolean allTypes; + final boolean onlyOne = (r.nextInt(100) == 7); if (onlyOne) { columnCount = 1; allTypes = false; } else { allTypes = r.nextBoolean(); if (allTypes) { - // One of each type. - columnCount = possibleHiveTypeNames.length; + switch (supportedTypes) { + case ALL: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + case ALL_EXCEPT_MAP: + columnCount = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case PRIMITIVES: + columnCount = possibleHivePrimitiveTypeNames.length; + break; + } hashSet = new HashSet<Integer>(); } else { columnCount = 1 + r.nextInt(20); } } typeNames = new ArrayList<String>(columnCount); + categories = new Category[columnCount]; + typeInfos = new TypeInfo[columnCount]; + objectInspectorList = new ArrayList<ObjectInspector>(columnCount); + primitiveCategories = new PrimitiveCategory[columnCount]; primitiveTypeInfos = new PrimitiveTypeInfo[columnCount]; primitiveObjectInspectorList = new ArrayList<ObjectInspector>(columnCount); List<String> columnNames = new ArrayList<String>(columnCount); for (int c = 0; c < columnCount; c++) { columnNames.add(String.format("col%d", c)); - String typeName; + final String typeName; if (onlyOne) { - typeName = possibleHiveTypeNames[r.nextInt(possibleHiveTypeNames.length)]; + typeName = getRandomTypeName(supportedTypes); } else { int typeNum; if (allTypes) { + int maxTypeNum = 0; + switch (supportedTypes) { + case PRIMITIVES: + maxTypeNum = possibleHivePrimitiveTypeNames.length; + break; + case ALL_EXCEPT_MAP: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length - 1; + break; + case ALL: + maxTypeNum = possibleHivePrimitiveTypeNames.length + possibleHiveComplexTypeNames.length; + break; + } while (true) { - typeNum = r.nextInt(possibleHiveTypeNames.length); + + typeNum = r.nextInt(maxTypeNum); + Integer typeNumInteger = new Integer(typeNum); if (!hashSet.contains(typeNumInteger)) { hashSet.add(typeNumInteger); @@ -189,32 +408,129 @@ public class VectorRandomRowSource { } } } else { - typeNum = r.nextInt(possibleHiveTypeNames.length); + if (supportedTypes == SupportedTypes.PRIMITIVES || r.nextInt(10) != 0) { + typeNum = r.nextInt(possibleHivePrimitiveTypeNames.length); + } else { + typeNum = possibleHivePrimitiveTypeNames.length + r.nextInt(possibleHiveComplexTypeNames.length); + if (supportedTypes == SupportedTypes.ALL_EXCEPT_MAP) { + typeNum--; + } + } + } + if (typeNum < possibleHivePrimitiveTypeNames.length) { + typeName = possibleHivePrimitiveTypeNames[typeNum]; + } else { + typeName = possibleHiveComplexTypeNames[typeNum - possibleHivePrimitiveTypeNames.length]; + } + + } + + String decoratedTypeName = getDecoratedTypeName(typeName, supportedTypes, 0, maxComplexDepth); + + final TypeInfo typeInfo; + try { + typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(decoratedTypeName); + } catch (Exception e) { + throw new RuntimeException("Cannot convert type name " + decoratedTypeName + " to a type " + e); + } + + typeInfos[c] = typeInfo; + final Category category = typeInfo.getCategory(); + categories[c] = category; + ObjectInspector objectInspector = getObjectInspector(typeInfo); + switch (category) { + case PRIMITIVE: + { + final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + objectInspector = PrimitiveObjectInspectorFactory. + getPrimitiveWritableObjectInspector(primitiveTypeInfo); + primitiveTypeInfos[c] = primitiveTypeInfo; + PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); + primitiveCategories[c] = primitiveCategory; + primitiveObjectInspectorList.add(objectInspector); } - typeName = possibleHiveTypeNames[typeNum]; + break; + case LIST: + case MAP: + case STRUCT: + case UNION: + primitiveObjectInspectorList.add(null); + break; + default: + throw new RuntimeException("Unexpected catagory " + category); } - if (typeName.equals("char")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("char(%d)", maxLength); - } else if (typeName.equals("varchar")) { - int maxLength = 1 + r.nextInt(100); - typeName = String.format("varchar(%d)", maxLength); - } else if (typeName.equals("decimal")) { - typeName = String.format("decimal(%d,%d)", HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); + objectInspectorList.add(objectInspector); + + if (category == Category.PRIMITIVE) { } - PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) TypeInfoUtils.getTypeInfoFromTypeString(typeName); - primitiveTypeInfos[c] = primitiveTypeInfo; - PrimitiveCategory primitiveCategory = primitiveTypeInfo.getPrimitiveCategory(); - primitiveCategories[c] = primitiveCategory; - primitiveObjectInspectorList.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo)); - typeNames.add(typeName); + typeNames.add(decoratedTypeName); } - rowStructObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(columnNames, primitiveObjectInspectorList); + rowStructObjectInspector = ObjectInspectorFactory. + getStandardStructObjectInspector(columnNames, objectInspectorList); alphabets = new String[columnCount]; } + public Object[][] randomRows(int n) { + + final Object[][] result = new Object[n][]; + for (int i = 0; i < n; i++) { + result[i] = randomRow(); + } + return result; + } + + public Object[] randomRow() { + + final Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomWritable(c); + } + return row; + } + + public Object[] randomRow(boolean allowNull) { + + final Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomWritable(typeInfos[c], objectInspectorList.get(c), allowNull); + } + return row; + } + + public Object[] randomPrimitiveRow(int columnCount) { + return randomPrimitiveRow(columnCount, r, primitiveTypeInfos); + } + + public static Object[] randomPrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + + final Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + row[c] = randomPrimitiveObject(r, primitiveTypeInfos[c]); + } + return row; + } + + public static Object[] randomWritablePrimitiveRow(int columnCount, Random r, + PrimitiveTypeInfo[] primitiveTypeInfos) { + + final Object row[] = new Object[columnCount]; + for (int c = 0; c < columnCount; c++) { + final PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[c]; + final ObjectInspector objectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(primitiveTypeInfo); + final Object object = randomPrimitiveObject(r, primitiveTypeInfo); + row[c] = getWritablePrimitiveObject(primitiveTypeInfo, objectInspector, object); + } + return row; + } + public void addBinarySortableAlphabets() { + for (int c = 0; c < columnCount; c++) { + if (primitiveCategories[c] == null) { + continue; + } switch (primitiveCategories[c]) { case STRING: case CHAR: @@ -241,52 +557,6 @@ public class VectorRandomRowSource { this.needsEscapeStr = needsEscapeStr; } - public Object[][] randomRows(int n) { - Object[][] result = new Object[n][]; - for (int i = 0; i < n; i++) { - result[i] = randomRow(); - } - return result; - } - - public Object[] randomRow() { - Object row[] = new Object[columnCount]; - for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } - } - return row; - } - - public Object[] randomRow(int columnCount) { - return randomRow(columnCount, r, primitiveObjectInspectorList, primitiveCategories, - primitiveTypeInfos); - } - - public static Object[] randomRow(int columnCount, Random r, - List<ObjectInspector> primitiveObjectInspectorList, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - Object row[] = new Object[columnCount]; - for (int c = 0; c < columnCount; c++) { - Object object = randomObject(c, r, primitiveCategories, primitiveTypeInfos); - if (object == null) { - throw new Error("Unexpected null for column " + c); - } - row[c] = getWritableObject(c, object, primitiveObjectInspectorList, - primitiveCategories, primitiveTypeInfos); - if (row[c] == null) { - throw new Error("Unexpected null for writable for column " + c); - } - } - return row; - } - public static void sort(Object[][] rows, ObjectInspector oi) { for (int i = 0; i < rows.length; i++) { for (int j = i + 1; j < rows.length; j++) { @@ -303,18 +573,10 @@ public class VectorRandomRowSource { VectorRandomRowSource.sort(rows, rowStructObjectInspector); } - public Object getWritableObject(int column, Object object) { - return getWritableObject(column, object, primitiveObjectInspectorList, - primitiveCategories, primitiveTypeInfos); - } + public static Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeInfo, + ObjectInspector objectInspector, Object object) { - public static Object getWritableObject(int column, Object object, - List<ObjectInspector> primitiveObjectInspectorList, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - ObjectInspector objectInspector = primitiveObjectInspectorList.get(column); - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - switch (primitiveCategory) { + switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object); case BYTE: @@ -334,17 +596,17 @@ public class VectorRandomRowSource { case STRING: return ((WritableStringObjectInspector) objectInspector).create((String) object); case CHAR: - { - WritableHiveCharObjectInspector writableCharObjectInspector = - new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); - return writableCharObjectInspector.create((HiveChar) object); - } + { + WritableHiveCharObjectInspector writableCharObjectInspector = + new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); + return writableCharObjectInspector.create((HiveChar) object); + } case VARCHAR: - { - WritableHiveVarcharObjectInspector writableVarcharObjectInspector = - new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); - return writableVarcharObjectInspector.create((HiveVarchar) object); - } + { + WritableHiveVarcharObjectInspector writableVarcharObjectInspector = + new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); + return writableVarcharObjectInspector.create((HiveVarchar) object); + } case BINARY: return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create((byte[]) object); case TIMESTAMP: @@ -354,113 +616,221 @@ public class VectorRandomRowSource { case INTERVAL_DAY_TIME: return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).create((HiveIntervalDayTime) object); case DECIMAL: - { - WritableHiveDecimalObjectInspector writableDecimalObjectInspector = - new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); - HiveDecimalWritable result = (HiveDecimalWritable) writableDecimalObjectInspector.create((HiveDecimal) object); - return result; - } + { + WritableHiveDecimalObjectInspector writableDecimalObjectInspector = + new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); + return writableDecimalObjectInspector.create((HiveDecimal) object); + } default: - throw new Error("Unknown primitive category " + primitiveCategory); + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); } } - public Object randomObject(int column) { - return randomObject(column, r, primitiveCategories, primitiveTypeInfos, alphabets, addEscapables, needsEscapeStr); + public Object randomWritable(int column) { + return randomWritable(typeInfos[column], objectInspectorList.get(column)); } - public static Object randomObject(int column, Random r, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos) { - return randomObject(column, r, primitiveCategories, primitiveTypeInfos, null, false, ""); - } - - public static Object randomObject(int column, Random r, PrimitiveCategory[] primitiveCategories, - PrimitiveTypeInfo[] primitiveTypeInfos, String[] alphabets, boolean addEscapables, String needsEscapeStr) { - PrimitiveCategory primitiveCategory = primitiveCategories[column]; - PrimitiveTypeInfo primitiveTypeInfo = primitiveTypeInfos[column]; - try { - switch (primitiveCategory) { - case BOOLEAN: - return Boolean.valueOf(r.nextInt(1) == 1); - case BYTE: - return Byte.valueOf((byte) r.nextInt()); - case SHORT: - return Short.valueOf((short) r.nextInt()); - case INT: - return Integer.valueOf(r.nextInt()); - case LONG: - return Long.valueOf(r.nextLong()); - case DATE: - return RandomTypeUtil.getRandDate(r); - case FLOAT: - return Float.valueOf(r.nextFloat() * 10 - 5); - case DOUBLE: - return Double.valueOf(r.nextDouble() * 10 - 5); - case STRING: - case CHAR: - case VARCHAR: - { - String result; - if (alphabets != null && alphabets[column] != null) { - result = RandomTypeUtil.getRandString(r, alphabets[column], r.nextInt(10)); - } else { - result = RandomTypeUtil.getRandString(r); + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector) { + return randomWritable(typeInfo, objectInspector, allowNull); + } + + public Object randomWritable(TypeInfo typeInfo, ObjectInspector objectInspector, boolean allowNull) { + + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + final Object object = randomPrimitiveObject(r, (PrimitiveTypeInfo) typeInfo); + return getWritablePrimitiveObject((PrimitiveTypeInfo) typeInfo, objectInspector, object); + } + case LIST: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + // Always generate a list with at least 1 value? + final int elementCount = 1 + r.nextInt(100); + final StandardListObjectInspector listObjectInspector = + (StandardListObjectInspector) objectInspector; + final ObjectInspector elementObjectInspector = + listObjectInspector.getListElementObjectInspector(); + final TypeInfo elementTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + elementObjectInspector); + boolean isStringFamily = false; + PrimitiveCategory primitiveCategory = null; + if (elementTypeInfo.getCategory() == Category.PRIMITIVE) { + primitiveCategory = ((PrimitiveTypeInfo) elementTypeInfo).getPrimitiveCategory(); + if (primitiveCategory == PrimitiveCategory.STRING || + primitiveCategory == PrimitiveCategory.BINARY || + primitiveCategory == PrimitiveCategory.CHAR || + primitiveCategory == PrimitiveCategory.VARCHAR) { + isStringFamily = true; } - if (addEscapables && result.length() > 0) { - int escapeCount = 1 + r.nextInt(2); - for (int i = 0; i < escapeCount; i++) { - int index = r.nextInt(result.length()); - String begin = result.substring(0, index); - String end = result.substring(index); - Character needsEscapeChar = needsEscapeStr.charAt(r.nextInt(needsEscapeStr.length())); - result = begin + needsEscapeChar + end; - } + } + final Object listObj = listObjectInspector.create(elementCount); + for (int i = 0; i < elementCount; i++) { + final Object ele = randomWritable(elementTypeInfo, elementObjectInspector, allowNull); + // UNDONE: For now, a 1-element list with a null element is a null list... + if (ele == null && elementCount == 1) { + return null; } - switch (primitiveCategory) { - case STRING: - return result; - case CHAR: - return new HiveChar(result, ((CharTypeInfo) primitiveTypeInfo).getLength()); - case VARCHAR: - return new HiveVarchar(result, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); - default: - throw new Error("Unknown primitive category " + primitiveCategory); + if (isStringFamily && elementCount == 1) { + switch (primitiveCategory) { + case STRING: + if (((Text) ele).getLength() == 0) { + return null; + } + break; + case BINARY: + if (((BytesWritable) ele).getLength() == 0) { + return null; + } + break; + case CHAR: + if (((HiveCharWritable) ele).getHiveChar().getStrippedValue().isEmpty()) { + return null; + } + break; + case VARCHAR: + if (((HiveVarcharWritable) ele).getHiveVarchar().getValue().isEmpty()) { + return null; + } + break; + default: + throw new RuntimeException("Unexpected primitive category " + primitiveCategory); + } } + listObjectInspector.set(listObj, i, ele); } - case BINARY: - return getRandBinary(r, 1 + r.nextInt(100)); - case TIMESTAMP: - return RandomTypeUtil.getRandTimestamp(r); - case INTERVAL_YEAR_MONTH: - return getRandIntervalYearMonth(r); - case INTERVAL_DAY_TIME: - return getRandIntervalDayTime(r); - case DECIMAL: - return getRandHiveDecimal(r, (DecimalTypeInfo) primitiveTypeInfo); - default: - throw new Error("Unknown primitive category " + primitiveCategory); + return listObj; + } + case MAP: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + final int keyPairCount = r.nextInt(100); + final StandardMapObjectInspector mapObjectInspector = + (StandardMapObjectInspector) objectInspector; + final ObjectInspector keyObjectInspector = + mapObjectInspector.getMapKeyObjectInspector(); + final TypeInfo keyTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + keyObjectInspector); + final ObjectInspector valueObjectInspector = + mapObjectInspector.getMapValueObjectInspector(); + final TypeInfo valueTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + valueObjectInspector); + final Object mapObj = mapObjectInspector.create(); + for (int i = 0; i < keyPairCount; i++) { + Object key = randomWritable(keyTypeInfo, keyObjectInspector); + Object value = randomWritable(valueTypeInfo, valueObjectInspector); + mapObjectInspector.put(mapObj, key, value); + } + return mapObj; } - } catch (Exception e) { - throw new RuntimeException("randomObject failed on column " + column + " type " + primitiveCategory, e); + case STRUCT: + { + if (allowNull && r.nextInt(20) == 0) { + return null; + } + final StandardStructObjectInspector structObjectInspector = + (StandardStructObjectInspector) objectInspector; + final List<? extends StructField> fieldRefsList = structObjectInspector.getAllStructFieldRefs(); + final int fieldCount = fieldRefsList.size(); + final Object structObj = structObjectInspector.create(); + for (int i = 0; i < fieldCount; i++) { + final StructField fieldRef = fieldRefsList.get(i); + final ObjectInspector fieldObjectInspector = + fieldRef.getFieldObjectInspector(); + final TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + final Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector); + structObjectInspector.setStructFieldData(structObj, fieldRef, fieldObj); + } + return structObj; + } + case UNION: + { + final StandardUnionObjectInspector unionObjectInspector = + (StandardUnionObjectInspector) objectInspector; + final List<ObjectInspector> objectInspectorList = unionObjectInspector.getObjectInspectors(); + final int unionCount = objectInspectorList.size(); + final byte tag = (byte) r.nextInt(unionCount); + final ObjectInspector fieldObjectInspector = + objectInspectorList.get(tag); + final TypeInfo fieldTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector( + fieldObjectInspector); + final Object fieldObj = randomWritable(fieldTypeInfo, fieldObjectInspector, false); + if (fieldObj == null) { + throw new RuntimeException(); + } + return new StandardUnion(tag, fieldObj); + } + default: + throw new RuntimeException("Unexpected category " + typeInfo.getCategory()); } } - public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo, String alphabet) { - int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); - HiveChar hiveChar = new HiveChar(randomString, maxLength); - return hiveChar; + public Object randomPrimitiveObject(int column) { + return randomPrimitiveObject(r, primitiveTypeInfos[column]); + } + + public static Object randomPrimitiveObject(Random r, PrimitiveTypeInfo primitiveTypeInfo) { + + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return Boolean.valueOf(r.nextBoolean()); + case BYTE: + return Byte.valueOf((byte) r.nextInt()); + case SHORT: + return Short.valueOf((short) r.nextInt()); + case INT: + return Integer.valueOf(r.nextInt()); + case LONG: + return Long.valueOf(r.nextLong()); + case DATE: + return RandomTypeUtil.getRandDate(r); + case FLOAT: + return Float.valueOf(r.nextFloat() * 10 - 5); + case DOUBLE: + return Double.valueOf(r.nextDouble() * 10 - 5); + case STRING: + return RandomTypeUtil.getRandString(r); + case CHAR: + return getRandHiveChar(r, (CharTypeInfo) primitiveTypeInfo); + case VARCHAR: + return getRandHiveVarchar(r, (VarcharTypeInfo) primitiveTypeInfo); + case BINARY: + return getRandBinary(r, 1 + r.nextInt(100)); + case TIMESTAMP: + return RandomTypeUtil.getRandTimestamp(r); + case INTERVAL_YEAR_MONTH: + return getRandIntervalYearMonth(r); + case INTERVAL_DAY_TIME: + return getRandIntervalDayTime(r); + case DECIMAL: + { + return getRandHiveDecimal(r, (DecimalTypeInfo) primitiveTypeInfo); + } + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getCategory()); + } } public static HiveChar getRandHiveChar(Random r, CharTypeInfo charTypeInfo) { - return getRandHiveChar(r, charTypeInfo, "abcdefghijklmnopqrstuvwxyz"); + final int maxLength = 1 + r.nextInt(charTypeInfo.getLength()); + final String randomString = RandomTypeUtil.getRandString(r, "abcdefghijklmnopqrstuvwxyz", 100); + return new HiveChar(randomString, maxLength); } public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo, String alphabet) { - int maxLength = 1 + r.nextInt(varcharTypeInfo.getLength()); - String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); - HiveVarchar hiveVarchar = new HiveVarchar(randomString, maxLength); - return hiveVarchar; + final int maxLength = 1 + r.nextInt(varcharTypeInfo.getLength()); + final String randomString = RandomTypeUtil.getRandString(r, alphabet, 100); + return new HiveVarchar(randomString, maxLength); } public static HiveVarchar getRandHiveVarchar(Random r, VarcharTypeInfo varcharTypeInfo) { @@ -468,7 +838,7 @@ public class VectorRandomRowSource { } public static byte[] getRandBinary(Random r, int len){ - byte[] bytes = new byte[len]; + final byte[] bytes = new byte[len]; for (int j = 0; j < len; j++){ bytes[j] = Byte.valueOf((byte) r.nextInt()); } @@ -479,11 +849,11 @@ public class VectorRandomRowSource { public static HiveDecimal getRandHiveDecimal(Random r, DecimalTypeInfo decimalTypeInfo) { while (true) { - StringBuilder sb = new StringBuilder(); - int precision = 1 + r.nextInt(18); - int scale = 0 + r.nextInt(precision + 1); + final StringBuilder sb = new StringBuilder(); + final int precision = 1 + r.nextInt(18); + final int scale = 0 + r.nextInt(precision + 1); - int integerDigits = precision - scale; + final int integerDigits = precision - scale; if (r.nextBoolean()) { sb.append("-"); @@ -499,19 +869,17 @@ public class VectorRandomRowSource { sb.append(RandomTypeUtil.getRandString(r, DECIMAL_CHARS, scale)); } - HiveDecimal dec = HiveDecimal.create(sb.toString()); - - return dec; + return HiveDecimal.create(sb.toString()); } } public static HiveIntervalYearMonth getRandIntervalYearMonth(Random r) { - String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; - String intervalYearMonthStr = String.format("%s%d-%d", + final String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; + final String intervalYearMonthStr = String.format("%s%d-%d", yearMonthSignStr, Integer.valueOf(1800 + r.nextInt(500)), // year Integer.valueOf(0 + r.nextInt(12))); // month - HiveIntervalYearMonth intervalYearMonthVal = HiveIntervalYearMonth.valueOf(intervalYearMonthStr); + final HiveIntervalYearMonth intervalYearMonthVal = HiveIntervalYearMonth.valueOf(intervalYearMonthStr); return intervalYearMonthVal; } @@ -521,8 +889,8 @@ public class VectorRandomRowSource { optionalNanos = String.format(".%09d", Integer.valueOf(0 + r.nextInt(DateUtils.NANOS_PER_SEC))); } - String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; - String dayTimeStr = String.format("%s%d %02d:%02d:%02d%s", + final String yearMonthSignStr = r.nextInt(2) == 0 ? "" : "-"; + final String dayTimeStr = String.format("%s%d %02d:%02d:%02d%s", yearMonthSignStr, Integer.valueOf(1 + r.nextInt(28)), // day Integer.valueOf(0 + r.nextInt(24)), // hour
http://git-wip-us.apache.org/repos/asf/hive/blob/d467e172/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java ---------------------------------------------------------------------- diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java new file mode 100644 index 0000000..b091026 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorVerifyFast.java @@ -0,0 +1,698 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector; + +import junit.framework.TestCase; +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde2.fast.DeserializeRead; +import org.apache.hadoop.hive.serde2.fast.SerializeWrite; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.StandardUnionObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class VectorVerifyFast { + + public static void verifyDeserializeRead( + DeserializeRead deserializeRead, TypeInfo typeInfo, Object object) throws IOException { + + boolean isNull; + + isNull = !deserializeRead.readNextField(); + doVerifyDeserializeRead(deserializeRead, typeInfo, object, isNull); + } + + public static void doVerifyDeserializeRead( + DeserializeRead deserializeRead, TypeInfo typeInfo, Object object, boolean isNull) throws IOException { + if (isNull) { + if (object != null) { + TestCase.fail("Field reports null but object is not null (class " + object.getClass().getName() + ", " + object.toString() + ")"); + } + return; + } else if (object == null) { + TestCase.fail("Field report not null but object is null"); + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = deserializeRead.currentBoolean; + if (!(object instanceof BooleanWritable)) { + TestCase.fail("Boolean expected writable not Boolean"); + } + boolean expected = ((BooleanWritable) object).get(); + if (value != expected) { + TestCase.fail("Boolean field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case BYTE: + { + byte value = deserializeRead.currentByte; + if (!(object instanceof ByteWritable)) { + TestCase.fail("Byte expected writable not Byte"); + } + byte expected = ((ByteWritable) object).get(); + if (value != expected) { + TestCase.fail("Byte field mismatch (expected " + (int) expected + " found " + (int) value + ")"); + } + } + break; + case SHORT: + { + short value = deserializeRead.currentShort; + if (!(object instanceof ShortWritable)) { + TestCase.fail("Short expected writable not Short"); + } + short expected = ((ShortWritable) object).get(); + if (value != expected) { + TestCase.fail("Short field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case INT: + { + int value = deserializeRead.currentInt; + if (!(object instanceof IntWritable)) { + TestCase.fail("Integer expected writable not Integer"); + } + int expected = ((IntWritable) object).get(); + if (value != expected) { + TestCase.fail("Int field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case LONG: + { + long value = deserializeRead.currentLong; + if (!(object instanceof LongWritable)) { + TestCase.fail("Long expected writable not Long"); + } + Long expected = ((LongWritable) object).get(); + if (value != expected) { + TestCase.fail("Long field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case FLOAT: + { + float value = deserializeRead.currentFloat; + if (!(object instanceof FloatWritable)) { + TestCase.fail("Float expected writable not Float"); + } + float expected = ((FloatWritable) object).get(); + if (value != expected) { + TestCase.fail("Float field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case DOUBLE: + { + double value = deserializeRead.currentDouble; + if (!(object instanceof DoubleWritable)) { + TestCase.fail("Double expected writable not Double"); + } + double expected = ((DoubleWritable) object).get(); + if (value != expected) { + TestCase.fail("Double field mismatch (expected " + expected + " found " + value + ")"); + } + } + break; + case STRING: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + String expected = ((Text) object).toString(); + if (!string.equals(expected)) { + TestCase.fail("String field mismatch (expected '" + expected + "' found '" + string + "')"); + } + } + break; + case CHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveChar hiveChar = new HiveChar(string, ((CharTypeInfo) primitiveTypeInfo).getLength()); + + HiveChar expected = ((HiveCharWritable) object).getHiveChar(); + if (!hiveChar.equals(expected)) { + TestCase.fail("Char field mismatch (expected '" + expected + "' found '" + hiveChar + "')"); + } + } + break; + case VARCHAR: + { + byte[] stringBytes = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + Text text = new Text(stringBytes); + String string = text.toString(); + + HiveVarchar hiveVarchar = new HiveVarchar(string, ((VarcharTypeInfo) primitiveTypeInfo).getLength()); + + HiveVarchar expected = ((HiveVarcharWritable) object).getHiveVarchar(); + if (!hiveVarchar.equals(expected)) { + TestCase.fail("Varchar field mismatch (expected '" + expected + "' found '" + hiveVarchar + "')"); + } + } + break; + case DECIMAL: + { + HiveDecimal value = deserializeRead.currentHiveDecimalWritable.getHiveDecimal(); + if (value == null) { + TestCase.fail("Decimal field evaluated to NULL"); + } + HiveDecimal expected = ((HiveDecimalWritable) object).getHiveDecimal(); + if (!value.equals(expected)) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + int precision = decimalTypeInfo.getPrecision(); + int scale = decimalTypeInfo.getScale(); + TestCase.fail("Decimal field mismatch (expected " + expected.toString() + " found " + value.toString() + ") precision " + precision + ", scale " + scale); + } + } + break; + case DATE: + { + Date value = deserializeRead.currentDateWritable.get(); + Date expected = ((DateWritable) object).get(); + if (!value.equals(expected)) { + TestCase.fail("Date field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case TIMESTAMP: + { + Timestamp value = deserializeRead.currentTimestampWritable.getTimestamp(); + Timestamp expected = ((TimestampWritable) object).getTimestamp(); + if (!value.equals(expected)) { + TestCase.fail("Timestamp field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = deserializeRead.currentHiveIntervalYearMonthWritable.getHiveIntervalYearMonth(); + HiveIntervalYearMonth expected = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalYearMonth field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = deserializeRead.currentHiveIntervalDayTimeWritable.getHiveIntervalDayTime(); + HiveIntervalDayTime expected = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + if (!value.equals(expected)) { + TestCase.fail("HiveIntervalDayTime field mismatch (expected " + expected.toString() + " found " + value.toString() + ")"); + } + } + break; + case BINARY: + { + byte[] byteArray = Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesStart + deserializeRead.currentBytesLength); + BytesWritable bytesWritable = (BytesWritable) object; + byte[] expected = Arrays.copyOfRange(bytesWritable.getBytes(), 0, bytesWritable.getLength()); + if (byteArray.length != expected.length){ + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + for (int b = 0; b < byteArray.length; b++) { + if (byteArray[b] != expected[b]) { + TestCase.fail("Byte Array field mismatch (expected " + Arrays.toString(expected) + + " found " + Arrays.toString(byteArray) + ")"); + } + } + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + break; + case LIST: + case MAP: + case STRUCT: + case UNION: + throw new Error("Complex types need to be handled separately"); + default: + throw new Error("Unknown category " + typeInfo.getCategory()); + } + } + + public static void serializeWrite(SerializeWrite serializeWrite, + TypeInfo typeInfo, Object object) throws IOException { + if (object == null) { + serializeWrite.writeNull(); + return; + } + switch (typeInfo.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + { + boolean value = ((BooleanWritable) object).get(); + serializeWrite.writeBoolean(value); + } + break; + case BYTE: + { + byte value = ((ByteWritable) object).get(); + serializeWrite.writeByte(value); + } + break; + case SHORT: + { + short value = ((ShortWritable) object).get(); + serializeWrite.writeShort(value); + } + break; + case INT: + { + int value = ((IntWritable) object).get(); + serializeWrite.writeInt(value); + } + break; + case LONG: + { + long value = ((LongWritable) object).get(); + serializeWrite.writeLong(value); + } + break; + case FLOAT: + { + float value = ((FloatWritable) object).get(); + serializeWrite.writeFloat(value); + } + break; + case DOUBLE: + { + double value = ((DoubleWritable) object).get(); + serializeWrite.writeDouble(value); + } + break; + case STRING: + { + Text value = (Text) object; + byte[] stringBytes = value.getBytes(); + int stringLength = stringBytes.length; + serializeWrite.writeString(stringBytes, 0, stringLength); + } + break; + case CHAR: + { + HiveChar value = ((HiveCharWritable) object).getHiveChar(); + serializeWrite.writeHiveChar(value); + } + break; + case VARCHAR: + { + HiveVarchar value = ((HiveVarcharWritable) object).getHiveVarchar(); + serializeWrite.writeHiveVarchar(value); + } + break; + case DECIMAL: + { + HiveDecimal value = ((HiveDecimalWritable) object).getHiveDecimal(); + DecimalTypeInfo decTypeInfo = (DecimalTypeInfo)primitiveTypeInfo; + serializeWrite.writeHiveDecimal(value, decTypeInfo.scale()); + } + break; + case DATE: + { + Date value = ((DateWritable) object).get(); + serializeWrite.writeDate(value); + } + break; + case TIMESTAMP: + { + Timestamp value = ((TimestampWritable) object).getTimestamp(); + serializeWrite.writeTimestamp(value); + } + break; + case INTERVAL_YEAR_MONTH: + { + HiveIntervalYearMonth value = ((HiveIntervalYearMonthWritable) object).getHiveIntervalYearMonth(); + serializeWrite.writeHiveIntervalYearMonth(value); + } + break; + case INTERVAL_DAY_TIME: + { + HiveIntervalDayTime value = ((HiveIntervalDayTimeWritable) object).getHiveIntervalDayTime(); + serializeWrite.writeHiveIntervalDayTime(value); + } + break; + case BINARY: + { + BytesWritable byteWritable = (BytesWritable) object; + byte[] binaryBytes = byteWritable.getBytes(); + int length = byteWritable.getLength(); + serializeWrite.writeBinary(binaryBytes, 0, length); + } + break; + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory().name()); + } + } + break; + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList<Object> elements = (ArrayList<Object>) object; + serializeWrite.beginList(elements); + boolean isFirst = true; + for (Object elementObject : elements) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateList(); + } + if (elementObject == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, elementTypeInfo, elementObject); + } + } + serializeWrite.finishList(); + } + break; + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap<Object, Object> hashMap = (HashMap<Object, Object>) object; + serializeWrite.beginMap(hashMap); + boolean isFirst = true; + for (Map.Entry<Object, Object> entry : hashMap.entrySet()) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateKeyValuePair(); + } + if (entry.getKey() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, keyTypeInfo, entry.getKey()); + } + serializeWrite.separateKey(); + if (entry.getValue() == null) { + serializeWrite.writeNull(); + } else { + serializeWrite(serializeWrite, valueTypeInfo, entry.getValue()); + } + } + serializeWrite.finishMap(); + } + break; + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList<TypeInfo> fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + ArrayList<Object> fieldValues = (ArrayList<Object>) object; + final int size = fieldValues.size(); + serializeWrite.beginStruct(fieldValues); + boolean isFirst = true; + for (int i = 0; i < size; i++) { + if (isFirst) { + isFirst = false; + } else { + serializeWrite.separateStruct(); + } + serializeWrite(serializeWrite, fieldTypeInfos.get(i), fieldValues.get(i)); + } + serializeWrite.finishStruct(); + } + break; + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List<TypeInfo> fieldTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = fieldTypeInfos.size(); + StandardUnionObjectInspector.StandardUnion standardUnion = (StandardUnionObjectInspector.StandardUnion) object; + byte tag = standardUnion.getTag(); + serializeWrite.beginUnion(tag); + serializeWrite(serializeWrite, fieldTypeInfos.get(tag), standardUnion.getObject()); + serializeWrite.finishUnion(); + } + break; + default: + throw new Error("Unknown category " + typeInfo.getCategory().name()); + } + } + + public Object readComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + return null; + } else { + return doReadComplexPrimitiveField(deserializeRead, primitiveTypeInfo); + } + } + + private static Object doReadComplexPrimitiveField(DeserializeRead deserializeRead, + PrimitiveTypeInfo primitiveTypeInfo) throws IOException { + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return new BooleanWritable(deserializeRead.currentBoolean); + case BYTE: + return new ByteWritable(deserializeRead.currentByte); + case SHORT: + return new ShortWritable(deserializeRead.currentShort); + case INT: + return new IntWritable(deserializeRead.currentInt); + case LONG: + return new LongWritable(deserializeRead.currentLong); + case FLOAT: + return new FloatWritable(deserializeRead.currentFloat); + case DOUBLE: + return new DoubleWritable(deserializeRead.currentDouble); + case STRING: + return new Text(new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8)); + case CHAR: + return new HiveCharWritable(new HiveChar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((CharTypeInfo) primitiveTypeInfo).getLength())); + case VARCHAR: + if (deserializeRead.currentBytes == null) { + throw new RuntimeException(); + } + return new HiveVarcharWritable(new HiveVarchar( + new String( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength, + StandardCharsets.UTF_8), + ((VarcharTypeInfo) primitiveTypeInfo).getLength())); + case DECIMAL: + return new HiveDecimalWritable(deserializeRead.currentHiveDecimalWritable); + case DATE: + return new DateWritable(deserializeRead.currentDateWritable); + case TIMESTAMP: + return new TimestampWritable(deserializeRead.currentTimestampWritable); + case INTERVAL_YEAR_MONTH: + return new HiveIntervalYearMonthWritable(deserializeRead.currentHiveIntervalYearMonthWritable); + case INTERVAL_DAY_TIME: + return new HiveIntervalDayTimeWritable(deserializeRead.currentHiveIntervalDayTimeWritable); + case BINARY: + return new BytesWritable( + Arrays.copyOfRange( + deserializeRead.currentBytes, + deserializeRead.currentBytesStart, + deserializeRead.currentBytesLength + deserializeRead.currentBytesStart)); + default: + throw new Error("Unknown primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + + public static Object deserializeReadComplexType(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + + boolean isNull = !deserializeRead.readNextField(); + if (isNull) { + return null; + } + return getComplexField(deserializeRead, typeInfo); + } + + private static Object getComplexField(DeserializeRead deserializeRead, + TypeInfo typeInfo) throws IOException { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + return doReadComplexPrimitiveField(deserializeRead, (PrimitiveTypeInfo) typeInfo); + case LIST: + { + ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + ArrayList<Object> list = new ArrayList<Object>(); + Object eleObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + eleObj = null; + } else { + eleObj = getComplexField(deserializeRead, elementTypeInfo); + } + list.add(eleObj); + } + return list; + } + case MAP: + { + MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + HashMap<Object, Object> hashMap = new HashMap<Object, Object>(); + Object keyObj; + Object valueObj; + boolean isNull; + while (deserializeRead.isNextComplexMultiValue()) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + keyObj = null; + } else { + keyObj = getComplexField(deserializeRead, keyTypeInfo); + } + isNull = !deserializeRead.readComplexField(); + if (isNull) { + valueObj = null; + } else { + valueObj = getComplexField(deserializeRead, valueTypeInfo); + } + hashMap.put(keyObj, valueObj); + } + return hashMap; + } + case STRUCT: + { + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + ArrayList<TypeInfo> fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final int size = fieldTypeInfos.size(); + ArrayList<Object> fieldValues = new ArrayList<Object>(); + Object fieldObj; + boolean isNull; + for (int i = 0; i < size; i++) { + isNull = !deserializeRead.readComplexField(); + if (isNull) { + fieldObj = null; + } else { + fieldObj = getComplexField(deserializeRead, fieldTypeInfos.get(i)); + } + fieldValues.add(fieldObj); + } + deserializeRead.finishComplexVariableFieldsType(); + return fieldValues; + } + case UNION: + { + UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + List<TypeInfo> unionTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final int size = unionTypeInfos.size(); + Object tagObj; + int tag; + Object unionObj; + boolean isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the tag value. + tagObj = getComplexField(deserializeRead, TypeInfoFactory.intTypeInfo); + tag = ((IntWritable) tagObj).get(); + + isNull = !deserializeRead.readComplexField(); + if (isNull) { + unionObj = null; + } else { + // Get the union value. + unionObj = new StandardUnionObjectInspector.StandardUnion((byte) tag, getComplexField(deserializeRead, unionTypeInfos.get(tag))); + } + } + + deserializeRead.finishComplexVariableFieldsType(); + return unionObj; + } + default: + throw new Error("Unexpected category " + typeInfo.getCategory()); + } + }} http://git-wip-us.apache.org/repos/asf/hive/blob/d467e172/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java ---------------------------------------------------------------------- diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java index 72fceb9..24b178f 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/CheckFastRowHashMap.java @@ -36,15 +36,16 @@ import org.apache.hadoop.hive.ql.plan.VectorMapJoinDesc.HashTableKeyType; import org.apache.hadoop.hive.serde2.WriteBuffers; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.lazy.VerifyLazy; import org.apache.hadoop.hive.serde2.lazybinary.fast.LazyBinaryDeserializeRead; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObject; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; import com.google.common.base.Preconditions; @@ -76,8 +77,7 @@ public class CheckFastRowHashMap extends CheckFastHashTable { lazyBinaryDeserializeRead.set(bytes, offset, length); for (int index = 0; index < columnCount; index++) { - Writable writable = (Writable) row[index]; - VerifyFastRow.verifyDeserializeRead(lazyBinaryDeserializeRead, (PrimitiveTypeInfo) typeInfos[index], writable); + verifyRead(lazyBinaryDeserializeRead, typeInfos[index], row[index]); } TestCase.assertTrue(lazyBinaryDeserializeRead.isEndOfInputReached()); @@ -132,8 +132,7 @@ public class CheckFastRowHashMap extends CheckFastHashTable { int index = 0; try { for (index = 0; index < columnCount; index++) { - Writable writable = (Writable) row[index]; - VerifyFastRow.verifyDeserializeRead(lazyBinaryDeserializeRead, (PrimitiveTypeInfo) typeInfos[index], writable); + verifyRead(lazyBinaryDeserializeRead, typeInfos[index], row[index]); } } catch (Exception e) { thrown = true; @@ -175,6 +174,39 @@ public class CheckFastRowHashMap extends CheckFastHashTable { } } + private static void verifyRead(LazyBinaryDeserializeRead lazyBinaryDeserializeRead, + TypeInfo typeInfo, Object expectedObject) throws IOException { + if (typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE) { + VerifyFastRow.verifyDeserializeRead(lazyBinaryDeserializeRead, typeInfo, expectedObject); + } else { + final Object complexFieldObj = + VerifyFastRow.deserializeReadComplexType(lazyBinaryDeserializeRead, typeInfo); + if (expectedObject == null) { + if (complexFieldObj != null) { + TestCase.fail("Field reports not null but object is null (class " + + complexFieldObj.getClass().getName() + + ", " + complexFieldObj.toString() + ")"); + } + } else { + if (complexFieldObj == null) { + // It's hard to distinguish a union with null from a null union. + if (expectedObject instanceof UnionObject) { + UnionObject expectedUnion = (UnionObject) expectedObject; + if (expectedUnion.getObject() == null) { + return; + } + } + TestCase.fail("Field reports null but object is not null (class " + + expectedObject.getClass().getName() + + ", " + expectedObject.toString() + ")"); + } + } + if (!VerifyLazy.lazyCompare(typeInfo, complexFieldObj, expectedObject)) { + TestCase.fail("Comparision failed typeInfo " + typeInfo.toString()); + } + } + } + /* * Element for Key: row and byte[] x Hash Table: HashMap */ @@ -283,7 +315,7 @@ public class CheckFastRowHashMap extends CheckFastHashTable { public void verify(VectorMapJoinFastHashTable map, HashTableKeyType hashTableKeyType, - PrimitiveTypeInfo[] valuePrimitiveTypeInfos, boolean doClipping, + TypeInfo[] valueTypeInfos, boolean doClipping, boolean useExactBytes, Random random) throws IOException { int mapSize = map.size(); if (mapSize != count) { @@ -368,10 +400,10 @@ public class CheckFastRowHashMap extends CheckFastHashTable { List<Object[]> rows = element.getValueRows(); if (!doClipping && !useExactBytes) { - verifyHashMapRows(rows, actualToValueMap, hashMapResult, valuePrimitiveTypeInfos); + verifyHashMapRows(rows, actualToValueMap, hashMapResult, valueTypeInfos); } else { int clipIndex = random.nextInt(rows.size()); - verifyHashMapRowsMore(rows, actualToValueMap, hashMapResult, valuePrimitiveTypeInfos, + verifyHashMapRowsMore(rows, actualToValueMap, hashMapResult, valueTypeInfos, clipIndex, useExactBytes); } }
