Replacing while loops with switch statements for the "extra bytes" also yields a small improvement. Pulling that common logic out into a utility method negates most of the benefit. Here is the updated ArrayUtil class.
package org.tukaani.xz.common; import static java.lang.invoke.MethodType.methodType; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Locale; import java.util.Properties; import java.util.logging.Level; import java.util.logging.Logger; /** * Utilities for optimized array interactions. * * <p> * The means of comparing arrays can be controlled by setting the sytem property * {@code org.tukaani.xz.ArrayComparison} to a value from {@link ArrayComparison}. * </p> * * @author Brett Okken */ public final class ArrayUtil { /** * Enumerated options for controlling implementation of how to compare arrays. */ public static enum ArrayComparison { /** * Uses {@code VarHandle} for {@code int[]} access. * <p> * This is default behavior on jdk9+ for 32 bit x86. * </p> */ VH_INT, /** * Uses {@code VarHandle} for {@code int[]} access after attempting * to align the reads on 4 byte boundaries. */ VH_INT_ALIGN, /** * Uses {@code VarHandle} for {@code long[]} access. * <p> * This is default behavior on jdk9+ for 64 bit x86. * </p> */ VH_LONG, /** * Uses {@code VarHandle} for {@code long[]} access after attempting * to align the reads. */ VH_LONG_ALIGN, /** * Uses {@code Arrays.mismatch()} to perform vectorized comparison. * <p> * This is default behavior on jdk9+ for non-x86. * </p> */ VECTOR, /** * Uses {@code sun.misc.Unsafe.getInt()} for unaligned {@code int[]} * access. * <p> * This is default behavior on jdk 8 and prior for 32 bit x86. * </p> */ UNSAFE_GET_INT, /** * Uses {@code sun.misc.Unsafe.getLong()} for unaligned {@code long[]} * access. * <p> * This is default behavior on jdk 8 and prior for 64 bit x86. * </p> */ UNSAFE_GET_LONG, /** * Performs byte-by-byte comparison. */ LEGACY; static ArrayComparison getFromProperty(String prop) { if (prop == null || prop.isEmpty()) { return null; } try { return ArrayComparison.valueOf(prop.toUpperCase(Locale.US)); } catch (Exception e) { final Logger logger = Logger.getLogger(ArrayUtil.class.getName()); logger.log(Level.INFO, "Invalid ArrayComparison option, using default behavior", e); return null; } } } /** * MethodHandle to the actual mismatch method to use at runtime. */ private static final MethodHandle MISMATCH; /** * The method this is bound to at runtime is depends on the chosen * implementation for {@code byte[]} comparison. * <p> * For {@code long} based comparisons, it will be bound to either * {@link Long#numberOfLeadingZeros(long)} or * {@link Long#numberOfTrailingZeros(long)} depending on * {@link ByteOrder#nativeOrder()}. * </p> * <p> * For {@code int} based comparisons it will be bound to either * {@link Integer#numberOfLeadingZeros(int)} or * {@link Integer#numberOfTrailingZeros(int)} depending on * {@link ByteOrder#nativeOrder()}. * </p> */ private static final MethodHandle LEADING_ZEROS; /** * Populated from reflected read of * {@code sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET} if one of the unsafe * implementations is used. */ private static final long ARRAY_BASE_OFFSET; /** * The method this is bound to at runtime is depends on the chosen * implementation for {@code byte[]} comparison. * <p> * For {@link ArrayComparison#VECTOR} and * {@link ArrayComparison#LEGACY} this will be {@code null}. * </p> * <p> * For {@link ArrayComparison#VH_INT} and {@link ArrayComparison#VH_INT_ALIGN} * this will be a jdk 9+ {@code byteArrayViewVarHandle} for {@code int[]} * using the {@link ByteOrder#nativeOrder()}. The method signature is * {@code int get(byte[], int)}. * </p> * <p> * For {@link ArrayComparison#VH_LONG} and {@link ArrayComparison#VH_LONG_ALIGN} * this will be a jdk 9+ {@code byteArrayViewVarHandle} for {@code long[]} * using the {@link ByteOrder#nativeOrder()}. The method signature is * {@code long get(byte[], int)}. * </p> * <p> * For {@link ArrayComparison#UNSAFE_GET_INT} this is bound to * {@code sun.misc.Unsafe.getInt(Object, long)}. * </p> * <p> * For {@link ArrayComparison#UNSAFE_GET_LONG} this is bound to * {@code sun.misc.Unsafe.getLong(Object, long)}. * </p> */ private static final MethodHandle GET_PRIMITIVE; /** * MethodHandle to the jdk 9+ * {@code Arrays.mismatch(byte[] a, int aFromIndex, int aToIndex, byte[] b, int bFromIndex, int bToIndex)}. */ private static final MethodHandle ARRAYS_MISMATCH; static { final Logger logger = Logger.getLogger(ArrayUtil.class.getName()); MethodHandle leadingZeros = null; MethodHandle getPrimitive = null; MethodHandle arraysMismatch = null; long arrayBaseOffset = 0; MethodHandle mismatch = null; final MethodHandles.Lookup lookup = MethodHandles.lookup(); final MethodType mismatchType = methodType( int.class, byte[].class, int.class, int.class, int.class); try { final Properties props = System.getProperties(); final ArrayComparison algo = ArrayComparison.getFromProperty( props.getProperty("org.tukaani.xz.ArrayComparison")); final String arch = props.getProperty("os.arch", ""); final boolean unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$"); //if unaligned, or explicitly configured, try VarHandles if ((unaligned && algo == null) || algo == ArrayComparison.VH_LONG || algo == ArrayComparison.VH_LONG_ALIGN || algo == ArrayComparison.VH_INT || algo == ArrayComparison.VH_INT_ALIGN) { try { final Class<?> varHandleClazz = Class.forName("java.lang.invoke.VarHandle", true, null); final Method byteArrayViewHandle = MethodHandles.class.getDeclaredMethod("byteArrayViewVarHandle", new Class[] {Class.class, ByteOrder.class}); final boolean doLong = (algo == null && arch.contains("64")) || algo == ArrayComparison.VH_LONG || algo == ArrayComparison.VH_LONG_ALIGN; final Object varHandle = byteArrayViewHandle.invoke(null, doLong ? long[].class : int[].class, ByteOrder.nativeOrder()); final Class<?> accessModeEnum = Class.forName("java.lang.invoke.VarHandle$AccessMode", true, null); @SuppressWarnings({ "unchecked", "rawtypes" }) final Object getAccessModeEnum = Enum.valueOf((Class)accessModeEnum, "GET"); final Method toMethodHandle = varHandleClazz.getDeclaredMethod("toMethodHandle", accessModeEnum); getPrimitive = (MethodHandle) toMethodHandle.invoke(varHandle, getAccessModeEnum); //the primitive will be in platform byte order. the concept of "leading //zeros" being bytes in encounter order is true for big endian //for little endian platform, the trailing zeros gives the encounter order result leadingZeros = lookup.findStatic(doLong? Long.class : Integer.class, ByteOrder.BIG_ENDIAN == ByteOrder.nativeOrder() ? "numberOfLeadingZeros" : "numberOfTrailingZeros", methodType(int.class, doLong ? long.class : int.class)); final String mismatchMethod; if (doLong) { mismatchMethod = algo == null || algo == ArrayComparison.VH_LONG ? "varHandleMismatch" : "alignedVarHandleMismatch"; logger.finest("byte[] comparison using long VarHandle"); } else { mismatchMethod = algo == null || algo == ArrayComparison.VH_INT ? "intVarHandleMismatch" : "alignedIntVarHandleMismatch"; logger.finest("byte[] comparison using int VarHandle"); } mismatch = lookup.findStatic(ArrayUtil.class, mismatchMethod, mismatchType); } catch (Throwable t) { logger.log(Level.FINE, "failed trying to load a MethodHandle to invoke get on a byteArrayViewVarHandle", t); mismatch = null; } } if (mismatch == null && ((!unaligned && algo == null) || algo == ArrayComparison.VECTOR)) { try { final MethodType arraysType = methodType(int.class, byte[].class, int.class, int.class, byte[].class, int.class, int.class); arraysMismatch = lookup.findStatic(Arrays.class, "mismatch", arraysType); mismatch = lookup.findStatic(ArrayUtil.class, "arraysMismatch", mismatchType); logger.finest("byte[] comparisons using Arrays.mismatch"); } catch (Throwable t) { logger.log(Level.FINE, "failed trying to load a MethodHandle to invoke Arrays.mismatch", t); arraysMismatch = null; } } //if byteArrayViewVarHandle for a long[] could not be loaded, then //try to load sun.misc.Unsafe for unaligned archs only if (mismatch == null && ((unaligned && algo == null) || algo == ArrayComparison.UNSAFE_GET_LONG || algo == ArrayComparison.UNSAFE_GET_INT)) { Class<?> unsafeClazz = Class.forName("sun.misc.Unsafe", true, null); Constructor<?> unsafeConstructor = unsafeClazz.getDeclaredConstructor(); unsafeConstructor.setAccessible(true); Object unsafe = unsafeConstructor.newInstance(); arrayBaseOffset = unsafeClazz.getField("ARRAY_BYTE_BASE_OFFSET") .getLong(null); if (algo == ArrayComparison.UNSAFE_GET_LONG || (algo == null && arch.contains("64"))) { leadingZeros = lookup.findStatic(Long.class, ByteOrder.BIG_ENDIAN == ByteOrder.nativeOrder() ? "numberOfLeadingZeros" : "numberOfTrailingZeros", methodType(int.class, long.class)); MethodHandle virtualGetLong = lookup.findVirtual( unsafeClazz, "getLong", methodType(long.class, Object.class, long.class)); getPrimitive = virtualGetLong.bindTo(unsafe); // do a test read to confirm unsafe is actually functioning long val = (long) getPrimitive.invokeExact( (Object) new byte[] { 0, 0, 0, 0, 0, 0, 0, 0 }, arrayBaseOffset + 0L); if (val != 0) { throw new IllegalStateException("invalid value: " + val); } mismatch = lookup.findStatic(ArrayUtil.class, "unsafeMismatch", mismatchType); logger.finest("byte[] comparisons using Unsafe.getLong"); } else { leadingZeros = lookup.findStatic(Integer.class, ByteOrder.BIG_ENDIAN == ByteOrder.nativeOrder() ? "numberOfLeadingZeros" : "numberOfTrailingZeros", methodType(int.class, int.class)); MethodHandle virtualGetInt = lookup.findVirtual( unsafeClazz, "getInt", methodType(int.class, Object.class, long.class)); getPrimitive = virtualGetInt.bindTo(unsafe); // do a test read to confirm unsafe is actually functioning int val = (int) getPrimitive.invokeExact( (Object) new byte[] { 0, 0, 0, 0 }, arrayBaseOffset + 0L); if (val != 0) { throw new IllegalStateException("invalid value: " + val); } mismatch = lookup.findStatic(ArrayUtil.class, "intUnsafeMismatch", mismatchType); logger.finest("byte[] comparisons using Unsafe.getInt"); } } } catch (Throwable t) { logger.log(Level.FINE, "failed trying to load means to compare byte[] by longs", t); } if (mismatch == null) { getPrimitive = null; leadingZeros = null; arraysMismatch = null; logger.finest("byte[] comparisons byte by byte"); try { mismatch = lookup.findStatic(ArrayUtil.class, "legacyMismatch", mismatchType); } catch (Exception e) { throw new IllegalStateException(e); } } GET_PRIMITIVE = getPrimitive; ARRAY_BASE_OFFSET = arrayBaseOffset; LEADING_ZEROS = leadingZeros; ARRAYS_MISMATCH = arraysMismatch; MISMATCH = mismatch; } /** * Compares the values in <i>bytes</i>, starting at <i>aFromIndex</i> and * <i>bFromIndex</i> and returns the zero-based index of the first * {@code byte} which differs. * @param bytes The {@code byte[]} for comparison. * @param aFromIndex The first offset into <i>bytes</i> to start reading from. * @param bFromIndex The second offset into <i>bytes</i> to start reading from. * @param length The number of bytes to compare. * @return The offset from the starting indexes of the first byte which differs. * If all match, <i>length</i> will be returned. */ public static int mismatch( byte[] bytes, int aFromIndex, int bFromIndex, int length) { try { return (int) MISMATCH.invokeExact(bytes, aFromIndex, bFromIndex, length); } catch (RuntimeException e) { throw e; } catch (Error e) { throw e; } catch (Throwable t) { throw new RuntimeException(t); } } /** * If <i>aFromIndex</i> and <i>bFromIndex</i> are similarly mis-aligned, will * do single byte comparisons to obtain alignment, then call * {@link #varHandleMismatch(byte[], int, byte[], int, int)}. */ @SuppressWarnings("unused") private static int alignedVarHandleMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { //while we could do an index check, the VarHandle call incorporates a check, //making any check here duplicative int aFromAlignment = aFromIndex & 7; int bFromAlignment = bFromIndex & 7; //if they are aligned, just go if (aFromAlignment == 0 && bFromAlignment == 0) { return varHandleMismatch(a, aFromIndex, bFromIndex, length); } int i=0; //if both are similarly out of alignment, adjust if (aFromAlignment == bFromAlignment) { for (int j = Math.min(8 - aFromAlignment, length); i < j; ++i) { if (a[aFromIndex + i] != a[bFromIndex + i]) { return i; } } } else if ((aFromAlignment & 1) == 1 && (bFromAlignment & 1) == 1) { //if they both have an odd alignment, adjust by one if (a[aFromIndex] != a[bFromIndex]) { return 0; } ++i; } return i + varHandleMismatch(a, aFromIndex + i, bFromIndex + i, length - i); } /** * Uses {@link #VAR_HANDLE_GET_LONG} to compare 8 bytes at a time. */ private static int varHandleMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { //while we could do an index check, the VarHandle call incorporates a check, //making any check here duplicative int i=0; for (int j=length-7; i<j; i+=8) { final long aVal = (long) GET_PRIMITIVE.invokeExact(a, aFromIndex + i); final long bVal = (long) GET_PRIMITIVE.invokeExact(a, bFromIndex + i); if (aVal != bVal) { //this returns a value where bits which match are 0 and bits which //differ are 1 final long diff = aVal ^ bVal; //the first (in native byte order) bit which differs tells us which //byte differed final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff); return Math.min(i + (leadingZeros >>> 3), length); } } int result = length; switch(length & 7) { case 7: if (a[aFromIndex + i + 6] != a[bFromIndex + i + 6]) result = i + 6; case 6: if (a[aFromIndex + i + 5] != a[bFromIndex + i + 5]) result = i + 5; case 5: if (a[aFromIndex + i + 4] != a[bFromIndex + i + 4]) result = i + 4; case 4: if (a[aFromIndex + i + 3] != a[bFromIndex + i + 3]) result = i + 3; case 3: if (a[aFromIndex + i + 2] != a[bFromIndex + i + 2]) result = i + 2; case 2: if (a[aFromIndex + i + 1] != a[bFromIndex + i + 1]) result = i + 1; case 1: if (a[aFromIndex + i] != a[bFromIndex + i]) result = i; } return result; } /** * If <i>aFromIndex</i> and <i>bFromIndex</i> are similarly mis-aligned, will * do single byte comparisons to obtain alignment, then call * {@link #intVarHandleMismatch(byte[], int, byte[], int, int)}. */ @SuppressWarnings("unused") private static int alignedIntVarHandleMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { //while we could do an index check, the VarHandle call incorporates a check, //making any check here duplicative int aFromAlignment = aFromIndex & 3; int bFromAlignment = bFromIndex & 3; //if they are aligned, just go if (aFromAlignment == 0 && bFromAlignment == 0) { return intVarHandleMismatch(a, aFromIndex, bFromIndex, length); } int i=0; //if both are similarly out of alignment, adjust if (aFromAlignment == bFromAlignment) { for (int j=Math.min(4 - aFromAlignment, length); i<j; ++i) { if (a[aFromIndex + i] != a[bFromIndex + i]) { return i; } } } else if ((aFromAlignment & 1) == 1 && (bFromAlignment & 1) == 1) { //if they both have an odd alignment, adjust by one if (a[aFromIndex] != a[bFromIndex]) { return 0; } ++i; } return i + intVarHandleMismatch(a, aFromIndex + i, bFromIndex + i, length - i); } /** * Uses {@link #VAR_HANDLE_GET_INT} to compare 4 bytes at a time. */ private static int intVarHandleMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { //while we could do an index check, the VarHandle call incorporates a check, //making any check here duplicative int i=0; for (int j=length-3; i<j; i+=4) { final int aVal = (int) GET_PRIMITIVE.invokeExact(a, aFromIndex + i); final int bVal = (int) GET_PRIMITIVE.invokeExact(a, bFromIndex + i); if (aVal != bVal) { //this returns a value where bits which match are 0 and bits which //differ are 1 final int diff = aVal ^ bVal; //the first (in native byte order) bit which differs tells us //which byte differed final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff); return Math.min(i + (leadingZeros >>> 3), length); } } int result = length; switch(length & 3) { case 3: if (a[aFromIndex + i + 2] != a[bFromIndex + i + 2]) result = i + 2; case 2: if (a[aFromIndex + i + 1] != a[bFromIndex + i + 1]) result = i + 1; case 1: if (a[aFromIndex + i] != a[bFromIndex + i]) result = i; } return result; } /** * Uses {@code UNSAFE_GET_LONG} to compare 8 bytes at a time. */ @SuppressWarnings("unused") private static int unsafeMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { //it is important to check the indexes prior to making the Unsafe calls, //as Unsafe does not validate and could result in SIGSEGV if out of bounds if (length < 0 || aFromIndex < 0 || bFromIndex < 0 || Math.max(aFromIndex, bFromIndex) > a.length - length) { throw new ArrayIndexOutOfBoundsException(); } int i=0; for (int j=length-7; i<j; i+=8) { final long aVal = (long) GET_PRIMITIVE.invokeExact((Object) a, ARRAY_BASE_OFFSET + aFromIndex + i); final long bVal = (long) GET_PRIMITIVE.invokeExact((Object) a, ARRAY_BASE_OFFSET + bFromIndex + i); if (aVal != bVal) { //this returns a value where bits which match are 0 and bits which differ are 1 final long diff = aVal ^ bVal; //the first (in native byte order) bit which differs tells us which byte differed final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff); return i + (leadingZeros >>> 3); } } int result = length; switch(length & 7) { case 7: if (a[aFromIndex + i + 6] != a[bFromIndex + i + 6]) result = i + 6; case 6: if (a[aFromIndex + i + 5] != a[bFromIndex + i + 5]) result = i + 5; case 5: if (a[aFromIndex + i + 4] != a[bFromIndex + i + 4]) result = i + 4; case 4: if (a[aFromIndex + i + 3] != a[bFromIndex + i + 3]) result = i + 3; case 3: if (a[aFromIndex + i + 2] != a[bFromIndex + i + 2]) result = i + 2; case 2: if (a[aFromIndex + i + 1] != a[bFromIndex + i + 1]) result = i + 1; case 1: if (a[aFromIndex + i] != a[bFromIndex + i]) result = i; } return result; } /** * Uses {@code UNSAFE_GET_LONG} to compare 8 bytes at a time. */ @SuppressWarnings("unused") private static int intUnsafeMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { //it is important to check the indexes prior to making the Unsafe calls, //as Unsafe does not validate and could result in SIGSEGV if out of bounds if (length < 0 || aFromIndex < 0 || bFromIndex < 0 || Math.max(aFromIndex, bFromIndex) > a.length - length) { throw new ArrayIndexOutOfBoundsException(); } int i=0; for (int j=length-3; i<j; i+=4) { final int aVal = (int) GET_PRIMITIVE.invokeExact((Object) a, ARRAY_BASE_OFFSET + aFromIndex + i); final int bVal = (int) GET_PRIMITIVE.invokeExact((Object) a, ARRAY_BASE_OFFSET + bFromIndex + i); if (aVal != bVal) { //this returns a value where bits which match are 0 and bits which //differ are 1 final int diff = aVal ^ bVal; //the first (in native byte order) bit which differs tells us //which byte differed final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff); return i + (leadingZeros >>> 3); } } int result = length; switch(length & 3) { case 3: if (a[aFromIndex + i + 2] != a[bFromIndex + i + 2]) result = i + 2; case 2: if (a[aFromIndex + i + 1] != a[bFromIndex + i + 1]) result = i + 1; case 1: if (a[aFromIndex + i] != a[bFromIndex + i]) result = i; } return result; } /** * Uses {@code ARRAYS_MISMATCH} to compare <i>a</i> and <i>b</i>. */ @SuppressWarnings("unused") private static int arraysMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) throws Throwable { final int m = (int) ARRAYS_MISMATCH.invokeExact(a, aFromIndex, aFromIndex + length, a, bFromIndex, bFromIndex + length); return m == -1 ? length : m; } /** * Simply loops over all of the bytes, comparing one at a time. */ @SuppressWarnings("unused") private static int legacyMismatch( byte[] a, int aFromIndex, int bFromIndex, int length) { int i=0; for (int j=length - 7; i<j; i+=8) { if (a[aFromIndex + i] != a[bFromIndex + i]) return i; if (a[aFromIndex + i + 1] != a[bFromIndex + i + 1]) return i + 1; if (a[aFromIndex + i + 2] != a[bFromIndex + i + 2]) return i + 2; if (a[aFromIndex + i + 3] != a[bFromIndex + i + 3]) return i + 3; if (a[aFromIndex + i + 4] != a[bFromIndex + i + 4]) return i + 4; if (a[aFromIndex + i + 5] != a[bFromIndex + i + 5]) return i + 5; if (a[aFromIndex + i + 6] != a[bFromIndex + i + 6]) return i + 6; if (a[aFromIndex + i + 7] != a[bFromIndex + i + 7]) return i + 7; } int result = length; switch(length & 7) { case 7: if (a[aFromIndex + i + 6] != a[bFromIndex + i + 6]) result = i + 6; case 6: if (a[aFromIndex + i + 5] != a[bFromIndex + i + 5]) result = i + 5; case 5: if (a[aFromIndex + i + 4] != a[bFromIndex + i + 4]) result = i + 4; case 4: if (a[aFromIndex + i + 3] != a[bFromIndex + i + 3]) result = i + 3; case 3: if (a[aFromIndex + i + 2] != a[bFromIndex + i + 2]) result = i + 2; case 2: if (a[aFromIndex + i + 1] != a[bFromIndex + i + 1]) result = i + 1; case 1: if (a[aFromIndex + i] != a[bFromIndex + i]) result = i; } return result; } private ArrayUtil() { } }