Pulkitg64 commented on PR #15549: URL: https://github.com/apache/lucene/pull/15549#issuecomment-3757043489
I don't have any good news now, but sharing some progress, since it has been more than a week now so here it is: I tried using the [JDK PR](https://github.com/openjdk/jdk/pull/28002/files#diff-7a632162d7be3b75c981219fd4e441a0df429de9d8a0f6858d489c7840865187) for float16 computation as suggested by @rmuir. For this I had to checkout the JDK and pull the PR locally and build that and use that for building lucene code. Then I tried adding new APIs to support Float16 vectors everywhere like introduced new KnnFloat16VectorField, KnnFloat16VectorQuery, Scorer etc. (*I think I should not have wasted time to do this in first place but instead focus only on vectorScore implementation and it's benchmarks but I have learnt the lesson*) After doing all above stuffs, I ran benchmark below * With DefaultVectorUtilSupport (NumDocs: 100k) For defaultVectorUtilSupport, I implemented below function for dot-product score, which is converting to float32 vectors first before doing any computation. With this I am seeing regression in the latency (around 100%) for no quantization, because of the extra conversion. For the quantization cases, the latency is comparable, but the indexing is way slower, I think this is again because of conversion of shorts to floats during vector quantization (I think I can try to optimize it). ``` @Override public short dotProduct(short[] a, short[] b) { assert a.length == b.length : "Vector lengths must match"; float sum = 0f; for (int i = 0; i < a.length; i++) { sum = Math.fma( Float.float16ToFloat(a[i]), Float.float16ToFloat(b[i]), sum ); } return Float.floatToFloat16(sum); } ``` | Encoding | recall | latency(ms) | netCPU | avgCpuCount | quantized | visited | index(s) | index_docs/s | force_merge(s) | index_size(MB) | vec_disk(MB) | vec_RAM(MB) | |----------|--------|-------------|--------|-------------|-----------|---------|----------|--------------|----------------|----------------|--------------|-------------| | float16 | 0.989 | 9.924 | 9.868 | 0.994 | no | 5659 | 75.67 | 1321.46 | 0 | 206.19 | 390.625 | 390.625 | | float16 | 0.981 | 4.911 | 4.884 | 0.994 | 8 bits | 5680 | 82.27 | 1215.54 | 40.13 | 305.41 | 294.495 | 99.182 | | float16 | 0.926 | 5.885 | 5.849 | 0.994 | 4 bits | 5727 | 82.98 | 1205.08 | 0 | 256.59 | 245.667 | 50.354 | | float32 | 0.991 | 5.123 | 5.103 | 0.996 | no | 5680 | 28.56 | 3501.16 | 44.64 | 401.53 | 390.625 | 390.625 | | float32 | 0.981 | 4.692 | 4.675 | 0.997 | 8 bits | 5689 | 32.29 | 3097.13 | 52.05 | 500.71 | 489.807 | 99.182 | | float32 | 0.926 | 5.822 | 5.786 | 0.994 | 4 bits | 5728 | 32.62 | 3065.79 | 64.55 | 451.9 | 440.979 | 50.354 | * With PanamaVectorUtilSupport (NumDocs: 10k only because 100k was taking too much time) With Float16 Panama implementation, I am seeing very bad results (almost 40 times higher latency). I checked the profiler results and there is one JDK internal function call (`VectorPayload.getPayload()`)which is taking the most time. I am yet to understand why that function call is taking too long | encoding | recall | latency(ms) | netCPU | avgCpuCount | visited | index(s) | index_docs/s | index_size(MB) | |----------|---------|-------------|----------|--------------|----------|-----------|---------------|----------------| | float16 | 0.996 | 35.076 | 137.727 | 3.927 | 16072 | 29.70 | 336.71 | 19.96 | | float32 | 0.998 | 0.865 | 3.139 | 3.630 | 16088 | 1.95 | 5117.71 | 39.49 | Profiler output for float16: ``` PERCENT CPU SAMPLES STACK 66.03% 217295 jdk.internal.vm.vector.VectorSupport$VectorPayload#getPayload() [Inlined code] 10.08% 33172 jdk.incubator.vector.Float16Vector#tOpTemplate() [Inlined code] 7.90% 25995 jdk.incubator.vector.Float16#valueOf() [Inlined code] 6.16% 20265 jdk.incubator.vector.Float16Vector#lanewiseTemplate() [Inlined code] 5.07% 16691 jdk.incubator.vector.Float16#lambda$fma$0() [Inlined code] 1.54% 5084 jdk.incubator.vector.Float16Vector256#vectorFactory() [Inlined code] 0.40% 1323 jdk.incubator.vector.Float16#shortBitsToFloat16() [Inlined code] 0.35% 1160 jdk.incubator.vector.Float16Vector#fromArray0Template() [Inlined code] 0.30% 988 jdk.internal.vm.vector.VectorSupport#ternaryOp() [JIT compiled code] 0.23% 757 jdk.jfr.internal.JVM#emitEvent() [Native code] 0.21% 683 jdk.internal.vm.vector.VectorSupport$VectorPayload#<init>() [Inlined code] 0.16% 523 jdk.incubator.vector.Float16Vector$$Lambda.0x000000003811dec0#apply() [Inlined code] 0.15% 506 jdk.incubator.vector.Float16Vector#bOpTemplate() [Inlined code] 0.10% 340 jdk.incubator.vector.Float16Vector256#vec() [Inlined code] 0.10% 333 org.apache.lucene.util.TernaryLongHeap#downHeap() [Inlined code] 0.09% 302 org.apache.lucene.codecs.lucene90.compressing.Lucene90CompressingStoredFieldsReader$BlockState#doReset() [JIT compiled code] 0.09% 289 jdk.internal.vm.vector.VectorSupport#maybeRebox() [Inlined code] 0.09% 288 jdk.internal.foreign.MemorySessionImpl#checkValidStateRaw() [Inlined code] 0.07% 231 sun.nio.ch.UnixFileDispatcherImpl#force0() [Native code] 0.05% 171 jdk.incubator.vector.Float16Vector#lambda$reductionOperations$1() [Inlined code] 0.03% 112 jdk.incubator.vector.Float16Vector#rOpTemplate() [Inlined code] 0.03% 98 sun.nio.fs.UnixNativeDispatcher#open0() [Native code] 0.02% 82 java.util.TimSort#binarySort() [JIT compiled code] 0.02% 79 jdk.internal.misc.ScopedMemoryAccess#loadFromMemorySegmentScopedInternal() [Inlined code] 0.02% 77 jdk.internal.vm.vector.VectorSupport#binaryOp() [JIT compiled code] 0.02% 73 sun.nio.ch.UnixFileDispatcherImpl#read0() [Native code] 0.02% 60 org.apache.lucene.codecs.lucene90.compressing.StoredFieldsInts#readInts8() [Inlined code] 0.02% 59 org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader#search() [JIT compiled code] 0.02% 56 org.apache.lucene.util.TernaryLongHeap#upHeap() [Inlined code] 0.02% 52 org.apache.lucene.util.hnsw.RandomVectorScorer#bulkScore() [Inlined code] ``` Profiler output for float32: ``` PERCENT CPU SAMPLES STACK 65.82% 679346 org.apache.lucene.internal.vectorization.DefaultVectorUtilSupport#fma() [Inlined code] 27.94% 288356 org.apache.lucene.internal.vectorization.DefaultVectorUtilSupport#dotProduct() [Inlined code] 1.31% 13488 org.apache.lucene.index.Float16VectorValues$1#vectorValue() [Inlined code] 0.36% 3739 org.apache.lucene.util.hnsw.HnswGraphSearcher#searchLevel() [JIT compiled code] 0.36% 3677 org.apache.lucene.util.FixedBitSet#getAndSet() [Inlined code] 0.31% 3227 org.apache.lucene.util.hnsw.HnswConcurrentMergeBuilder$MergeSearcher#graphSeek() [JIT compiled code] 0.30% 3098 org.apache.lucene.util.VectorUtil#dotProduct() [Inlined code] 0.30% 3075 org.apache.lucene.util.hnsw.OnHeapHnswGraph#nextNeighbor() [Inlined code] 0.25% 2567 org.apache.lucene.util.TernaryLongHeap#downHeap() [Inlined code] 0.20% 2055 java.util.Arrays#fill() [Inlined code] 0.15% 1594 org.apache.lucene.util.hnsw.NeighborArray#nodes() [Inlined code] 0.14% 1489 org.apache.lucene.codecs.lucene90.compressing.Lucene90CompressingStoredFieldsReader$BlockState#doReset() [JIT compiled code] 0.13% 1374 org.apache.lucene.util.hnsw.RandomVectorScorer#bulkScore() [JIT compiled code] 0.10% 1046 jdk.jfr.internal.JVM#emitEvent() [Native code] 0.09% 931 org.apache.lucene.codecs.lucene95.OffHeapFloat16VectorValues#vectorValue() [Inlined code] 0.09% 915 org.apache.lucene.util.hnsw.OnHeapHnswGraph#getNeighbors() [Inlined code] 0.08% 778 org.apache.lucene.util.TernaryLongHeap#upHeap() [Inlined code] 0.07% 734 java.util.concurrent.locks.AbstractQueuedLongSynchronizer#apparentlyFirstQueuedIsExclusive() [Inlined code] 0.06% 615 java.util.ArrayList#elementData() [Inlined code] 0.06% 574 sun.nio.ch.UnixFileDispatcherImpl#write0() [Native code] 0.05% 564 jdk.internal.foreign.MemorySessionImpl#checkValidStateRaw() [Inlined code] 0.05% 514 org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer$Float16ScoringSupplier$1#score() [Inlined code] 0.05% 507 jdk.internal.foreign.AbstractMemorySegmentImpl#copy() [Inlined code] 0.05% 501 jdk.internal.misc.ScopedMemoryAccess#getByteInternal() [Inlined code] 0.05% 479 org.apache.lucene.util.hnsw.RandomVectorScorer#bulkScore() [Inlined code] 0.04% 461 jdk.internal.foreign.AbstractMemorySegmentImpl#checkBounds() [Inlined code] 0.04% 426 org.apache.lucene.util.ArrayUtil#growExact() [Inlined code] 0.04% 414 org.apache.lucene.util.hnsw.HnswGraphSearcher#graphNextNeighbor() [Inlined code] 0.04% 412 org.apache.lucene.util.hnsw.NeighborArray#addOutOfOrder() [Inlined code] 0.04% 391 org.apache.lucene.util.packed.DirectMonotonicReader#get() [Inlined code] ``` #### Next Steps: * Try and understand, why there is so much regression with Float16 panama support and understand the profiler results better. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
