This is an automated email from the ASF dual-hosted git repository.
CurtHagenlocher pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-dotnet.git
The following commit(s) were added to refs/heads/main by this push:
new 2b2afa2 chore(src): add checked conversions where needed (#318)
2b2afa2 is described below
commit 2b2afa201cf406167044a505b30dd096ebd37193
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Thu Apr 23 21:23:01 2026 -0700
chore(src): add checked conversions where needed (#318)
## What's Changed
Adds checks to some int64->int32 conversions when reading data through
IPC.
---
.../Internal/RecordBatchReaderImplementation.cs | 9 +-
src/Apache.Arrow/Arrays/RunEndEncodedArray.cs | 2 +-
.../Ipc/ArrowMemoryReaderImplementation.cs | 6 +
src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 29 +++-
test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs | 175 ++++++++++++++++++++-
5 files changed, 209 insertions(+), 12 deletions(-)
diff --git
a/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
b/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
index cb64867..16e080d 100644
--- a/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
+++ b/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
@@ -15,6 +15,7 @@
using System;
using System.Collections.Generic;
+using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -134,8 +135,14 @@ namespace Apache.Arrow.Flight.Internal
switch (message.HeaderType)
{
case MessageHeader.RecordBatch:
+ if (message.BodyLength < 0 || message.BodyLength >
int.MaxValue)
+ {
+ throw new InvalidDataException(
+ $"Cannot read batch. Message body of
{message.BodyLength} bytes is out of range");
+ }
+
var body = _flightDataStream.Current.DataBody.Memory;
- return CreateArrowObjectFromMessage(message,
CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
+ return CreateArrowObjectFromMessage(message,
CreateByteBuffer(body.Slice(0, checked((int)message.BodyLength))), null);
default:
throw new NotImplementedException();
}
diff --git a/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
b/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
index c1aff92..2cf73ab 100644
--- a/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
+++ b/src/Apache.Arrow/Arrays/RunEndEncodedArray.cs
@@ -152,7 +152,7 @@ public class RunEndEncodedArray : Array, IIndexes
long? lastValue = int64Array.GetValue(int64Array.Length -
1);
if (lastValue.HasValue && lastValue.Value > int.MaxValue)
{
- throw new ArgumentException("Run ends value exceeds
maximum supported length.");
+ throw new OverflowException($"Run ends value
{lastValue.Value} exceeds maximum supported length {int.MaxValue}.");
}
return (int)(lastValue ?? throw new
ArgumentException("invalid length"));
}
diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index e43d1d5..e9cd936 100644
--- a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -106,6 +106,12 @@ namespace Apache.Arrow.Ipc
CreateByteBuffer(_buffer.Slice(_bufferPosition,
messageLength)));
_bufferPosition += messageLength;
+ if (message.BodyLength < 0 || message.BodyLength >
int.MaxValue)
+ {
+ throw new InvalidDataException(
+ $"Cannot read batch. Message body of
{message.BodyLength} bytes is out of range");
+ }
+
int bodyLength = (int)message.BodyLength;
ByteBuffer bodybb =
CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;
diff --git a/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
b/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
index 799a488..398c80b 100644
--- a/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
@@ -141,6 +141,12 @@ namespace Apache.Arrow.Ipc
break;
case Flatbuf.MessageHeader.RecordBatch:
Flatbuf.RecordBatch rb =
message.Header<Flatbuf.RecordBatch>().Value;
+ if (rb.Length < 0 || rb.Length > int.MaxValue)
+ {
+ throw new InvalidDataException(
+ $"Cannot read batch. Message body of {rb.Length}
rows is out of range.");
+ }
+
List<IArrowArray> arrays = BuildArrays(message.Version,
Schema, bodyByteBuffer, rb);
return new RecordBatch(Schema, memoryOwner, arrays,
(int)rb.Length);
default:
@@ -256,20 +262,19 @@ namespace Apache.Arrow.Ipc
ByteBuffer bodyData,
IBufferCreator bufferCreator)
{
-
- int fieldLength = (int)fieldNode.Length;
- int fieldNullCount = (int)fieldNode.NullCount;
-
- if (fieldLength < 0)
+ if (fieldNode.Length < 0 || fieldNode.Length > int.MaxValue)
{
- throw new InvalidDataException("Field length must be >= 0");
// TODO:Localize exception message
+ throw new InvalidDataException($"Field length of
{fieldNode.Length} must be >= 0 and <= {int.MaxValue}");
}
- if (fieldNullCount < 0)
+ if (fieldNode.NullCount < 0 || fieldNode.NullCount > int.MaxValue)
{
- throw new InvalidDataException("Null count must be >= 0"); //
TODO:Localize exception message
+ throw new InvalidDataException($"Null count of
{fieldNode.NullCount} must be >= 0 and <= {int.MaxValue}");
}
+ int fieldLength = (int)fieldNode.Length;
+ int fieldNullCount = (int)fieldNode.NullCount;
+
int buffers;
IArrowType storageType =
Types.ArrowTypeExtensions.GetStorageType(field.DataType);
switch (storageType.TypeId)
@@ -375,6 +380,14 @@ namespace Apache.Arrow.Ipc
return ArrowBuffer.Empty;
}
+ if (buffer.Offset < 0 || buffer.Offset > int.MaxValue ||
+ buffer.Length < 0 || buffer.Length > int.MaxValue ||
+ buffer.Length + buffer.Offset > bodyData.Length)
+ {
+ throw new InvalidDataException(
+ $"IPC buffer range is out of range for a .NET buffer:
offset={buffer.Offset}, length={buffer.Length}, size={bodyData.Length}.");
+ }
+
int offset = (int)buffer.Offset;
int length = (int)buffer.Length;
diff --git a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index 22ed233..5e7c57e 100644
--- a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -14,12 +14,11 @@
// limitations under the License.
using System;
+using System.Buffers.Binary;
using System.IO;
-using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
-using Apache.Arrow.Memory;
using Apache.Arrow.Types;
using Xunit;
@@ -316,6 +315,163 @@ namespace Apache.Arrow.Tests
});
}
+ [Fact]
+ public void MalformedBodyLength_OverflowsInt32()
+ {
+ byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+ int messageStart = FindRecordBatchMessageStart(buffer);
+ int messageTablePos = ReadRootTablePos(buffer, messageStart);
+
+ // Message table vtable slot 10 = BodyLength
+ int bodyLengthPos = ReadFieldAbsolutePos(buffer, messageTablePos,
vtableSlot: 10);
+ Assert.True(ToInt64LittleEndian(buffer, bodyLengthPos) > 0);
+
+ WriteInt64LittleEndian(buffer, bodyLengthPos, (long)int.MaxValue +
1);
+
+ InvalidDataException ex = Assert.Throws<InvalidDataException>(
+ () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+ Assert.Contains("Message body", ex.Message);
+ Assert.Contains("out of range", ex.Message);
+ }
+
+ [Fact]
+ public void MalformedRecordBatchLength_OverflowsInt32()
+ {
+ byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+ int messageStart = FindRecordBatchMessageStart(buffer);
+ int messageTablePos = ReadRootTablePos(buffer, messageStart);
+
+ // Message.Header (slot 8) is a union pointing at the RecordBatch
table
+ int recordBatchTablePos = ReadUnionTablePos(buffer,
messageTablePos, vtableSlot: 8);
+
+ // RecordBatch table vtable slot 4 = Length (row count)
+ int lengthPos = ReadFieldAbsolutePos(buffer, recordBatchTablePos,
vtableSlot: 4);
+ Assert.Equal(3L, ToInt64LittleEndian(buffer, lengthPos));
+
+ WriteInt64LittleEndian(buffer, lengthPos, (long)int.MaxValue + 1);
+
+ InvalidDataException ex = Assert.Throws<InvalidDataException>(
+ () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+ Assert.Contains("out of range", ex.Message);
+ }
+
+ [Fact]
+ public void MalformedFieldNodeLength_OverflowsInt32()
+ {
+ byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+ int messageStart = FindRecordBatchMessageStart(buffer);
+ int messageTablePos = ReadRootTablePos(buffer, messageStart);
+ int recordBatchTablePos = ReadUnionTablePos(buffer,
messageTablePos, vtableSlot: 8);
+
+ // RecordBatch.Nodes (slot 6) is a vector of 16-byte FieldNode
structs
+ // where the first 8 bytes are Length and the next 8 bytes are
NullCount.
+ int nodesDataStart = ReadVectorDataStart(buffer,
recordBatchTablePos, vtableSlot: 6);
+ Assert.Equal(3L, ToInt64LittleEndian(buffer, nodesDataStart));
+
+ WriteInt64LittleEndian(buffer, nodesDataStart, (long)int.MaxValue
+ 1);
+
+ InvalidDataException ex = Assert.Throws<InvalidDataException>(
+ () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+ Assert.Contains("Field length", ex.Message);
+ }
+
+ [Fact]
+ public void MalformedBufferLength_OverflowsInt32()
+ {
+ byte[] buffer = BuildSimpleInt32Batch(rowCount: 3);
+ int messageStart = FindRecordBatchMessageStart(buffer);
+ int messageTablePos = ReadRootTablePos(buffer, messageStart);
+ int recordBatchTablePos = ReadUnionTablePos(buffer,
messageTablePos, vtableSlot: 8);
+
+ // RecordBatch.Buffers (slot 8) is a vector of 16-byte Buffer
structs
+ // (8 bytes Offset, 8 bytes Length). Find the first buffer with
non-zero
+ // length and corrupt its Length field.
+ int buffersDataStart = ReadVectorDataStart(buffer,
recordBatchTablePos, vtableSlot: 8);
+ int buffersLength = ToInt32LittleEndian(buffer, buffersDataStart -
4);
+ int targetLengthPos = -1;
+ for (int i = 0; i < buffersLength; i++)
+ {
+ int lengthPos = buffersDataStart + i * 16 + 8;
+ if (ToInt64LittleEndian(buffer, lengthPos) > 0)
+ {
+ targetLengthPos = lengthPos;
+ break;
+ }
+ }
+ Assert.NotEqual(-1, targetLengthPos);
+
+ WriteInt64LittleEndian(buffer, targetLengthPos, (long)int.MaxValue
+ 1);
+
+ InvalidDataException ex = Assert.Throws<InvalidDataException>(
+ () => new ArrowStreamReader(buffer).ReadNextRecordBatch());
+ Assert.Contains("IPC buffer range", ex.Message);
+ }
+
+ private static byte[] BuildSimpleInt32Batch(int rowCount)
+ {
+ Schema schema = new(
+ [new Field("x", Int32Type.Default, nullable: true)],
+ metadata: []);
+ Int32Array.Builder arrayBuilder = new();
+ for (int i = 0; i < rowCount; i++)
+ {
+ arrayBuilder.Append(i);
+ }
+ RecordBatch batch = new(schema, [arrayBuilder.Build()], rowCount);
+
+ using MemoryStream stream = new();
+ using (ArrowStreamWriter writer = new(stream, schema, leaveOpen:
true))
+ {
+ writer.WriteRecordBatch(batch);
+ writer.WriteEnd();
+ }
+ return stream.ToArray();
+ }
+
+ private static int FindRecordBatchMessageStart(byte[] buffer)
+ {
+ // Stream layout: [continuation(0xFFFFFFFF)][len][schema
message][continuation][len][batch message][body]...
+ int pos = 0;
+ Assert.Equal(-1, ToInt32LittleEndian(buffer, pos)); pos += 4;
+ int schemaLen = ToInt32LittleEndian(buffer, pos); pos += 4;
+ pos += schemaLen;
+ Assert.Equal(-1, ToInt32LittleEndian(buffer, pos)); pos += 4;
+ pos += 4; // batch message length prefix
+ return pos;
+ }
+
+ private static int ReadRootTablePos(byte[] buffer, int messageStart)
+ {
+ return messageStart + ToInt32LittleEndian(buffer, messageStart);
+ }
+
+ private static int ReadFieldAbsolutePos(byte[] buffer, int tablePos,
int vtableSlot)
+ {
+ int vtable = tablePos - ToInt32LittleEndian(buffer, tablePos);
+ short fieldOffset = ToInt16LittleEndian(buffer, vtable +
vtableSlot);
+ Assert.NotEqual(0, fieldOffset); // field must be present in the
vtable
+ return tablePos + fieldOffset;
+ }
+
+ private static int ReadUnionTablePos(byte[] buffer, int tablePos, int
vtableSlot)
+ {
+ int unionPtrPos = ReadFieldAbsolutePos(buffer, tablePos,
vtableSlot);
+ return unionPtrPos + ToInt32LittleEndian(buffer, unionPtrPos);
+ }
+
+ private static int ReadVectorDataStart(byte[] buffer, int tablePos,
int vtableSlot)
+ {
+ int vectorPtrPos = ReadFieldAbsolutePos(buffer, tablePos,
vtableSlot);
+ int vectorLengthPos = vectorPtrPos + ToInt32LittleEndian(buffer,
vectorPtrPos);
+ return vectorLengthPos + 4; // skip the 4-byte vector length prefix
+ }
+
+ private static void WriteInt64LittleEndian(byte[] buffer, int offset,
long value)
+ {
+ System.Buffers.Binary.BinaryPrimitives.WriteInt64LittleEndian(
+ buffer.AsSpan(offset), value);
+ }
+
[Fact]
public async Task EmptyStreamNoSyncRead()
{
@@ -327,6 +483,21 @@ namespace Apache.Arrow.Tests
}
}
+ private static short ToInt16LittleEndian(byte[] buffer, int offset)
+ {
+ return
BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
+ private static int ToInt32LittleEndian(byte[] buffer, int offset)
+ {
+ return
BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
+ private static long ToInt64LittleEndian(byte[] buffer, int offset)
+ {
+ return
BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
private class EmptyAsyncOnlyStream : Stream
{
public override bool CanRead => true;