Here is a class which is compatible with jdk 7. It will use a
MethodHandle to invoke Arrays.mismatch if that is found at runtime. If
that is not found, it will see if it can find Unsafe to read 4 bytes
at a time and compare as ints. If that cannot be found/loaded/invoked,
it falls back to iterating over bytes and comparing one by one.

For jdk 9, the mismatch method could instead be implemented as:
return Arrays.mismatch(a, aFromIndex, aFromIndex + length, b,
bFromIndex, bFromIndex + length);


package org.tukaani.xz.common;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.nio.ByteOrder;
import java.util.Arrays;

import sun.misc.Unsafe;

public final class ArrayUtil {

    private static interface ArrayComparer {
        int mismatch(byte[] a, int aFromIndex, byte[] b, int
bFromIndex, int length);
    }

    private static final ArrayComparer COMPARER;

    static {
        //try to create an instance using Unsafe
        ArrayComparer comparer = null;
        try {
            try {
                final MethodHandle mh =
MethodHandles.lookup().findStatic(Arrays.class, "mismatch",
MethodType.methodType(int.class, byte[].class, int.class, int.class,
byte[].class, int.class, int.class));
                comparer = new ArrayComparer() {
                    @Override
                    public int mismatch(byte[] a, int aFromIndex,
byte[] b, int bFromIndex, int length) {
                        try {
                            return (int) mh.invokeExact(a, aFromIndex,
aFromIndex + length, b, bFromIndex, bFromIndex + length);
                        } catch (Throwable e) {
                            if (e instanceof RuntimeException) {
                                throw (RuntimeException) e;
                            }
                            if (e instanceof Error) {
                                throw (Error) e;
                            }
                            throw new IllegalStateException(e);
                        }
                    }
                };
            } catch (Exception e) {
                //TODO: log out?
            }

            if (comparer == null) {
                Constructor<Unsafe> unsafeConstructor =
Unsafe.class.getDeclaredConstructor();
                unsafeConstructor.setAccessible(true);
                final Unsafe unsafe = unsafeConstructor.newInstance();

                //do a test read to confirm unsafe is actually functioning
                int val = unsafe.getInt(new byte[] {0,0,0,0},
Unsafe.ARRAY_BYTE_BASE_OFFSET + 0L);
                if (val != 0) {
                    throw new IllegalStateException("invalid value: " + val);
                }

                final boolean bigEndian = ByteOrder.BIG_ENDIAN ==
ByteOrder.nativeOrder();

                comparer = new ArrayComparer() {
                    @Override
                    public int mismatch(byte[] a, int aFromIndex,
byte[] b, int bFromIndex, int length) {
                        int i=0;
                        for (int j=length - 3; i<j; i+=4) {
                            int aVal = unsafe.getInt(a, ((long)
Unsafe.ARRAY_BYTE_BASE_OFFSET) + aFromIndex + i);
                            int bVal = unsafe.getInt(b, ((long)
Unsafe.ARRAY_BYTE_BASE_OFFSET) +  bFromIndex + i);
                            if (aVal != bVal) {
                                final int diff = aVal ^ bVal;
                                //getInt interprets in platform byte
order. the concept of "leading zeros" being bytes
                                //in encounter order is true for big endian
                                //for little endian platform, the
trailing zeros gives the encounter order result
                                final int leadingZeros = bigEndian ?
Integer.numberOfLeadingZeros(diff) :
Integer.numberOfTrailingZeros(diff);
                                return i + (leadingZeros / Byte.SIZE);
                            }
                        }
                        for ( ; i<length; ++i) {
                            if (a[aFromIndex + i] != b[bFromIndex + i]) {
                                return i;
                            }
                        }
                        return -1;
                    }
                };
            }
        } catch (Throwable t) {
            //TODO: log out?

            //fall back to "normal" implementation
            comparer = new ArrayComparer() {
                @Override
                public int mismatch(byte[] a, int aFromIndex, byte[]
b, int bFromIndex, int length) {
                    for (int i=0; i<length; ++i) {
                        if (a[aFromIndex + i] != b[bFromIndex + i]) {
                            return i;
                        }
                    }
                    return -1;
                }
            };
        }

        COMPARER = comparer;
    }

    public static int mismatch(byte[] a, int aFromIndex, byte[] b, int
bFromIndex, int length) {
       return COMPARER.mismatch(a, aFromIndex, b, bFromIndex, length);
    }

    private ArrayUtil() {
    }
}

Reply via email to