benwtrent commented on code in PR #14304:
URL: https://github.com/apache/lucene/pull/14304#discussion_r1985759854
##########
lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java:
##########
@@ -907,4 +907,87 @@ public static long int4BitDotProduct128(byte[] q, byte[]
d) {
}
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
}
+
+ @Override
+ public float quantize(
+ float[] vector, byte[] dest, float scale, float alpha, float
minQuantile, float maxQuantile) {
+ float correction = 0;
+ int i = 0;
+ // only vectorize if we have a viable BYTE_SPECIES we can use for output
+ if (VECTOR_BITSIZE >= 256) {
+ for (; i < FLOAT_SPECIES.loopBound(vector.length); i +=
FLOAT_SPECIES.length()) {
+ FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
+
+ // Make sure the value is within the quantile range, cutting off the
tails
+ // see first parenthesis in equation: byte = (float - minQuantile) *
127/(maxQuantile -
+ // minQuantile)
+ FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
+ // Scale the value to the range [0, 127], this is our quantized value
+ // scale = 127/(maxQuantile - minQuantile)
+ // Math.round rounds to positive infinity, so do the same by +0.5 then
truncating to int
+ Vector<Integer> roundedDxs =
dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0);
+ // output this to the array
+ ((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest,
i);
+ // We multiply by `alpha` here to get the quantized value back into
the original range
+ // to aid in calculating the corrective offset
+ Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES,
0)).mul(alpha);
+ // Calculate the corrective offset that needs to be applied to the
score
+ // in addition to the `byte * minQuantile * alpha` term in the equation
+ // we add the `(dx - dxq) * dxq` term to account for the fact that the
quantized value
+ // will be rounded to the nearest whole number and lose some accuracy
+ // Additionally, we account for the global correction of
`minQuantile^2` in the equation
+ correction +=
+ v.sub(minQuantile / 2f)
+ .mul(minQuantile)
+ .add(v.sub(minQuantile).sub(dxq).mul(dxq))
+ .reduceLanes(VectorOperators.ADD);
Review Comment:
I think if you could keep the lanes separate for as long as possible, we get
a bigger perf boost. Reducing lanes is a serious bottleneck.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]