This is an automated email from the ASF dual-hosted git repository.
fanningpj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko.git
The following commit(s) were added to refs/heads/main by this push:
new 973327be25 VarHandle multi-byte reads for ByteString lastIndexOf
(#2838)
973327be25 is described below
commit 973327be25bda62953b17800259d37c398cb9f3c
Author: PJ Fanning <[email protected]>
AuthorDate: Sat Apr 11 20:23:06 2026 +0200
VarHandle multi-byte reads for ByteString lastIndexOf (#2838)
* perf: use VarHandle-based reads for ByteArrayIterator get methods and
SWAR lastIndexOf
Agent-Logs-Url:
https://github.com/pjfanning/incubator-pekko/sessions/1a4cc51a-e270-46cd-9cf0-e59a0e608650
Co-authored-by: pjfanning <[email protected]>
perf: add clarifying comments to unrolledLastIndexOf methods
Agent-Logs-Url:
https://github.com/pjfanning/incubator-pekko/sessions/1a4cc51a-e270-46cd-9cf0-e59a0e608650
Co-authored-by: pjfanning <[email protected]>
lastIndexOf (specialized)
scalafmt
Update SWARUtil.scala
Update ByteString.scala
Update ByteString.scala
* Update ByteStringSpec.scala
* Apply suggestions from code review
Co-authored-by: He-Pin(kerr) <[email protected]>
* review comments
* only call compilePattern if needed
* refactor
* refactor getLastIndex
* multi string concat tests
---------
Co-authored-by: copilot-swe-agent[bot]
<[email protected]>
Co-authored-by: He-Pin(kerr) <[email protected]>
---
.../org/apache/pekko/util/ByteStringSpec.scala | 96 ++++++++++
.../scala/org/apache/pekko/util/ByteString.scala | 208 ++++++++++++++++++++-
.../scala/org/apache/pekko/util/SWARUtil.scala | 11 ++
3 files changed, 309 insertions(+), 6 deletions(-)
diff --git
a/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
b/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
index 701b19df38..1601b581f1 100644
--- a/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
+++ b/actor-tests/src/test/scala/org/apache/pekko/util/ByteStringSpec.scala
@@ -699,6 +699,12 @@ class ByteStringSpec extends AnyWordSpec with Matchers
with Checkers {
byteStringLong.lastIndexOf('m') should ===(12)
byteStringLong.lastIndexOf('z') should ===(25)
byteStringLong.lastIndexOf('a') should ===(0)
+
+ val long1 = ByteString1.fromString("abcdefghijklmnop") // 16 bytes
+ long1.lastIndexOf('a'.toByte) should ===(0)
+ long1.lastIndexOf('p'.toByte) should ===(15)
+ long1.lastIndexOf('h'.toByte, 7) should ===(7)
+ long1.lastIndexOf('h'.toByte, 6) should ===(-1)
}
"indexOf from offset" in {
ByteString.empty.indexOf(5, -1) should ===(-1)
@@ -820,6 +826,74 @@ class ByteStringSpec extends AnyWordSpec with Matchers
with Checkers {
compact.lastIndexOf('b', 1) should ===(1)
compact.lastIndexOf('b', 0) should ===(-1)
compact.lastIndexOf('b', -1) should ===(-1)
+
+ val concat0 = ByteStrings(ByteString1.fromString("ab"),
ByteString1.fromString("dd"))
+ concat0.lastIndexOf('d'.toByte, 2) should ===(2)
+ concat0.lastIndexOf('d'.toByte, 3) should ===(3)
+ }
+ "lastIndexOf (specialized)" in {
+ ByteString.empty.lastIndexOf(5.toByte, -1) should ===(-1)
+ ByteString.empty.lastIndexOf(5.toByte, 0) should ===(-1)
+ ByteString.empty.lastIndexOf(5.toByte, 1) should ===(-1)
+ ByteString.empty.lastIndexOf(5.toByte) should ===(-1)
+ val byteString1 = ByteString1.fromString("abb")
+ byteString1.lastIndexOf('d'.toByte) should ===(-1)
+ byteString1.lastIndexOf('d'.toByte, -1) should ===(-1)
+ byteString1.lastIndexOf('d'.toByte, 4) should ===(-1)
+ byteString1.lastIndexOf('d'.toByte, 1) should ===(-1)
+ byteString1.lastIndexOf('d'.toByte, 0) should ===(-1)
+ byteString1.lastIndexOf('a'.toByte, -1) should ===(-1)
+ byteString1.lastIndexOf('a'.toByte) should ===(0)
+ byteString1.lastIndexOf('a'.toByte, 0) should ===(0)
+ byteString1.lastIndexOf('a'.toByte, 1) should ===(0)
+ byteString1.lastIndexOf('b'.toByte) should ===(2)
+ byteString1.lastIndexOf('b'.toByte, 2) should ===(2)
+ byteString1.lastIndexOf('b'.toByte, 1) should ===(1)
+ byteString1.lastIndexOf('b'.toByte, 0) should ===(-1)
+
+ val byteStrings = ByteStrings(ByteString1.fromString("abb"),
ByteString1.fromString("efg"))
+ byteStrings.lastIndexOf('e'.toByte) should ===(3)
+ byteStrings.lastIndexOf('e'.toByte, 6) should ===(3)
+ byteStrings.lastIndexOf('e'.toByte, 4) should ===(3)
+ byteStrings.lastIndexOf('e'.toByte, 1) should ===(-1)
+ byteStrings.lastIndexOf('e'.toByte, 0) should ===(-1)
+ byteStrings.lastIndexOf('e'.toByte, -1) should ===(-1)
+
+ byteStrings.lastIndexOf('b'.toByte) should ===(2)
+ byteStrings.lastIndexOf('b'.toByte, 6) should ===(2)
+ byteStrings.lastIndexOf('b'.toByte, 4) should ===(2)
+ byteStrings.lastIndexOf('b'.toByte, 1) should ===(1)
+ byteStrings.lastIndexOf('b'.toByte, 0) should ===(-1)
+ byteStrings.lastIndexOf('b'.toByte, -1) should ===(-1)
+
+ val compact = byteStrings.compact
+ compact.lastIndexOf('e'.toByte) should ===(3)
+ compact.lastIndexOf('e'.toByte, 6) should ===(3)
+ compact.lastIndexOf('e'.toByte, 4) should ===(3)
+ compact.lastIndexOf('e'.toByte, 1) should ===(-1)
+ compact.lastIndexOf('e'.toByte, 0) should ===(-1)
+ compact.lastIndexOf('e'.toByte, -1) should ===(-1)
+
+ compact.lastIndexOf('b'.toByte) should ===(2)
+ compact.lastIndexOf('b'.toByte, 6) should ===(2)
+ compact.lastIndexOf('b'.toByte, 4) should ===(2)
+ compact.lastIndexOf('b'.toByte, 1) should ===(1)
+ compact.lastIndexOf('b'.toByte, 0) should ===(-1)
+ compact.lastIndexOf('b'.toByte, -1) should ===(-1)
+
+ val sliced = ByteString1.fromString("xxabcdefghijk").drop(2)
+ sliced.lastIndexOf('k'.toByte) should ===(10)
+
+ val zeros = ByteString(Array[Byte](0, 1, 0, 1))
+ zeros.lastIndexOf(0.toByte) should ===(2)
+ val neg = ByteString(Array[Byte](-1, 0, -1))
+ neg.lastIndexOf((-1).toByte) should ===(2)
+
+ val concat0 = makeMultiByteStringsSample()
+ concat0.lastIndexOf(0xFF.toByte) should ===(18)
+ concat0.lastIndexOf(0xFF.toByte, 18) should ===(18)
+ concat0.lastIndexOf(0xFF.toByte, 17) should ===(0)
+ concat0.lastIndexOf(0xFE.toByte) should ===(-1)
}
"indexOf (specialized)" in {
ByteString.empty.indexOf(5.toByte) should ===(-1)
@@ -853,6 +927,10 @@ class ByteStringSpec extends AnyWordSpec with Matchers
with Checkers {
compact.indexOf('f'.toByte) should ===(4)
compact.indexOf('g'.toByte) should ===(5)
+ val concat0 = makeMultiByteStringsSample()
+ concat0.indexOf(0xFF.toByte) should ===(0)
+ concat0.indexOf(16.toByte) should ===(17)
+ concat0.indexOf(0xFE.toByte) should ===(-1)
}
"indexOf (specialized) from offset" in {
ByteString.empty.indexOf(5.toByte, -1) should ===(-1)
@@ -919,6 +997,11 @@ class ByteStringSpec extends AnyWordSpec with Matchers
with Checkers {
byteStringLong.indexOf('m', 2) should ===(12)
byteStringLong.indexOf('z', 2) should ===(25)
byteStringLong.indexOf('a', 2) should ===(-1)
+
+ val concat0 = makeMultiByteStringsSample()
+ concat0.indexOf(0xFF.toByte, 0) should ===(0)
+ concat0.indexOf(0xFF.toByte, 17) should ===(18)
+ concat0.indexOf(0xFE.toByte, 17) should ===(-1)
}
"contains" in {
ByteString.empty.contains(5) should ===(false)
@@ -1933,4 +2016,17 @@ class ByteStringSpec extends AnyWordSpec with Matchers
with Checkers {
}
}
}
+
+ private def makeMultiByteStringsSample(): ByteString = {
+ val byteStrings = Vector(
+ ByteString1(Array[Byte](0xFF.toByte)),
+ ByteString1(Array[Byte](0, 1, 2, 3)),
+ ByteString1(Array[Byte](4, 5)),
+ ByteString1(Array[Byte](6, 7, 8, 9)),
+ ByteString1(Array[Byte](10)),
+ ByteString1(Array[Byte](11, 12, 13, 14, 15, 16)),
+ ByteString1(Array[Byte](0xFF.toByte))
+ )
+ ByteStrings(byteStrings)
+ }
}
diff --git a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
index 2742436b68..982227ee94 100644
--- a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
+++ b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
@@ -259,7 +259,7 @@ object ByteString {
if (offset == length) return -1
}
val longCount = searchLength >>> 3
- val pattern = SWARUtil.compilePattern(elem)
+ val pattern = if (longCount > 0) SWARUtil.compilePattern(elem) else 0L
var i = 0
while (i < longCount) {
val word = SWARUtil.getLong(bytes, offset, ByteOrder.BIG_ENDIAN)
@@ -285,7 +285,7 @@ object ByteString {
if (offset == length) return -1
}
val longCount = searchLength >>> 3
- val pattern = SWARUtil.compilePattern(elem)
+ val pattern = if (longCount > 0) SWARUtil.compilePattern(elem) else 0L
var i = 0
while (i < longCount) {
val word = SWARUtil.getLong(bytes, offset, ByteOrder.BIG_ENDIAN)
@@ -297,6 +297,62 @@ object ByteString {
-1
}
+ override def lastIndexOf[B >: Byte](elem: B, end: Int): Int = {
+ elem match {
+ case byte: Byte => lastIndexOf(byte, end)
+ case _ =>
+ if (end < 0) -1
+ else {
+ var found = -1
+ var i = math.min(end, length - 1)
+ while (i >= 0 && found == -1) {
+ if (bytes(i) == elem) found = i
+ i -= 1
+ }
+ found
+ }
+ }
+ }
+
+ override def lastIndexOf(elem: Byte, end: Int): Int = {
+ val endIdx = math.min(end, length - 1)
+ if (endIdx < 0) return -1
+ val searchLength = endIdx + 1
+ // Check the rightmost partial chunk first (bytes not fitting in a full
8-byte block)
+ val tailBytes = searchLength & 7
+ if (tailBytes > 0) {
+ val tailStart = searchLength - tailBytes
+ val index = unrolledLastIndexOf(tailStart, tailBytes, elem)
+ if (index != -1) return index
+ if (tailStart == 0) return -1
+ }
+ // Scan full 8-byte chunks from right to left
+ var chunkStart = searchLength - tailBytes - 8
+ if (chunkStart >= 0) {
+ val pattern = SWARUtil.compilePattern(elem)
+ while (chunkStart >= 0) {
+ val word = SWARUtil.getLong(bytes, chunkStart, ByteOrder.BIG_ENDIAN)
+ val result = SWARUtil.applyPattern(word, pattern)
+ if (result != 0) return chunkStart + SWARUtil.getLastIndex(result)
+ chunkStart -= 8
+ }
+ }
+ -1
+ }
+
+ // Searches byteCount bytes (1-7) starting at fromIndex from highest to
lowest index,
+ // returning the rightmost (last) match, or -1 if not found.
+ private def unrolledLastIndexOf(fromIndex: Int, byteCount: Int, value:
Byte): Int = {
+ if (byteCount >= 7 && bytes(fromIndex + 6) == value) fromIndex + 6
+ else if (byteCount >= 6 && bytes(fromIndex + 5) == value) fromIndex + 5
+ else if (byteCount >= 5 && bytes(fromIndex + 4) == value) fromIndex + 4
+ else if (byteCount >= 4 && bytes(fromIndex + 3) == value) fromIndex + 3
+ else if (byteCount >= 3 && bytes(fromIndex + 2) == value) fromIndex + 2
+ else if (byteCount >= 2 && bytes(fromIndex + 1) == value) fromIndex + 1
+ else if (bytes(fromIndex) == value) fromIndex
+ else -1
+ }
+
private def unrolledFirstIndexOf(fromIndex: Int, byteCount: Int, value:
Byte): Int = {
if (bytes(fromIndex) == value) fromIndex
else if (byteCount == 1) -1
@@ -531,7 +587,7 @@ object ByteString {
if (offset == length) return -1
}
val longCount = searchLength >>> 3
- val pattern = SWARUtil.compilePattern(elem)
+ val pattern = if (longCount > 0) SWARUtil.compilePattern(elem) else 0L
var i = 0
while (i < longCount) {
val word = SWARUtil.getLong(bytes, startIndex + offset,
ByteOrder.BIG_ENDIAN)
@@ -557,7 +613,7 @@ object ByteString {
if (offset == length) return -1
}
val longCount = searchLength >>> 3
- val pattern = SWARUtil.compilePattern(elem)
+ val pattern = if (longCount > 0) SWARUtil.compilePattern(elem) else 0L
var i = 0
while (i < longCount) {
val word = SWARUtil.getLong(bytes, startIndex + offset,
ByteOrder.BIG_ENDIAN)
@@ -570,6 +626,49 @@ object ByteString {
}
+ override def lastIndexOf[B >: Byte](elem: B, end: Int): Int = {
+ elem match {
+ case byte: Byte => lastIndexOf(byte, end)
+ case _ =>
+ if (end < 0) -1
+ else {
+ var found = -1
+ var i = math.min(end, length - 1)
+ while (i >= 0 && found == -1) {
+ if (bytes(startIndex + i) == elem) found = i
+ i -= 1
+ }
+ found
+ }
+ }
+ }
+
+ override def lastIndexOf(elem: Byte, end: Int): Int = {
+ val endIdx = math.min(end, length - 1)
+ if (endIdx < 0) return -1
+ val searchLength = endIdx + 1
+ // Check the rightmost partial chunk first (bytes not fitting in a full
8-byte block)
+ val tailBytes = searchLength & 7
+ if (tailBytes > 0) {
+ val tailStart = searchLength - tailBytes
+ val index = unrolledLastIndexOf(startIndex + tailStart, tailBytes,
elem)
+ if (index != -1) return index - startIndex
+ if (tailStart == 0) return -1
+ }
+ // Scan full 8-byte chunks from right to left
+ var chunkStart = searchLength - tailBytes - 8
+ if (chunkStart >= 0) {
+ val pattern = SWARUtil.compilePattern(elem)
+ while (chunkStart >= 0) {
+ val word = SWARUtil.getLong(bytes, startIndex + chunkStart,
ByteOrder.BIG_ENDIAN)
+ val result = SWARUtil.applyPattern(word, pattern)
+ if (result != 0) return chunkStart + SWARUtil.getLastIndex(result)
+ chunkStart -= 8
+ }
+ }
+ -1
+ }
+
// the calling code already adds the startIndex so this method does not
need to
private def unrolledFirstIndexOf(fromIndex: Int, byteCount: Int, value:
Byte): Int = {
if (bytes(fromIndex) == value) fromIndex
@@ -588,6 +687,20 @@ object ByteString {
else -1
}
+ // the calling code already adds the startIndex so this method does not
need to.
+ // Searches byteCount bytes (1-7) starting at fromIndex from highest to
lowest index,
+ // returning the rightmost (last) match, or -1 if not found.
+ private def unrolledLastIndexOf(fromIndex: Int, byteCount: Int, value:
Byte): Int = {
+ if (byteCount >= 7 && bytes(fromIndex + 6) == value) fromIndex + 6
+ else if (byteCount >= 6 && bytes(fromIndex + 5) == value) fromIndex + 5
+ else if (byteCount >= 5 && bytes(fromIndex + 4) == value) fromIndex + 4
+ else if (byteCount >= 4 && bytes(fromIndex + 3) == value) fromIndex + 3
+ else if (byteCount >= 3 && bytes(fromIndex + 2) == value) fromIndex + 2
+ else if (byteCount >= 2 && bytes(fromIndex + 1) == value) fromIndex + 1
+ else if (bytes(fromIndex) == value) fromIndex
+ else -1
+ }
+
override def copyToArray[B >: Byte](dest: Array[B], start: Int, len: Int):
Int = {
// min of the bytes available to copy, bytes there is room for in dest
and the requested number of bytes
val toCopy = math.min(math.min(len, length), dest.length - start)
@@ -921,6 +1034,64 @@ object ByteString {
}
}
+ override def lastIndexOf[B >: Byte](elem: B, end: Int): Int = {
+ if (end < 0) -1
+ else {
+ val byteStringsLast = bytestrings.size - 1
+
+ @tailrec
+ def find(bsIdx: Int, relativeIndex: Int, len: Int): Int = {
+ if (bsIdx < 0) -1
+ else {
+ val bs = bytestrings(bsIdx)
+ val bsStartIndex = len - bs.length
+
+ if (relativeIndex < bsStartIndex || bs.isEmpty) {
+ if (bsIdx == 0) -1
+ else find(bsIdx - 1, relativeIndex, bsStartIndex)
+ } else {
+ val subIndexOf = bs.lastIndexOf(elem, relativeIndex -
bsStartIndex)
+ if (subIndexOf < 0) {
+ if (bsIdx == 0) -1
+ else find(bsIdx - 1, relativeIndex, bsStartIndex)
+ } else subIndexOf + bsStartIndex
+ }
+ }
+ }
+
+ find(byteStringsLast, math.min(end, length - 1), length)
+ }
+ }
+
+ override def lastIndexOf(elem: Byte, end: Int): Int = {
+ if (end < 0) -1
+ else {
+ val byteStringsLast = bytestrings.size - 1
+
+ @tailrec
+ def find(bsIdx: Int, relativeIndex: Int, len: Int): Int = {
+ if (bsIdx < 0) -1
+ else {
+ val bs = bytestrings(bsIdx)
+ val bsStartIndex = len - bs.length
+
+ if (relativeIndex < bsStartIndex || bs.isEmpty) {
+ if (bsIdx == 0) -1
+ else find(bsIdx - 1, relativeIndex, bsStartIndex)
+ } else {
+ val subIndexOf = bs.lastIndexOf(elem, relativeIndex -
bsStartIndex)
+ if (subIndexOf < 0) {
+ if (bsIdx == 0) -1
+ else find(bsIdx - 1, relativeIndex, bsStartIndex)
+ } else subIndexOf + bsStartIndex
+ }
+ }
+ }
+
+ find(byteStringsLast, math.min(end, length - 1), length)
+ }
+ }
+
override def copyToArray[B >: Byte](dest: Array[B], start: Int, len: Int):
Int = {
if (bytestrings.size == 1) bytestrings.head.copyToArray(dest, start, len)
else {
@@ -995,11 +1166,11 @@ sealed abstract class ByteString
// of ByteString which changed for Scala 2.12, see
https://github.com/akka/akka/issues/21774
override final def className: String = "ByteString"
+ override def isEmpty: Boolean = length == 0
+
// Cache the hash code since ByteString is immutable
override lazy val hashCode: Int = super.hashCode()
- override def isEmpty: Boolean = length == 0
-
// override protected[this] def newBuilder: ByteStringBuilder =
ByteString.newBuilder
// *must* be overridden by derived classes. This construction is necessary
@@ -1085,6 +1256,31 @@ sealed abstract class ByteString
*/
def indexOf(elem: Byte): Int = indexOf(elem, 0)
+ /**
+ * Finds index of last occurrence of some byte in this ByteString before or
at some end index.
+ *
+ * Similar to lastIndexOf, but it avoids boxing if the value is already a
byte.
+ *
+ * @param elem the element value to search for.
+ * @param end the end index
+ * @return the index `<= end` of the last element of this ByteString that
is equal (as determined by `==`)
+ * to `elem`, or `-1`, if none exists.
+ * @since 2.0.0
+ */
+ def lastIndexOf(elem: Byte, end: Int): Int = lastIndexOf[Byte](elem, end)
+
+ /**
+ * Finds index of last occurrence of some byte in this ByteString.
+ *
+ * Similar to lastIndexOf, but it avoids boxing if the value is already a
byte.
+ *
+ * @param elem the element value to search for.
+ * @return the index of the last element of this ByteString that is equal
(as determined by `==`)
+ * to `elem`, or `-1`, if none exists.
+ * @since 2.0.0
+ */
+ def lastIndexOf(elem: Byte): Int = lastIndexOf(elem, length - 1)
+
override def indexOfSlice[B >: Byte](slice: scala.collection.Seq[B], from:
Int): Int = {
// this is only called if the first byte matches, so we can skip that check
def check(startPos: Int): Boolean = {
diff --git a/actor/src/main/scala/org/apache/pekko/util/SWARUtil.scala
b/actor/src/main/scala/org/apache/pekko/util/SWARUtil.scala
index a0d44fd52d..1a2912f849 100644
--- a/actor/src/main/scala/org/apache/pekko/util/SWARUtil.scala
+++ b/actor/src/main/scala/org/apache/pekko/util/SWARUtil.scala
@@ -146,6 +146,17 @@ private[pekko] object SWARUtil {
def getIndex(word: Long): Int =
java.lang.Long.numberOfLeadingZeros(word) >>> 3
+ /**
+ * Returns the index of the last occurrence of a byte specified in the
pattern within a word.
+ * If no pattern is found, the result is -1. Currently only supports big
endian.
+ *
+ * @param word the return value of [[applyPattern]]
+ * @return the index of the last occurrence of the specified pattern in the
specified word.
+ */
+ def getLastIndex(word: Long): Int =
+ if (word == 0) -1
+ else (java.lang.Long.SIZE - 1 -
java.lang.Long.numberOfTrailingZeros(word)) >>> 3
+
/**
* Returns the long value at the specified index in the given byte array.
* Uses big-endian byte order. Uses a VarHandle byte array view if supported.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]