Copilot commented on code in PR #318:
URL: https://github.com/apache/arrow-dotnet/pull/318#discussion_r3130197577


##########
src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs:
##########
@@ -141,6 +141,13 @@ protected RecordBatch CreateArrowObjectFromMessage(
                     break;
                 case Flatbuf.MessageHeader.RecordBatch:
                     Flatbuf.RecordBatch rb = 
message.Header<Flatbuf.RecordBatch>().Value;
+                    if (rb.Length > int.MaxValue)
+                    {
+                        throw new InvalidDataException(
+                            $"Cannot read batch. Message body of {rb.Length} 
bytes " +
+                            $"is greater than the maximum supported length 
({int.MaxValue})");

Review Comment:
   The overflow guard is validating `rb.Length`, which is the record batch row 
count, but the exception message says “Message body … bytes”. This is 
misleading (and the value is not bytes). Also consider validating `rb.Length >= 
0` since a malformed IPC message could make `Length` negative; the subsequent 
cast to `int` would then produce an invalid RecordBatch length.
   ```suggestion
                       if (rb.Length < 0 || rb.Length > int.MaxValue)
                       {
                           throw new InvalidDataException(
                               $"Cannot read batch. Record batch length of 
{rb.Length} rows " +
                               $"is outside the supported range [0, 
{int.MaxValue}]");
   ```



##########
src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs:
##########
@@ -106,6 +106,13 @@ public override RecordBatch ReadNextRecordBatch()
                     CreateByteBuffer(_buffer.Slice(_bufferPosition, 
messageLength)));
                 _bufferPosition += messageLength;
 
+                if (message.BodyLength > int.MaxValue)
+                {
+                    throw new InvalidDataException(
+                        $"Cannot read batch. Message body of 
{message.BodyLength} bytes " +
+                        $"is greater than the maximum supported length 
({int.MaxValue})");
+                }
+
                 int bodyLength = (int)message.BodyLength;

Review Comment:
   `message.BodyLength` is only checked for `> int.MaxValue`. If it is negative 
(malformed IPC), the `(int)message.BodyLength` cast succeeds and 
`_buffer.Slice(..., bodyLength)` will throw `ArgumentOutOfRangeException`. 
Consider validating `0 <= message.BodyLength && message.BodyLength <= 
int.MaxValue` before casting, and optionally verifying the buffer has at least 
`bodyLength` remaining bytes to produce a consistent `InvalidDataException` for 
corrupted input.
   ```suggestion
                   if (message.BodyLength < 0 || message.BodyLength > 
int.MaxValue)
                   {
                       throw new InvalidDataException(
                           $"Cannot read batch. Message body of 
{message.BodyLength} bytes " +
                           $"is outside the supported length range (0 to 
{int.MaxValue})");
                   }
   
                   int bodyLength = (int)message.BodyLength;
                   if (_buffer.Length - _bufferPosition < bodyLength)
                   {
                       throw new InvalidDataException(
                           $"Corrupted IPC message. Message body length of 
{bodyLength} bytes exceeds the remaining buffer length.");
                   }
   ```



##########
test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs:
##########
@@ -316,6 +316,163 @@ [new Field("index", Int32Type.Default, nullable: false)],
             });
         }
 
+        [Fact]
+        public void MalformedBodyLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+
+            // Message table vtable slot 10 = BodyLength
+            int bodyLengthPos = ReadFieldAbsolutePos(buffer, messageTablePos, 
vtableSlot: 10);
+            Assert.True(BitConverter.ToInt64(buffer, bodyLengthPos) > 0);
+
+            WriteInt64LittleEndian(buffer, bodyLengthPos, (long)int.MaxValue + 
1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("Message body", ex.Message);
+            Assert.Contains("maximum supported length", ex.Message);
+        }
+
+        [Fact]
+        public void MalformedRecordBatchLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+
+            // Message.Header (slot 8) is a union pointing at the RecordBatch 
table
+            int recordBatchTablePos = ReadUnionTablePos(buffer, 
messageTablePos, vtableSlot: 8);
+
+            // RecordBatch table vtable slot 4 = Length (row count)
+            int lengthPos = ReadFieldAbsolutePos(buffer, recordBatchTablePos, 
vtableSlot: 4);
+            Assert.Equal(3L, BitConverter.ToInt64(buffer, lengthPos));
+
+            WriteInt64LittleEndian(buffer, lengthPos, (long)int.MaxValue + 1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("maximum supported length", ex.Message);
+        }
+
+        [Fact]
+        public void MalformedFieldNodeLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+            int recordBatchTablePos = ReadUnionTablePos(buffer, 
messageTablePos, vtableSlot: 8);
+
+            // RecordBatch.Nodes (slot 6) is a vector of 16-byte FieldNode 
structs
+            // where the first 8 bytes are Length and the next 8 bytes are 
NullCount.
+            int nodesDataStart = ReadVectorDataStart(buffer, 
recordBatchTablePos, vtableSlot: 6);
+            Assert.Equal(3L, BitConverter.ToInt64(buffer, nodesDataStart));
+
+            WriteInt64LittleEndian(buffer, nodesDataStart, (long)int.MaxValue 
+ 1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("Field length", ex.Message);
+        }
+
+        [Fact]
+        public void MalformedBufferLength_OverflowsInt32()
+        {
+            byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+            int messageStart = FindRecordBatchMessageStart(buffer);
+            int messageTablePos = ReadRootTablePos(buffer, messageStart);
+            int recordBatchTablePos = ReadUnionTablePos(buffer, 
messageTablePos, vtableSlot: 8);
+
+            // RecordBatch.Buffers (slot 8) is a vector of 16-byte Buffer 
structs
+            // (8 bytes Offset, 8 bytes Length). Find the first buffer with 
non-zero
+            // length and corrupt its Length field.
+            int buffersDataStart = ReadVectorDataStart(buffer, 
recordBatchTablePos, vtableSlot: 8);
+            int buffersLength = BitConverter.ToInt32(buffer, buffersDataStart 
- 4);
+            int targetLengthPos = -1;
+            for (int i = 0; i < buffersLength; i++)
+            {
+                int lengthPos = buffersDataStart + i * 16 + 8;
+                if (BitConverter.ToInt64(buffer, lengthPos) > 0)
+                {
+                    targetLengthPos = lengthPos;
+                    break;
+                }
+            }
+            Assert.NotEqual(-1, targetLengthPos);
+
+            WriteInt64LittleEndian(buffer, targetLengthPos, (long)int.MaxValue 
+ 1);
+
+            InvalidDataException ex = Assert.Throws<InvalidDataException>(
+                () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+            Assert.Contains("IPC buffer length", ex.Message);
+        }
+
+        private static byte[] BuildSimpleInt32Batch(int rowCount)
+        {
+            Schema schema = new(
+                [new Field("x", Int32Type.Default, nullable: true)],
+                metadata: []);
+            Int32Array.Builder arrayBuilder = new();
+            for (int i = 0; i < rowCount; i++)
+            {
+                arrayBuilder.Append(i);
+            }
+            RecordBatch batch = new(schema, [arrayBuilder.Build()], rowCount);
+
+            using MemoryStream stream = new();
+            using (ArrowStreamWriter writer = new(stream, schema, leaveOpen: 
true))
+            {
+                writer.WriteRecordBatch(batch);
+                writer.WriteEnd();
+            }
+            return stream.ToArray();
+        }
+
+        private static int FindRecordBatchMessageStart(byte[] buffer)
+        {
+            // Stream layout: [continuation(0xFFFFFFFF)][len][schema 
message][continuation][len][batch message][body]...
+            int pos = 0;
+            Assert.Equal(-1, BitConverter.ToInt32(buffer, pos)); pos += 4;
+            int schemaLen = BitConverter.ToInt32(buffer, pos); pos += 4;
+            pos += schemaLen;
+            Assert.Equal(-1, BitConverter.ToInt32(buffer, pos)); pos += 4;
+            pos += 4; // batch message length prefix
+            return pos;
+        }
+
+        private static int ReadRootTablePos(byte[] buffer, int messageStart)
+        {
+            return messageStart + BitConverter.ToInt32(buffer, messageStart);
+        }
+
+        private static int ReadFieldAbsolutePos(byte[] buffer, int tablePos, 
int vtableSlot)
+        {
+            int vtable = tablePos - BitConverter.ToInt32(buffer, tablePos);
+            short fieldOffset = BitConverter.ToInt16(buffer, vtable + 
vtableSlot);
+            Assert.NotEqual(0, fieldOffset); // field must be present in the 
vtable
+            return tablePos + fieldOffset;

Review Comment:
   These new FlatBuffer parsing helpers use 
`BitConverter.ToInt32/ToInt16/ToInt64`, which are endianness-dependent. Since 
Arrow IPC/FlatBuffers are little-endian, tests should use 
`BinaryPrimitives.ReadInt32LittleEndian` / `ReadInt16LittleEndian` / 
`ReadInt64LittleEndian` (as other IPC tests do) to avoid failures on big-endian 
platforms and to make the intent explicit.



##########
src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs:
##########
@@ -134,8 +135,15 @@ public override async ValueTask<RecordBatch> 
ReadNextRecordBatchAsync(Cancellati
                 switch (message.HeaderType)
                 {
                     case MessageHeader.RecordBatch:
+                        if (message.BodyLength > int.MaxValue)
+                        {
+                            throw new InvalidDataException(
+                                $"Cannot read batch. Message body of 
{message.BodyLength} bytes " +
+                                $"is greater than the maximum supported length 
({int.MaxValue})");
+                        }
+
                         var body = _flightDataStream.Current.DataBody.Memory;
-                        return CreateArrowObjectFromMessage(message, 
CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
+                        return CreateArrowObjectFromMessage(message, 
CreateByteBuffer(body.Slice(0, checked((int)message.BodyLength))), null);

Review Comment:
   `message.BodyLength` is only checked for `> int.MaxValue`. If it is 
negative, `checked((int)message.BodyLength)` will succeed and `body.Slice(0, 
...)` will throw `ArgumentOutOfRangeException`. Consider validating 
`message.BodyLength >= 0` (and ideally `<= body.Length`) before slicing so 
malformed Flight/IPC payloads fail with a consistent `InvalidDataException`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to