Lasse,

I have found a way to use VarHandle byte array access at runtime in
code which is compile time compatible with jdk 7. So here is an
updated ArrayUtil class which will use a VarHandle to read long values
in jdk 9+. If that is not available, it will attempt to use
sun.misc.Unsafe. If that cannot be found, it falls back to standard
byte by byte comparison.
I did add an index bounds check for the unsafe implementation and
found it had minimal impact on over all performance.
Using VarHandle (at least on jdk 11) offers very similar performance
to Unsafe across all 3 files I used for benchmarking.

--Baseline 1.8
Benchmark                                 (file)  Mode  Cnt      Score
    Error  Units
XZCompressionBenchmark.compress  ihe_ovly_pr.dcm  avgt    4      9.558
±   0.239  ms/op
XZCompressionBenchmark.compress       image1.dcm  avgt    4   6553.304
± 112.475  ms/op
XZCompressionBenchmark.compress        large.xml  avgt    4  10592.151
± 291.527  ms/op

--Unsafe
Benchmark                                 (file)  Mode  Cnt     Score
   Error  Units
XZCompressionBenchmark.compress  ihe_ovly_pr.dcm  avgt    4     7.699
±   0.058  ms/op
XZCompressionBenchmark.compress       image1.dcm  avgt    4  6001.170
± 143.814  ms/op
XZCompressionBenchmark.compress        large.xml  avgt    4  7853.963
±  83.753  ms/op

--VarHandle
Benchmark                                 (file)  Mode  Cnt     Score
   Error  Units
XZCompressionBenchmark.compress  ihe_ovly_pr.dcm  avgt    4     7.630
±   0.542  ms/op
XZCompressionBenchmark.compress       image1.dcm  avgt    4  5872.098
±  71.185  ms/op
XZCompressionBenchmark.compress        large.xml  avgt    4  8239.880
± 346.036  ms/op

I know you said you were not going to be able to work on xz-java for
awhile, but given these benchmark results, which really exceeded my
expectations, could this get some priority to release?


package org.tukaani.xz.common;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.nio.ByteOrder;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Utilities for optimized array interactions.
 *
 * @author Brett Okken
 */
public final class ArrayUtil {

    /**
     * MethodHandle to the actual mismatch method to use at runtime.
     */
    private static final MethodHandle MISMATCH;

    /**
     * If {@code sun.misc.Unsafe} can be loaded, this is MethodHandle
bound to an instance of Unsafe for method {@code long getLong(Object,
long)}.
     */
    private static final MethodHandle UNSAFE_GET_LONG;

    /**
     * MethodHandle to either {@link Long#numberOfLeadingZeros(long)}
or {@link Long#numberOfTrailingZeros(long)} depending on {@link
ByteOrder#nativeOrder()}.
     */
    private static final MethodHandle LEADING_ZEROS;

    /**
     * Populated from reflected read of {@code
sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET}.
     */
    private static final long ARRAY_BASE_OFFSET;

    /**
     * {@code MethodHandle} for a jdk 9+ {@code
byteArrayViewVarHandle} for {@code long[]} using the {@link
ByteOrder#nativeOrder()}.
     * The method signature is {@code long get(byte[], int)}.
     */
    private static final MethodHandle VAR_HANDLE_GET_LONG;

    static {
        final Logger logger = Logger.getLogger(ArrayUtil.class.getName());
        MethodHandle leadingZeros = null;
        MethodHandle varHandleGetLong = null;
        MethodHandle unsafeGetLong = null;
        long arrayBaseOffset = 0;
        MethodHandle mismatch = null;
        final MethodHandles.Lookup lookup = MethodHandles.lookup();
        final MethodType mismatchType =
MethodType.methodType(int.class, byte[].class, int.class,
byte[].class, int.class, int.class);
        try {
            //getLong interprets in platform byte order. the concept
of "leading zeros" being bytes
            //in encounter order is true for big endian
            //for little endian platform, the trailing zeros gives the
encounter order result
            leadingZeros = lookup.findStatic(Long.class,
                                             ByteOrder.BIG_ENDIAN ==
ByteOrder.nativeOrder()
                                                 ?
"numberOfLeadingZeros" : "numberOfTrailingZeros",

MethodType.methodType(int.class, long.class));

            //first try to load byteArrayViewVarHandle for a long[]
            try {
                final Class<?> varHandleClazz =
Class.forName("java.lang.invoke.VarHandle", true, null);
                final Method byteArrayViewHandle =
MethodHandles.class.getDeclaredMethod("byteArrayViewVarHandle", new
Class[] {Class.class, ByteOrder.class});
                final Object varHandle =
byteArrayViewHandle.invoke(null, long[].class,
ByteOrder.nativeOrder());
                final Class<?> accessModeEnum =
Class.forName("java.lang.invoke.VarHandle$AccessMode", true, null);
                @SuppressWarnings({ "unchecked", "rawtypes" })
                final Object getAccessModeEnum =
Enum.valueOf((Class)accessModeEnum, "GET");
                final Method toMethodHandle =
varHandleClazz.getDeclaredMethod("toMethodHandle", accessModeEnum);
                varHandleGetLong = (MethodHandle)
toMethodHandle.invoke(varHandle, getAccessModeEnum);
                mismatch = lookup.findStatic(ArrayUtil.class,
"varHandleMismatch", mismatchType);
                logger.finest("byte[] comparison using VarHandle");
            } catch (Throwable t) {
                logger.log(Level.FINE, "failed trying to load a
MethodHandle to invoke get on a byteArrayViewVarHandle for a long[]",
t);
                unsafeGetLong = null;
                mismatch = null;
            }

            //if byteArrayViewVarHandle for a long[] could not be
loaded, then try to load sun.misc.Unsafe
            if (mismatch == null) {
                Class<?> unsafeClazz =
Class.forName("sun.misc.Unsafe", true, null);
                Constructor<?> unsafeConstructor =
unsafeClazz.getDeclaredConstructor();
                unsafeConstructor.setAccessible(true);
                Object unsafe = unsafeConstructor.newInstance();

                arrayBaseOffset =
unsafeClazz.getField("ARRAY_BYTE_BASE_OFFSET").getLong(null);

                MethodHandle virtualGetLong =
lookup.findVirtual(unsafeClazz, "getLong",
MethodType.methodType(long.class, Object.class, long.class));
                unsafeGetLong = virtualGetLong.bindTo(unsafe);

                // do a test read to confirm unsafe is actually functioning
                long val = (long) unsafeGetLong.invokeExact((Object)
new byte[] { 0, 0, 0, 0, 0, 0, 0, 0 }, arrayBaseOffset + 0L);
                if (val != 0) {
                    throw new IllegalStateException("invalid value: " + val);
                }
                mismatch = lookup.findStatic(ArrayUtil.class,
"unsafeMismatch", mismatchType);
                logger.finest("byte[] comparisons using Unsafe");
            }
        } catch (Throwable t) {
            logger.log(Level.FINE, "failed trying to load means to
compare byte[] by longs", t);
            logger.finest("byte[] comparisons byte by byte");
            varHandleGetLong = null;
            unsafeGetLong = null;
            leadingZeros = null;
            try {
                mismatch = lookup.findStatic(ArrayUtil.class,
"legacyMismatch", mismatchType);
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        VAR_HANDLE_GET_LONG = varHandleGetLong;
        UNSAFE_GET_LONG = unsafeGetLong;
        ARRAY_BASE_OFFSET = arrayBaseOffset;
        LEADING_ZEROS = leadingZeros;
        MISMATCH = mismatch;
    }

    /**
     * Compares the values in <i>a</i> and <i>b</i> and returns the
index of the first {@code byte} which differs.
     * @param a The first {@code byte[]} for comparison.
     * @param aFromIndex The offset into <i>a</i> to start reading from.
     * @param b The second {@code byte[]} for comparison.
     * @param bFromIndex The offset into <i>b</i> to start reading from.
     * @param length The number of bytes to compare.
     * @return The offset from the starting indexes of the first byte
which differs. If all match, <i>length</i> will be returned.
     */
    public static int mismatch(byte[] a, int aFromIndex, byte[] b, int
bFromIndex, int length) {
       try {
          return (int) MISMATCH.invokeExact(a, aFromIndex, b,
bFromIndex, length);
       } catch (RuntimeException e) {
           throw e;
       } catch (Error e) {
           throw e;
       } catch (Throwable t) {
           throw new RuntimeException(t);
       }
    }

    /**
     * Uses {@link #VAR_HANDLE_GET_LONG} to compare 8 bytes at a time.
     */
    @SuppressWarnings("unused")
    private static int varHandleMismatch(byte[] a, int aFromIndex,
byte[] b, int bFromIndex, int length) throws Throwable {
        //while we could do an index check, the VarHandle call
incorporates a check, making any check here duplicative
        int i=0;
        for (int j=length - 7; i<j; i+=8) {
            final long aVal = (long)
VAR_HANDLE_GET_LONG.invokeExact(a, aFromIndex + i);
            final long bVal = (long)
VAR_HANDLE_GET_LONG.invokeExact(b, bFromIndex + i);
            if (aVal != bVal) {
                //this returns a value where bits which match are 0
and bits which differ are 1
                final long diff = aVal ^ bVal;
                //the first (in native byte order) bit which differs
tells us which byte differed
                final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
                return i + (leadingZeros / Byte.SIZE);
            }
        }
        for ( ; i<length; ++i) {
            if (a[aFromIndex + i] != b[bFromIndex + i]) {
                return i;
            }
        }
        return length;
    }

    /**
     * Uses {@code UNSAFE_GET_LONG} to compare 8 bytes at a time.
     */
    @SuppressWarnings("unused")
    private static int unsafeMismatch(byte[] a, int aFromIndex, byte[]
b, int bFromIndex, int length) throws Throwable {
        //it is important to check the indexes prior to making the
Unsafe calls, as Unsafe does not validate
        //and could result in SIGSEGV if out of bounds
        if (aFromIndex < 0 || aFromIndex + length > a.length ||
bFromIndex < 0 || bFromIndex + length > b.length) {
            throw new ArrayIndexOutOfBoundsException();
        }
        int i=0;
        for (int j=length - 7; i<j; i+=8) {
            final long aVal = (long)
UNSAFE_GET_LONG.invokeExact((Object) a, ARRAY_BASE_OFFSET + aFromIndex
+ i);
            final long bVal = (long)
UNSAFE_GET_LONG.invokeExact((Object) b, ARRAY_BASE_OFFSET + bFromIndex
+ i);
            if (aVal != bVal) {
                //this returns a value where bits which match are 0
and bits which differ are 1
                final long diff = aVal ^ bVal;
                //the first (in native byte order) bit which differs
tells us which byte differed
                final int leadingZeros = (int) LEADING_ZEROS.invokeExact(diff);
                return i + (leadingZeros / Byte.SIZE);
            }
        }
        for ( ; i<length; ++i) {
            if (a[aFromIndex + i] != b[bFromIndex + i]) {
                return i;
            }
        }
        return length;
    }

    /**
     * Simply loops over all of the bytes, comparing one at a time.
     */
    @SuppressWarnings("unused")
    private static int legacyMismatch(byte[] a, int aFromIndex, byte[]
b, int bFromIndex, int length) {
        for (int i=0; i<length; ++i) {
            if (a[aFromIndex + i] != b[bFromIndex + i]) {
                return i;
            }
        }
        return length;
    }

    private ArrayUtil() {
    }
}

Reply via email to