shubhamvishu commented on code in PR #15508:
URL: https://github.com/apache/lucene/pull/15508#discussion_r2962527526


##########
lucene/core/src/java25/org/apache/lucene/internal/vectorization/NativeVectorUtilSupport.java:
##########
@@ -0,0 +1,590 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.internal.vectorization;
+
+import static java.lang.foreign.ValueLayout.JAVA_BYTE;
+import static java.lang.foreign.ValueLayout.JAVA_DOUBLE;
+import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
+import static java.lang.foreign.ValueLayout.JAVA_INT;
+import static java.lang.foreign.ValueLayout.JAVA_LONG;
+
+import java.lang.foreign.AddressLayout;
+import java.lang.foreign.FunctionDescriptor;
+import java.lang.foreign.Linker;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.SymbolLookup;
+import java.lang.foreign.ValueLayout;
+import java.lang.invoke.MethodHandle;
+import java.util.function.Supplier;
+import java.util.logging.Logger;
+import org.apache.lucene.util.Constants;
+
+/**
+ * VectorUtilSupport implementation that uses native bindings for optimized 
vector operations(using
+ * Foreign Function and Memory API (FFM)) if available(optional) or else 
fallback to
+ * PanamaVectorUtil implementations.
+ *
+ * <p>This class provides access to native C implementations of dot product 
operations from the
+ * loaded shared/dynamic library(.so|.dylib|.dll) which is generated from C 
code and linked at
+ * runtime. The native library contains multiple optimized implementations:
+ *
+ * <p>PanamaVectorUtilSupport#dotProduct use this Native C implementation for 
dot product
+ * calculation if system property <b>lucene.useNativeDotProduct=true</b> is 
passed it always tries
+ * to ensure binary is provided and required methods are implemented
+ *
+ * <p>It Uses <code>Linker.Option.critical(true)</code> for optimal 
performance by eliminating the
+ * overhead of ensuring MemorySegments are allocated off-heap before native 
calls.
+ */
+@SuppressWarnings("restricted")
+public final class NativeVectorUtilSupport implements VectorUtilSupport {
+
+  private final VectorUtilSupport delegateVectorUtilSupport;
+
+  public static final AddressLayout POINTER = ValueLayout.ADDRESS;
+
+  private static final Linker LINKER = Linker.nativeLinker();
+  private static final SymbolLookup SYMBOL_LOOKUP;
+
+  @SuppressWarnings("NonFinalStaticField")
+  private static boolean isLibraryLoaded;
+
+  // TODO: Make this dynamic?
+  public static final String NATIVE_VECTOR_LIBRARY_NAME = "dotProduct";
+
+  public NativeVectorUtilSupport(VectorUtilSupport vectorUtilSupport) {
+    this.delegateVectorUtilSupport = vectorUtilSupport;
+  }
+
+  static {
+    try {
+      // Attempt to load the library
+      System.loadLibrary(NATIVE_VECTOR_LIBRARY_NAME);
+      isLibraryLoaded = true; // If successful, set the flag to true
+    } catch (UnsatisfiedLinkError e) {
+      // If the library loading fails, set the flag to false
+      isLibraryLoaded = false;
+      Logger.getLogger(NativeVectorUtilSupport.class.getName())
+          .warning("No native library" + NATIVE_VECTOR_LIBRARY_NAME + " found 
: " + e.getMessage());
+    }
+  }
+
+  // Function descriptors
+  // (POINTER, POINTER, INT) -> INT
+  private static final FunctionDescriptor twoPointerIntToInt =
+      FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);
+
+  // (POINTER, POINTER, INT) -> LONG
+  private static final FunctionDescriptor twoPointerIntToLong =
+      FunctionDescriptor.of(JAVA_LONG, POINTER, POINTER, JAVA_INT);
+
+  // (POINTER, POINTER, INT) -> FLOAT
+  private static final FunctionDescriptor twoPointerIntToFloat =
+      FunctionDescriptor.of(JAVA_FLOAT, POINTER, POINTER, JAVA_INT);
+
+  // (POINTER, POINTER, FLOAT, FLOAT, FLOAT, FLOAT, INT) -> FLOAT
+  private static final FunctionDescriptor minMaxScalarQuantizeDesc =
+      FunctionDescriptor.of(
+          JAVA_FLOAT, POINTER, POINTER, JAVA_FLOAT, JAVA_FLOAT, JAVA_FLOAT, 
JAVA_FLOAT, JAVA_INT);
+
+  // (POINTER, FLOAT, FLOAT, FLOAT, FLOAT, FLOAT, FLOAT, INT) -> FLOAT
+  private static final FunctionDescriptor recalculateOffsetDesc =
+      FunctionDescriptor.of(
+          JAVA_FLOAT,
+          POINTER,
+          JAVA_FLOAT,
+          JAVA_FLOAT,
+          JAVA_FLOAT,
+          JAVA_FLOAT,
+          JAVA_FLOAT,
+          JAVA_FLOAT,
+          JAVA_INT);
+
+  // (POINTER, POINTER, DOUBLE, INT) -> INT
+  private static final FunctionDescriptor filterByScoreDesc =
+      FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_DOUBLE, JAVA_INT);
+
+  // (POINTER, BYTE, INT) -> POINTER
+  private static final FunctionDescriptor l2normalizeDesc =
+      FunctionDescriptor.of(POINTER, POINTER, JAVA_BYTE, JAVA_INT);
+
+  // (POINTER, INT) -> void
+  private static final FunctionDescriptor expand8Desc =
+      FunctionDescriptor.ofVoid(POINTER, JAVA_INT);
+
+  // (POINTER, INT, INT, INT) -> INT
+  private static final FunctionDescriptor findNextGEQDesc =
+      FunctionDescriptor.of(JAVA_INT, POINTER, JAVA_INT, JAVA_INT, JAVA_INT);
+
+  // Method handles
+  private static final MethodHandle dotProduct$MH;
+  private static final MethodHandle squareDistance$MH;
+  private static final MethodHandle cosine$MH;
+  private static final MethodHandle dotProductFloat$MH;
+  private static final MethodHandle squareDistanceFloat$MH;
+  private static final MethodHandle cosineFloat$MH;
+  private static final MethodHandle int4SquareDistance$MH;
+  private static final MethodHandle int4SquareDistanceSinglePacked$MH;
+  private static final MethodHandle int4SquareDistanceBothPacked$MH;
+  private static final MethodHandle uint8SquareDistance$MH;
+  private static final MethodHandle uint8DotProduct$MH;
+  private static final MethodHandle int4DotProduct$MH;
+  private static final MethodHandle int4DotProductSinglePacked$MH;
+  private static final MethodHandle int4DotProductBothPacked$MH;
+  private static final MethodHandle int4BitDotProduct$MH;
+  private static final MethodHandle int4DibitDotProduct$MH;
+  private static final MethodHandle minMaxScalarQuantize$MH;
+  private static final MethodHandle recalculateScalarQuantizationOffset$MH;
+  private static final MethodHandle filterByScore$MH;
+  private static final MethodHandle l2normalize$MH;
+  private static final MethodHandle expand8$MH;
+  private static final MethodHandle findNextGEQ$MH;
+
+  public static boolean isLibraryLoaded() {
+    return isLibraryLoaded;
+  }
+
+  static {
+    if (isLibraryLoaded) {
+      SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
+      SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> 
LINKER.defaultLookup().find(name));
+
+      // Each method handle with a unique native method name
+      dotProduct$MH = getMethodHandle("dotProduct", twoPointerIntToInt);
+      squareDistance$MH = getMethodHandle("squareDistance", 
twoPointerIntToInt);
+      cosine$MH = getMethodHandle("cosine", twoPointerIntToInt);
+      dotProductFloat$MH = getMethodHandle("dotProductFloat", 
twoPointerIntToFloat);
+      squareDistanceFloat$MH = getMethodHandle("squareDistanceFloat", 
twoPointerIntToFloat);
+      cosineFloat$MH = getMethodHandle("cosineFloat", twoPointerIntToFloat);
+      int4SquareDistance$MH = getMethodHandle("int4SquareDistance", 
twoPointerIntToInt);
+      int4SquareDistanceSinglePacked$MH =
+          getMethodHandle("int4SquareDistanceSinglePacked", 
twoPointerIntToInt);
+      int4SquareDistanceBothPacked$MH =
+          getMethodHandle("int4SquareDistanceBothPacked", twoPointerIntToInt);
+      uint8SquareDistance$MH = getMethodHandle("uint8SquareDistance", 
twoPointerIntToInt);
+      uint8DotProduct$MH = getMethodHandle("uint8DotProduct", 
twoPointerIntToInt);
+      int4DotProduct$MH = getMethodHandle("int4DotProduct", 
twoPointerIntToInt);
+      int4DotProductSinglePacked$MH =
+          getMethodHandle("int4DotProductSinglePacked", twoPointerIntToInt);
+      int4DotProductBothPacked$MH = 
getMethodHandle("int4DotProductBothPacked", twoPointerIntToInt);
+      int4BitDotProduct$MH = getMethodHandle("int4BitDotProduct", 
twoPointerIntToLong);
+      int4DibitDotProduct$MH = getMethodHandle("int4DibitDotProduct", 
twoPointerIntToLong);
+      minMaxScalarQuantize$MH = getMethodHandle("minMaxScalarQuantize", 
minMaxScalarQuantizeDesc);
+      recalculateScalarQuantizationOffset$MH =
+          getMethodHandle("recalculateScalarQuantizationOffset", 
recalculateOffsetDesc);
+      filterByScore$MH = getMethodHandle("filterByScore", filterByScoreDesc);
+      l2normalize$MH = getMethodHandle("l2normalize", l2normalizeDesc);
+      expand8$MH = getMethodHandle("expand8", expand8Desc);
+      findNextGEQ$MH = getMethodHandle("findNextGEQ", findNextGEQDesc);
+    } else if (Constants.NATIVE_DOT_PRODUCT_ENABLED) {
+      throw new RuntimeException("Native library dotProduct missing!");
+    } else {
+      SYMBOL_LOOKUP = null;
+      dotProduct$MH = null;
+      squareDistance$MH = null;
+      cosine$MH = null;
+      dotProductFloat$MH = null;
+      squareDistanceFloat$MH = null;
+      cosineFloat$MH = null;
+      int4SquareDistance$MH = null;
+      int4SquareDistanceSinglePacked$MH = null;
+      int4SquareDistanceBothPacked$MH = null;
+      uint8SquareDistance$MH = null;
+      uint8DotProduct$MH = null;
+      int4DotProduct$MH = null;
+      int4DotProductSinglePacked$MH = null;
+      int4DotProductBothPacked$MH = null;
+      int4BitDotProduct$MH = null;
+      int4DibitDotProduct$MH = null;
+      minMaxScalarQuantize$MH = null;
+      recalculateScalarQuantizationOffset$MH = null;
+      filterByScore$MH = null;
+      l2normalize$MH = null;
+      expand8$MH = null;
+      findNextGEQ$MH = null;
+    }
+  }
+
+  private static MethodHandle getMethodHandle(String methodName, 
FunctionDescriptor descriptor) {
+    MethodHandle mh =
+        SYMBOL_LOOKUP
+            .find(methodName)
+            .map(addr -> LINKER.downcallHandle(addr, descriptor, 
Linker.Option.critical(true)))
+            .orElse(null);
+    if (mh == null && Constants.NATIVE_STRICT_MODE) {
+      throw new RuntimeException("C code for " + methodName + " was not 
linked!");
+    }
+    return mh;
+  }
+
+  // Reusable invoke helpers for signatures used multiple times
+  private static int invokeIntMethodHandle(MethodHandle mh, MemorySegment a, 
MemorySegment b) {
+    try {
+      return (int) mh.invokeExact(a, b, (int) a.byteSize());
+    } catch (Throwable ex) {
+      throw new AssertionError("should not reach here", ex);
+    }
+  }
+
+  private static long invokeLongMethodHandle(MethodHandle mh, MemorySegment a, 
MemorySegment b) {
+    try {
+      return (long) mh.invokeExact(a, b, (int) a.byteSize());
+    } catch (Throwable ex) {
+      throw new AssertionError("should not reach here", ex);
+    }
+  }
+
+  private static float invokeFloatMethodHandle(MethodHandle mh, MemorySegment 
a, MemorySegment b) {
+    try {
+      return (float) mh.invokeExact(a, b, (int) a.byteSize());
+    } catch (Throwable ex) {
+      throw new AssertionError("should not reach here", ex);
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  private static <T> T invokeOrDelegate(MethodHandle mh, Supplier<T> delegate, 
Object... args) {
+    if (mh != null) {
+      try {
+        return (T) mh.invokeExact(args);

Review Comment:
   > I'd like to get rid of this magic. it slows down,.....
   > :
   > Ifwe don't want to do it statically ........
   
   I like the suggestion, makes sense.
   
   
   > I wonder why tests pass, do we really test all methods here?
   
   We are not running it as part of the tests since we don't generate any 
binary (its still in the user court; user needs to pass the binary for tests to 
use it overriding whichever implementations they have) and the current [c 
code](https://github.com/apache/lucene/blob/main/lucene/misc/src/c/dotProduct.c)
 only have dotProduct impl and I tested few others but not all the existing 
VectorUtil methods. It would be good to always run the tests to ensure it 
catches all issues but idk if we want to go with generating a binary path. I'm 
thinking if its fine to not provide any platform specific binaries(abiding by 
what  we discussed earlier) but just extend the c code to have basic VectorUtil 
implementations and test that on the unit test(which this PR was doing 
initially). Maybe you have some better ideas to how we can have better test 
coverage here with/without having binary? or just use fake java implementations 
to test each operation? 
   Right now this piece is bit disconnected due to no automatic test coverage 
running.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to