This is an automated email from the ASF dual-hosted git repository.

exceptionfactory pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new b3b718dd63d NIFI-15319 Decrease memory counter after buffer shutdown 
in ConsumeKinesis
b3b718dd63d is described below

commit b3b718dd63d714cda73ff3e4f389853cd079a449
Author: Alaksiej Ščarbaty <[email protected]>
AuthorDate: Thu Dec 11 17:27:40 2025 +0100

    NIFI-15319 Decrease memory counter after buffer shutdown in ConsumeKinesis
    
    This closes #10633
    
    Signed-off-by: David Handermann <[email protected]>
---
 .../aws/kinesis/MemoryBoundRecordBuffer.java       | 89 +++++++++++++---------
 .../aws/kinesis/MemoryBoundRecordBufferTest.java   | 28 +++++++
 2 files changed, 81 insertions(+), 36 deletions(-)

diff --git 
a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java
 
b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java
index 10b349a81ae..c4d1522a350 100644
--- 
a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java
+++ 
b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java
@@ -113,7 +113,7 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
         }
 
         final RecordBatch recordBatch = new RecordBatch(records, checkpointer, 
calculateMemoryUsage(records));
-        memoryTracker.reserveMemory(recordBatch);
+        memoryTracker.reserveMemory(recordBatch, bufferId);
         final boolean addedRecords = buffer.offer(recordBatch);
 
         if (addedRecords) {
@@ -131,7 +131,7 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
                     records.getLast().sequenceNumber(),
                     records.getLast().subSequenceNumber());
             // If the buffer was invalidated, we should free memory reserved 
for these records.
-            memoryTracker.freeMemory(List.of(recordBatch));
+            memoryTracker.freeMemory(List.of(recordBatch), bufferId);
         }
     }
 
@@ -152,28 +152,27 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
 
     @Override
     public void shutdownShardConsumption(final ShardBufferId bufferId, final 
RecordProcessorCheckpointer checkpointer) {
-        final ShardBuffer buffer = shardBuffers.get(bufferId);
+        final ShardBuffer buffer = shardBuffers.remove(bufferId);
+
         if (buffer == null) {
             logger.debug("Buffer with id {} not found. Cannot shutdown shard 
consumption", bufferId);
-            return;
+        } else {
+            logger.debug("Shutting down the buffer {}. Checkpointing last 
consumed record", bufferId);
+            final Collection<RecordBatch> invalidatedBatches = 
buffer.shutdownBuffer(checkpointer);
+            memoryTracker.freeMemory(invalidatedBatches, bufferId);
         }
-
-        logger.debug("Shutting down the buffer {}. Checkpointing last consumed 
record", bufferId);
-        buffer.shutdownBuffer(checkpointer);
-
-        logger.debug("Removing buffer with id {} after successful last 
consumed record checkpoint", bufferId);
-        shardBuffers.remove(bufferId);
     }
 
     @Override
     public void consumerLeaseLost(final ShardBufferId bufferId) {
         final ShardBuffer buffer = shardBuffers.remove(bufferId);
 
-        logger.debug("Lease lost for buffer {}: Invalidating", bufferId);
-
-        if (buffer != null) {
+        if (buffer == null) {
+            logger.debug("Buffer with id {} not found. Ignoring lease lost 
event", bufferId);
+        } else {
+            logger.debug("Lease lost for buffer {}: Invalidating", bufferId);
             final Collection<RecordBatch> invalidatedBatches = 
buffer.invalidate();
-            memoryTracker.freeMemory(invalidatedBatches);
+            memoryTracker.freeMemory(invalidatedBatches, bufferId);
         }
     }
 
@@ -244,8 +243,9 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
             return;
         }
 
+        logger.debug("Committing consumed records for buffer {}", bufferId);
         final List<RecordBatch> consumedBatches = 
buffer.commitConsumedRecords();
-        memoryTracker.freeMemory(consumedBatches);
+        memoryTracker.freeMemory(consumedBatches, bufferId);
     }
 
     @Override
@@ -259,6 +259,7 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
         final ShardBuffer buffer = shardBuffers.get(bufferId);
 
         if (buffer != null) {
+            logger.debug("Rolling back consumed records for buffer {}", 
bufferId);
             buffer.rollbackConsumedRecords();
         }
     }
@@ -331,10 +332,11 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
             this.maxMemoryBytes = maxMemoryBytes;
         }
 
-        void reserveMemory(final RecordBatch recordBatch) {
+        void reserveMemory(final RecordBatch recordBatch, final ShardBufferId 
bufferId) {
             final long consumedBytes = recordBatch.batchSizeBytes();
 
             if (consumedBytes == 0) {
+                logger.debug("The batch for buffer {} is empty. No need to 
reserve memory", bufferId);
                 return;
             }
 
@@ -352,8 +354,8 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
                 } else {
                     final long newConsumedBytes = currentlyConsumedBytes + 
consumedBytes;
                     if 
(consumedMemoryBytes.compareAndSet(currentlyConsumedBytes, newConsumedBytes)) {
-                        logger.debug("Reserved {} bytes for {} records. Total 
consumed memory: {} bytes",
-                                consumedBytes, recordBatch.size(), 
newConsumedBytes);
+                        logger.debug("Reserved {} bytes for {} records for 
buffer {}. Total consumed memory: {} bytes",
+                                consumedBytes, recordBatch.size(), bufferId, 
newConsumedBytes);
                         break;
                     }
                     // If we're here, the compare and set operation failed, as 
another thread has modified the gauge in meantime.
@@ -362,8 +364,9 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
             }
         }
 
-        void freeMemory(final Collection<RecordBatch> consumedBatches) {
+        void freeMemory(final Collection<RecordBatch> consumedBatches, final 
ShardBufferId bufferId) {
             if (consumedBatches.isEmpty()) {
+                logger.debug("No batches were consumed from buffer {}. No need 
to free memory", bufferId);
                 return;
             }
 
@@ -380,8 +383,8 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
 
                 final long newTotal = currentlyConsumedBytes - freedBytes;
                 if (consumedMemoryBytes.compareAndSet(currentlyConsumedBytes, 
newTotal)) {
-                    logger.debug("Freed {} bytes for {} batches. Total 
consumed memory: {} bytes",
-                            freedBytes, consumedBatches.size(), newTotal);
+                    logger.debug("Freed {} bytes for {} batches from buffer 
{}. Total consumed memory: {} bytes",
+                            freedBytes, consumedBatches.size(), bufferId, 
newTotal);
 
                     final CountDownLatch oldLatch = 
memoryAvailableLatch.getAndSet(new CountDownLatch(1));
                     oldLatch.countDown(); // Release any waiting threads for 
free memory.
@@ -585,25 +588,28 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
             }
         }
 
-        void shutdownBuffer(final RecordProcessorCheckpointer checkpointer) {
-            if (invalidated.get()) {
-                return;
+        Collection<RecordBatch> shutdownBuffer(final 
RecordProcessorCheckpointer checkpointer) {
+            if (invalidated.getAndSet(true)) {
+                return emptyList();
             }
 
             if (batchesCount.get() == 0) {
                 checkpointLastReceivedRecord(checkpointer);
-            } else {
-                // If there are still records in the buffer, checkpointing 
with the latest provided checkpointer is not safe.
-                // But, if the records were committed without checkpointing in 
the past, we can checkpoint them now.
-                final LastIgnoredCheckpoint ignoredCheckpoint = 
this.lastIgnoredCheckpoint;
-                if (ignoredCheckpoint != null) {
-                    checkpointSequenceNumber(
-                            ignoredCheckpoint.checkpointer(),
-                            ignoredCheckpoint.sequenceNumber(),
-                            ignoredCheckpoint.subSequenceNumber()
-                    );
-                }
+                return emptyList();
+            }
+
+            // If there are still records in the buffer, checkpointing with 
the latest provided checkpointer is not safe.
+            // But, if the records were committed without checkpointing in the 
past, we can checkpoint them now.
+            final LastIgnoredCheckpoint ignoredCheckpoint = 
this.lastIgnoredCheckpoint;
+            if (ignoredCheckpoint != null) {
+                checkpointSequenceNumber(
+                        ignoredCheckpoint.checkpointer(),
+                        ignoredCheckpoint.sequenceNumber(),
+                        ignoredCheckpoint.subSequenceNumber()
+                );
             }
+
+            return drainInvalidatedBatches();
         }
 
         Collection<RecordBatch> invalidate() {
@@ -611,9 +617,17 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
                 return emptyList();
             }
 
+            return drainInvalidatedBatches();
+        }
+
+        private Collection<RecordBatch> drainInvalidatedBatches() {
+            if (!invalidated.get()) {
+                throw new IllegalStateException("Unable to drain invalidated 
batches for valid shard buffer: " + bufferId);
+            }
+
             final List<RecordBatch> batches = new ArrayList<>();
             RecordBatch batch;
-            // If both consumeRecords and invalidate are called concurrently, 
invalidation must always consume all batches.
+            // If both consumeRecords and drainInvalidatedBatches are called 
concurrently, invalidation must always consume all batches.
             // Since consumeRecords moves batches from pending to in_progress, 
during invalidation pending batches should be drained first.
             while ((batch = pendingBatches.poll()) != null) {
                 batches.add(batch);
@@ -683,6 +697,9 @@ final class MemoryBoundRecordBuffer implements 
RecordBuffer.ForKinesisClientLibr
                 } catch (final ShutdownException e) {
                     logger.warn("Failed to checkpoint records due to shutdown. 
Ignoring checkpoint", e);
                     return;
+                } catch (final RuntimeException e) {
+                    logger.warn("Failed to checkpoint records due to an error. 
Ignoring checkpoint", e);
+                    return;
                 }
             }
         }
diff --git 
a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java
 
b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java
index 9c16b6b18a8..1b53e3d2025 100644
--- 
a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java
+++ 
b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java
@@ -54,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.assertAll;
 import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 class MemoryBoundRecordBufferTest {
@@ -354,6 +355,33 @@ class MemoryBoundRecordBufferTest {
         assertTrue(recordBuffer.acquireBufferLease().isEmpty(), "Buffer should 
not be available after shutdown");
     }
 
+    @Test
+    void testShutdownShardConsumption_whileOtherShardIsValid() {
+        final int bufferSize = 100;
+
+        // Create buffer with small memory limit.
+        final MemoryBoundRecordBuffer recordBuffer = new 
MemoryBoundRecordBuffer(new NopComponentLog(), bufferSize, CHECKPOINT_INTERVAL);
+        final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1);
+        final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2);
+
+        final List<KinesisClientRecord> records1 = 
List.of(createRecordWithSize(bufferSize));
+        recordBuffer.addRecords(bufferId1, records1, checkpointer1);
+
+        // Shutting down a buffer with a record.
+        recordBuffer.shutdownShardConsumption(bufferId1, checkpointer1);
+
+        // Adding records to another buffer.
+        final List<KinesisClientRecord> records2 = 
List.of(createRecordWithSize(bufferSize));
+        assertTimeoutPreemptively(
+                Duration.ofSeconds(1),
+                () -> recordBuffer.addRecords(bufferId2, records2, 
checkpointer2),
+                "Records should be added to a buffer without memory 
backpressure");
+
+        final Lease lease = recordBuffer.acquireBufferLease().orElseThrow();
+        assertEquals(SHARD_ID_2, lease.shardId(), "Expected to acquire a lease 
for " + SHARD_ID_2);
+        assertEquals(records2, recordBuffer.consumeRecords(lease));
+    }
+
     @Test
     @Timeout(value = 5, unit = SECONDS)
     void testMemoryBackpressure() throws InterruptedException {

Reply via email to