This is an automated email from the ASF dual-hosted git repository.

CurtHagenlocher pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-dotnet.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b2afa2  chore(src): add checked conversions where needed (#318)
2b2afa2 is described below

commit 2b2afa201cf406167044a505b30dd096ebd37193
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Thu Apr 23 21:23:01 2026 -0700

    chore(src): add checked conversions where needed (#318)
    
    ## What's Changed
    
    Adds checks to some int64->int32 conversions when reading data through
    IPC.
---
 .../Internal/RecordBatchReaderImplementation.cs    |   9 +-
 src/Apache.Arrow/Arrays/RunEndEncodedArray.cs      |   2 +-
 .../Ipc/ArrowMemoryReaderImplementation.cs         |   6 +
 src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs  |  29 +++-
 test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs  | 175 ++++++++++++++++++++-
 5 files changed, 209 insertions(+), 12 deletions(-)

diff --git 
a/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs 
b/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
index cb64867..16e080d 100644
--- a/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
+++ b/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
@@ -15,6 +15,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.IO;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
@@ -134,8 +135,14 @@ namespace Apache.Arrow.Flight.Internal
                 switch (message.HeaderType)
                 {
                     case MessageHeader.RecordBatch:
+                        if (message.BodyLength < 0 || message.BodyLength > 
int.MaxValue)
+                        {
+                            throw new InvalidDataException(
+                                $"Cannot read batch. Message body of 
{message.BodyLength} bytes is out of range");
+                        }
+
                         var body = _flightDataStream.Current.DataBody.Memory;
-                        return CreateArrowObjectFromMessage(message, 
CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
+                        return CreateArrowObjectFromMessage(message, 
CreateByteBuffer(body.Slice(0, checked((int)message.BodyLength))), null);
                     default:
                         throw new NotImplementedException();
                 }
diff --git a/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs 
b/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
index c1aff92..2cf73ab 100644
--- a/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
+++ b/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
@@ -152,7 +152,7 @@ public class RunEndEncodedArray : Array, IIndexes
                     long? lastValue = int64Array.GetValue(int64Array.Length - 
1);
                     if (lastValue.HasValue && lastValue.Value > int.MaxValue)
                     {
-                        throw new ArgumentException("Run ends value exceeds 
maximum supported length.");
+                        throw new OverflowException($"Run ends value 
{lastValue.Value} exceeds maximum supported length {int.MaxValue}.");
                     }
                     return (int)(lastValue ?? throw new 
ArgumentException("invalid length"));
                 }
diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs 
b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index e43d1d5..e9cd936 100644
--- a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -106,6 +106,12 @@ namespace Apache.Arrow.Ipc
                     CreateByteBuffer(_buffer.Slice(_bufferPosition, 
messageLength)));
                 _bufferPosition += messageLength;
 
+                if (message.BodyLength < 0 || message.BodyLength > 
int.MaxValue)
+                {
+                    throw new InvalidDataException(
+                        $"Cannot read batch. Message body of 
{message.BodyLength} bytes is out of range");
+                }
+
                 int bodyLength = (int)message.BodyLength;
                 ByteBuffer bodybb = 
CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
                 _bufferPosition += bodyLength;
diff --git a/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs 
b/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
index 799a488..398c80b 100644
--- a/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
@@ -141,6 +141,12 @@ namespace Apache.Arrow.Ipc
                     break;
                 case Flatbuf.MessageHeader.RecordBatch:
                     Flatbuf.RecordBatch rb = 
message.Header<Flatbuf.RecordBatch>().Value;
+                    if (rb.Length < 0 || rb.Length > int.MaxValue)
+                    {
+                        throw new InvalidDataException(
+                            $"Cannot read batch. Message body of {rb.Length} 
rows is out of range.");
+                    }
+
                     List<IArrowArray> arrays = BuildArrays(message.Version, 
Schema, bodyByteBuffer, rb);
                     return new RecordBatch(Schema, memoryOwner, arrays, 
(int)rb.Length);
                 default:
@@ -256,20 +262,19 @@ namespace Apache.Arrow.Ipc
             ByteBuffer bodyData,
             IBufferCreator bufferCreator)
         {
-
-            int fieldLength = (int)fieldNode.Length;
-            int fieldNullCount = (int)fieldNode.NullCount;
-
-            if (fieldLength < 0)
+            if (fieldNode.Length < 0 || fieldNode.Length > int.MaxValue)
             {
-                throw new InvalidDataException("Field length must be >= 0"); 
// TODO:Localize exception message
+                throw new InvalidDataException($"Field length of 
{fieldNode.Length} must be >= 0 and <= {int.MaxValue}");
             }
 
-            if (fieldNullCount < 0)
+            if (fieldNode.NullCount < 0 || fieldNode.NullCount > int.MaxValue)
             {
-                throw new InvalidDataException("Null count must be >= 0"); // 
TODO:Localize exception message
+                throw new InvalidDataException($"Null count of 
{fieldNode.NullCount} must be >= 0 and <= {int.MaxValue}");
             }
 
+            int fieldLength = (int)fieldNode.Length;
+            int fieldNullCount = (int)fieldNode.NullCount;
+
             int buffers;
             IArrowType storageType = 
Types.ArrowTypeExtensions.GetStorageType(field.DataType);
             switch (storageType.TypeId)
@@ -375,6 +380,14 @@ namespace Apache.Arrow.Ipc
                 return ArrowBuffer.Empty;
             }
 
+            if (buffer.Offset < 0 || buffer.Offset > int.MaxValue ||
+                buffer.Length < 0 || buffer.Length > int.MaxValue ||
+                buffer.Length + buffer.Offset > bodyData.Length)
+            {
+                throw new InvalidDataException(
+                    $"IPC buffer range is out of range for a .NET buffer: 
offset={buffer.Offset}, length={buffer.Length}, size={bodyData.Length}.");
+            }
+
             int offset = (int)buffer.Offset;
             int length = (int)buffer.Length;
 
diff --git a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs 
b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index 22ed233..5e7c57e 100644
--- a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -14,12 +14,11 @@
 // limitations under the License.
 
 using System;
+using System.Buffers.Binary;
 using System.IO;
-using System.Reflection;
 using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
-using Apache.Arrow.Memory;
 using Apache.Arrow.Types;
 using Xunit;
 
@@ -316,6 +315,163 @@ namespace Apache.Arrow.Tests
             });
         }
 
+        [Fact]
+        public void MalformedBodyLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+
+            // Message table vtable slot 10 = BodyLength
+            int bodyLengthPos = ReadFieldAbsolutePos(buffer, messageTablePos, 
vtableSlot: 10);
+            Assert.True(ToInt64LittleEndian(buffer, bodyLengthPos) > 0);
+
+            WriteInt64LittleEndian(buffer, bodyLengthPos, (long)int.MaxValue + 
1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("Message body", ex.Message);
+            Assert.Contains("out of range", ex.Message);
+        }
+
+        [Fact]
+        public void MalformedRecordBatchLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+
+            // Message.Header (slot 8) is a union pointing at the RecordBatch 
table
+            int recordBatchTablePos = ReadUnionTablePos(buffer, 
messageTablePos, vtableSlot: 8);
+
+            // RecordBatch table vtable slot 4 = Length (row count)
+            int lengthPos = ReadFieldAbsolutePos(buffer, recordBatchTablePos, 
vtableSlot: 4);
+            Assert.Equal(3L, ToInt64LittleEndian(buffer, lengthPos));
+
+            WriteInt64LittleEndian(buffer, lengthPos, (long)int.MaxValue + 1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("out of range", ex.Message);
+        }
+
+        [Fact]
+        public void MalformedFieldNodeLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+            int recordBatchTablePos = ReadUnionTablePos(buffer, 
messageTablePos, vtableSlot: 8);
+
+            // RecordBatch.Nodes (slot 6) is a vector of 16-byte FieldNode 
structs
+            // where the first 8 bytes are Length and the next 8 bytes are 
NullCount.
+            int nodesDataStart = ReadVectorDataStart(buffer, 
recordBatchTablePos, vtableSlot: 6);
+            Assert.Equal(3L, ToInt64LittleEndian(buffer, nodesDataStart));
+
+            WriteInt64LittleEndian(buffer, nodesDataStart, (long)int.MaxValue 
+ 1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("Field length", ex.Message);
+        }
+
+        [Fact]
+        public void MalformedBufferLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+            int recordBatchTablePos = ReadUnionTablePos(buffer, 
messageTablePos, vtableSlot: 8);
+
+            // RecordBatch.Buffers (slot 8) is a vector of 16-byte Buffer 
structs
+            // (8 bytes Offset, 8 bytes Length). Find the first buffer with 
non-zero
+            // length and corrupt its Length field.
+            int buffersDataStart = ReadVectorDataStart(buffer, 
recordBatchTablePos, vtableSlot: 8);
+            int buffersLength = ToInt32LittleEndian(buffer, buffersDataStart - 
4);
+            int targetLengthPos = -1;
+            for (int i = 0; i < buffersLength; i++)
+            {
+                int lengthPos = buffersDataStart + i * 16 + 8;
+                if (ToInt64LittleEndian(buffer, lengthPos) > 0)
+                {
+                    targetLengthPos = lengthPos;
+                    break;
+                }
+            }
+            Assert.NotEqual(-1, targetLengthPos);
+
+            WriteInt64LittleEndian(buffer, targetLengthPos, (long)int.MaxValue 
+ 1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("IPC buffer range", ex.Message);
+        }
+
+        private static byte[] BuildSimpleInt32Batch(int rowCount)
+        {
+            Schema schema = new(
+                [new Field("x", Int32Type.Default, nullable: true)],
+                metadata: []);
+            Int32Array.Builder arrayBuilder = new();
+            for (int i = 0; i < rowCount; i++)
+            {
+                arrayBuilder.Append(i);
+            }
+            RecordBatch batch = new(schema, [arrayBuilder.Build()], rowCount);
+
+            using MemoryStream stream = new();
+            using (ArrowStreamWriter writer = new(stream, schema, leaveOpen: 
true))
+            {
+                writer.WriteRecordBatch(batch);
+                writer.WriteEnd();
+            }
+            return stream.ToArray();
+        }
+
+        private static int FindRecordBatchMessageStart(byte[] buffer)
+        {
+            // Stream layout: [continuation(0xFFFFFFFF)][len][schema 
message][continuation][len][batch message][body]...
+            int pos = 0;
+            Assert.Equal(-1, ToInt32LittleEndian(buffer, pos)); pos += 4;
+            int schemaLen = ToInt32LittleEndian(buffer, pos); pos += 4;
+            pos += schemaLen;
+            Assert.Equal(-1, ToInt32LittleEndian(buffer, pos)); pos += 4;
+            pos += 4; // batch message length prefix
+            return pos;
+        }
+
+        private static int ReadRootTablePos(byte[] buffer, int messageStart)
+        {
+            return messageStart + ToInt32LittleEndian(buffer, messageStart);
+        }
+
+        private static int ReadFieldAbsolutePos(byte[] buffer, int tablePos, 
int vtableSlot)
+        {
+            int vtable = tablePos - ToInt32LittleEndian(buffer, tablePos);
+            short fieldOffset = ToInt16LittleEndian(buffer, vtable + 
vtableSlot);
+            Assert.NotEqual(0, fieldOffset); // field must be present in the 
vtable
+            return tablePos + fieldOffset;
+        }
+
+        private static int ReadUnionTablePos(byte[] buffer, int tablePos, int 
vtableSlot)
+        {
+            int unionPtrPos = ReadFieldAbsolutePos(buffer, tablePos, 
vtableSlot);
+            return unionPtrPos + ToInt32LittleEndian(buffer, unionPtrPos);
+        }
+
+        private static int ReadVectorDataStart(byte[] buffer, int tablePos, 
int vtableSlot)
+        {
+            int vectorPtrPos = ReadFieldAbsolutePos(buffer, tablePos, 
vtableSlot);
+            int vectorLengthPos = vectorPtrPos + ToInt32LittleEndian(buffer, 
vectorPtrPos);
+            return vectorLengthPos + 4; // skip the 4-byte vector length prefix
+        }
+
+        private static void WriteInt64LittleEndian(byte[] buffer, int offset, 
long value)
+        {
+            System.Buffers.Binary.BinaryPrimitives.WriteInt64LittleEndian(
+                buffer.AsSpan(offset), value);
+        }
+
         [Fact]
         public async Task EmptyStreamNoSyncRead()
         {
@@ -327,6 +483,21 @@ namespace Apache.Arrow.Tests
             }
         }
 
+        private static short ToInt16LittleEndian(byte[] buffer, int offset)
+        {
+            return 
BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
+        }
+
+        private static int ToInt32LittleEndian(byte[] buffer, int offset)
+        {
+            return 
BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+        }
+
+        private static long ToInt64LittleEndian(byte[] buffer, int offset)
+        {
+            return 
BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+        }
+
         private class EmptyAsyncOnlyStream : Stream
         {
             public override bool CanRead => true;

Reply via email to