This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1633f8d028 GH-38351: [C#] Add SqlDecimal support to Decimal128Array 
(#38481)
1633f8d028 is described below

commit 1633f8d0284bfefa2e50bbd1378c46acfbc5ceab
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri Oct 27 15:13:33 2023 -0700

    GH-38351: [C#] Add SqlDecimal support to Decimal128Array (#38481)
    
    ### What changes are included in this PR?
    
    Adds support for reading and writing System.Data.SqlTypes.SqlDecimal 
against Decimal128Array.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Adds functions to the API.
    * Closes: #38351
    
    Authored-by: Curt Hagenlocher <[email protected]>
    Signed-off-by: Curt Hagenlocher <[email protected]>
---
 csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs  |  40 ++++++
 csharp/src/Apache.Arrow/DecimalUtility.cs          |  49 ++++++++
 .../Apache.Arrow.Tests/Decimal128ArrayTests.cs     | 139 +++++++++++++++++++--
 .../test/Apache.Arrow.Tests/DecimalUtilityTests.cs |  61 ++++++++-
 4 files changed, 278 insertions(+), 11 deletions(-)

diff --git a/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs 
b/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs
index 128e9e5f08..7b147f5124 100644
--- a/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs
+++ b/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs
@@ -15,6 +15,9 @@
 
 using System;
 using System.Collections.Generic;
+#if !NETSTANDARD1_3
+using System.Data.SqlTypes;
+#endif
 using System.Diagnostics;
 using System.Numerics;
 using Apache.Arrow.Arrays;
@@ -61,6 +64,31 @@ namespace Apache.Arrow
                 return Instance;
             }
 
+#if !NETSTANDARD1_3
+            public Builder Append(SqlDecimal value)
+            {
+                Span<byte> bytes = stackalloc byte[DataType.ByteWidth];
+                DecimalUtility.GetBytes(value, DataType.Precision, 
DataType.Scale, bytes);
+
+                return Append(bytes);
+            }
+
+            public Builder AppendRange(IEnumerable<SqlDecimal> values)
+            {
+                if (values == null)
+                {
+                    throw new ArgumentNullException(nameof(values));
+                }
+
+                foreach (SqlDecimal d in values)
+                {
+                    Append(d);
+                }
+
+                return Instance;
+            }
+#endif
+
             public Builder Set(int index, decimal value)
             {
                 Span<byte> bytes = stackalloc byte[DataType.ByteWidth];
@@ -91,5 +119,17 @@ namespace Apache.Arrow
             }
             return DecimalUtility.GetDecimal(ValueBuffer, index, Scale, 
ByteWidth);
         }
+
+#if !NETSTANDARD1_3
+        public SqlDecimal? GetSqlDecimal(int index)
+        {
+            if (IsNull(index))
+            {
+                return null;
+            }
+
+            return DecimalUtility.GetSqlDecimal128(ValueBuffer, index, 
Precision, Scale);
+        }
+#endif
     }
 }
diff --git a/csharp/src/Apache.Arrow/DecimalUtility.cs 
b/csharp/src/Apache.Arrow/DecimalUtility.cs
index 4a29d068c6..35e56ff65e 100644
--- a/csharp/src/Apache.Arrow/DecimalUtility.cs
+++ b/csharp/src/Apache.Arrow/DecimalUtility.cs
@@ -14,6 +14,9 @@
 // limitations under the License.
 
 using System;
+#if !NETSTANDARD1_3
+using System.Data.SqlTypes;
+#endif
 using System.Numerics;
 
 namespace Apache.Arrow
@@ -73,6 +76,32 @@ namespace Apache.Arrow
             }
         }
 
+#if !NETSTANDARD1_3
+        internal static SqlDecimal GetSqlDecimal128(in ArrowBuffer 
valueBuffer, int index, int precision, int scale)
+        {
+            const int byteWidth = 16;
+            const int intWidth = byteWidth / 4;
+            const int longWidth = byteWidth / 8;
+
+            byte mostSignificantByte = valueBuffer.Span[(index + 1) * 
byteWidth - 1];
+            bool isPositive = (mostSignificantByte & 0x80) == 0;
+
+            if (isPositive)
+            {
+                ReadOnlySpan<int> value = 
valueBuffer.Span.CastTo<int>().Slice(index * intWidth, intWidth);
+                return new SqlDecimal((byte)precision, (byte)scale, true, 
value[0], value[1], value[2], value[3]);
+            }
+            else
+            {
+                ReadOnlySpan<long> value = 
valueBuffer.Span.CastTo<long>().Slice(index * longWidth, longWidth);
+                long data1 = -value[0];
+                long data2 = (data1 == 0) ? -value[1] : ~value[1];
+
+                return new SqlDecimal((byte)precision, (byte)scale, false, 
(int)(data1 & 0xffffffff), (int)(data1 >> 32), (int)(data2 & 0xffffffff), 
(int)(data2 >> 32));
+            }
+        }
+#endif
+
         private static decimal DivideByScale(BigInteger integerValue, int 
scale)
         {
             decimal result = (decimal)integerValue; // this cast is safe here
@@ -169,5 +198,25 @@ namespace Apache.Arrow
                 }
             }
         }
+
+#if !NETSTANDARD1_3
+        internal static void GetBytes(SqlDecimal value, int precision, int 
scale, Span<byte> bytes)
+        {
+            if (value.Precision != precision || value.Scale != scale)
+            {
+                value = SqlDecimal.ConvertToPrecScale(value, precision, scale);
+            }
+
+            // TODO: Consider groveling in the internals to avoid the probable 
allocation
+            Span<int> span = bytes.CastTo<int>();
+            value.Data.AsSpan().CopyTo(span);
+            if (!value.IsPositive)
+            {
+                Span<long> longSpan = bytes.CastTo<long>();
+                longSpan[0] = -longSpan[0];
+                longSpan[1] = (longSpan[0] == 0) ? -longSpan[1] : ~longSpan[1];
+            }
+        }
+#endif
     }
 }
diff --git a/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs 
b/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs
index 4c4e653726..8d7adfef42 100644
--- a/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs
@@ -14,7 +14,9 @@
 // limitations under the License.
 
 using System;
-using System.Collections.Generic;
+#if !NETSTANDARD1_3
+using System.Data.SqlTypes;
+#endif
 using Apache.Arrow.Types;
 using Xunit;
 
@@ -22,6 +24,18 @@ namespace Apache.Arrow.Tests
 {
     public class Decimal128ArrayTests
     {
+#if !NETSTANDARD1_3
+        static SqlDecimal? Convert(decimal? value)
+        {
+            return value == null ? null : new SqlDecimal(value.Value);
+        }
+
+        static decimal? Convert(SqlDecimal? value)
+        {
+            return value == null ? null : value.Value.Value;
+        }
+#endif
+
         public class Builder
         {
             public class AppendNull
@@ -30,7 +44,7 @@ namespace Apache.Arrow.Tests
                 public void AppendThenGetGivesNull()
                 {
                     // Arrange
-                    var builder = new Decimal128Array.Builder(new 
Decimal128Type(8,2));
+                    var builder = new Decimal128Array.Builder(new 
Decimal128Type(8, 2));
 
                     // Act
 
@@ -45,6 +59,12 @@ namespace Apache.Arrow.Tests
                     Assert.Null(array.GetValue(0));
                     Assert.Null(array.GetValue(1));
                     Assert.Null(array.GetValue(2));
+
+#if !NETSTANDARD1_3
+                    Assert.Null(array.GetSqlDecimal(0));
+                    Assert.Null(array.GetSqlDecimal(1));
+                    Assert.Null(array.GetSqlDecimal(2));
+#endif
                 }
             }
 
@@ -67,7 +87,7 @@ namespace Apache.Arrow.Tests
                             testData[i] = null;
                             continue;
                         }
-                        decimal rnd = i * (decimal)Math.Round(new 
Random().NextDouble(),10);
+                        decimal rnd = i * (decimal)Math.Round(new 
Random().NextDouble(), 10);
                         testData[i] = rnd;
                         builder.Append(rnd);
                     }
@@ -78,6 +98,9 @@ namespace Apache.Arrow.Tests
                     for (int i = 0; i < count; i++)
                     {
                         Assert.Equal(testData[i], array.GetValue(i));
+#if !NETSTANDARD1_3
+                        Assert.Equal(Convert(testData[i]), 
array.GetSqlDecimal(i));
+#endif
                     }
                 }
 
@@ -95,6 +118,11 @@ namespace Apache.Arrow.Tests
                     var array = builder.Build();
                     Assert.Equal(large, array.GetValue(0));
                     Assert.Equal(-large, array.GetValue(1));
+
+#if !NETSTANDARD1_3
+                    Assert.Equal(Convert(large), array.GetSqlDecimal(0));
+                    Assert.Equal(Convert(-large), array.GetSqlDecimal(1));
+#endif
                 }
 
                 [Fact]
@@ -115,6 +143,13 @@ namespace Apache.Arrow.Tests
                     Assert.Equal(Decimal.MinValue, array.GetValue(1));
                     Assert.Equal(Decimal.MaxValue - 10, array.GetValue(2));
                     Assert.Equal(Decimal.MinValue + 10, array.GetValue(3));
+
+#if !NETSTANDARD1_3
+                    Assert.Equal(Convert(Decimal.MaxValue), 
array.GetSqlDecimal(0));
+                    Assert.Equal(Convert(Decimal.MinValue), 
array.GetSqlDecimal(1));
+                    Assert.Equal(Convert(Decimal.MaxValue) - 10, 
array.GetSqlDecimal(2));
+                    Assert.Equal(Convert(Decimal.MinValue) + 10, 
array.GetSqlDecimal(3));
+#endif
                 }
 
                 [Fact]
@@ -131,6 +166,11 @@ namespace Apache.Arrow.Tests
                     var array = builder.Build();
                     Assert.Equal(fraction, array.GetValue(0));
                     Assert.Equal(-fraction, array.GetValue(1));
+
+#if !NETSTANDARD1_3
+                    Assert.Equal(Convert(fraction), array.GetSqlDecimal(0));
+                    Assert.Equal(Convert(-fraction), array.GetSqlDecimal(1));
+#endif
                 }
 
                 [Fact]
@@ -138,7 +178,7 @@ namespace Apache.Arrow.Tests
                 {
                     // Arrange
                     var builder = new Decimal128Array.Builder(new 
Decimal128Type(24, 8));
-                    var range = new decimal[] {2.123M, 1.5984M, -0.0000001M, 
9878987987987987.1235407M};
+                    var range = new decimal[] { 2.123M, 1.5984M, -0.0000001M, 
9878987987987987.1235407M };
 
                     // Act
                     builder.AppendRange(range);
@@ -146,12 +186,15 @@ namespace Apache.Arrow.Tests
 
                     // Assert
                     var array = builder.Build();
-                    for(int i = 0; i < range.Length; i ++)
+                    for (int i = 0; i < range.Length; i++)
                     {
                         Assert.Equal(range[i], array.GetValue(i));
+#if !NETSTANDARD1_3
+                        Assert.Equal(Convert(range[i]), 
array.GetSqlDecimal(i));
+#endif
                     }
-                    
-                    Assert.Null( array.GetValue(range.Length));
+
+                    Assert.Null(array.GetValue(range.Length));
                 }
 
                 [Fact]
@@ -159,7 +202,7 @@ namespace Apache.Arrow.Tests
                 {
                     // Arrange
                     var builder = new Decimal128Array.Builder(new 
Decimal128Type(24, 8));
-                    
+
                     // Act
                     builder.Append(1);
                     builder.Clear();
@@ -256,6 +299,86 @@ namespace Apache.Arrow.Tests
                     Assert.Equal(123.456M, array.GetValue(1));
                 }
             }
+
+#if !NETSTANDARD1_3
+            public class SqlDecimals
+            {
+                [Theory]
+                [InlineData(200)]
+                public void AppendSqlDecimal(int count)
+                {
+                    // Arrange
+                    const int precision = 10;
+                    var builder = new Decimal128Array.Builder(new 
Decimal128Type(14, precision));
+
+                    // Act
+                    SqlDecimal?[] testData = new SqlDecimal?[count];
+                    for (int i = 0; i < count; i++)
+                    {
+                        if (i == count - 2)
+                        {
+                            builder.AppendNull();
+                            testData[i] = null;
+                            continue;
+                        }
+                        SqlDecimal rnd = i * (SqlDecimal)Math.Round(new 
Random().NextDouble(), 10);
+                        builder.Append(rnd);
+                        testData[i] = SqlDecimal.Round(rnd, precision);
+                    }
+
+                    // Assert
+                    var array = builder.Build();
+                    Assert.Equal(count, array.Length);
+                    for (int i = 0; i < count; i++)
+                    {
+                        Assert.Equal(testData[i], array.GetSqlDecimal(i));
+                        Assert.Equal(Convert(testData[i]), array.GetValue(i));
+                    }
+                }
+
+                [Fact]
+                public void AppendMaxAndMinSqlDecimal()
+                {
+                    // Arrange
+                    var builder = new Decimal128Array.Builder(new 
Decimal128Type(38, 0));
+
+                    // Act
+                    builder.Append(SqlDecimal.MaxValue);
+                    builder.Append(SqlDecimal.MinValue);
+                    builder.Append(SqlDecimal.MaxValue - 10);
+                    builder.Append(SqlDecimal.MinValue + 10);
+
+                    // Assert
+                    var array = builder.Build();
+                    Assert.Equal(SqlDecimal.MaxValue, array.GetSqlDecimal(0));
+                    Assert.Equal(SqlDecimal.MinValue, array.GetSqlDecimal(1));
+                    Assert.Equal(SqlDecimal.MaxValue - 10, 
array.GetSqlDecimal(2));
+                    Assert.Equal(SqlDecimal.MinValue + 10, 
array.GetSqlDecimal(3));
+                }
+
+                [Fact]
+                public void AppendRangeSqlDecimal()
+                {
+                    // Arrange
+                    var builder = new Decimal128Array.Builder(new 
Decimal128Type(24, 8));
+                    var range = new SqlDecimal[] { 2.123M, 1.5984M, 
-0.0000001M, 9878987987987987.1235407M };
+
+                    // Act
+                    builder.AppendRange(range);
+                    builder.AppendNull();
+
+                    // Assert
+                    var array = builder.Build();
+                    for (int i = 0; i < range.Length; i++)
+                    {
+                        Assert.Equal(range[i], array.GetSqlDecimal(i));
+                        Assert.Equal(Convert(range[i]), array.GetValue(i));
+                    }
+
+                    Assert.Null(array.GetValue(range.Length));
+                }
+            }
+#endif
         }
     }
 }
diff --git a/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs 
b/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs
index 9c7e5b587c..dd5f7b9d3f 100644
--- a/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs
@@ -14,6 +14,9 @@
 // limitations under the License.
 
 using System;
+#if !NETSTANDARD1_3
+using System.Data.SqlTypes;
+#endif
 using Apache.Arrow.Types;
 using Xunit;
 
@@ -31,13 +34,13 @@ namespace Apache.Arrow.Tests
             [InlineData(100.123, 5, 2, true)]
             [InlineData(100.123, 5, 3, true)]
             [InlineData(100.123, 6, 3, false)]
-            public void HasExpectedResultOrThrows(decimal d, int precision , 
int scale, bool shouldThrow)
+            public void HasExpectedResultOrThrows(decimal d, int precision, 
int scale, bool shouldThrow)
             {
                 var builder = new Decimal128Array.Builder(new 
Decimal128Type(precision, scale));
 
                 if (shouldThrow)
                 {
-                   Assert.Throws<OverflowException>(() => builder.Append(d));
+                    Assert.Throws<OverflowException>(() => builder.Append(d));
                 }
                 else
                 {
@@ -55,7 +58,7 @@ namespace Apache.Arrow.Tests
                 var builder = new Decimal256Array.Builder(new 
Decimal256Type(precision, scale));
                 builder.Append(d);
                 Decimal256Array result = builder.Build(new 
TestMemoryAllocator()); ;
-                
+
                 if (shouldThrow)
                 {
                     Assert.Throws<OverflowException>(() => result.GetValue(0));
@@ -66,5 +69,57 @@ namespace Apache.Arrow.Tests
                 }
             }
         }
+
+        public class SqlDecimals
+        {
+
+#if !NETSTANDARD1_3
+            [Fact]
+            public void NegativeSqlDecimal()
+            {
+                const int precision = 38;
+                const int scale = 0;
+                const int bitWidth = 16;
+
+                var negative = new SqlDecimal(precision, scale, false, 0, 0, 
1, 0);
+                var bytes = new byte[16];
+                DecimalUtility.GetBytes(negative.Value, precision, scale, 
bitWidth, bytes);
+                var sqlNegative = DecimalUtility.GetSqlDecimal128(new 
ArrowBuffer(bytes), 0, precision, scale);
+                Assert.Equal(negative, sqlNegative);
+
+                DecimalUtility.GetBytes(sqlNegative, precision, scale, bytes);
+                var decimalNegative = DecimalUtility.GetDecimal(new 
ArrowBuffer(bytes), 0, scale, bitWidth);
+                Assert.Equal(negative.Value, decimalNegative);
+            }
+
+            [Fact]
+            public void LargeScale()
+            {
+                string digits = "1.2345678901234567890123456789012345678";
+
+                var positive = SqlDecimal.Parse(digits);
+                Assert.Equal(38, positive.Precision);
+                Assert.Equal(37, positive.Scale);
+
+                var bytes = new byte[16];
+                DecimalUtility.GetBytes(positive, positive.Precision, 
positive.Scale, bytes);
+                var sqlPositive = DecimalUtility.GetSqlDecimal128(new 
ArrowBuffer(bytes), 0, positive.Precision, positive.Scale);
+
+                Assert.Equal(positive, sqlPositive);
+                Assert.Equal(digits, sqlPositive.ToString());
+
+                digits = "-" + digits;
+                var negative = SqlDecimal.Parse(digits);
+                Assert.Equal(38, positive.Precision);
+                Assert.Equal(37, positive.Scale);
+
+                DecimalUtility.GetBytes(negative, negative.Precision, 
negative.Scale, bytes);
+                var sqlNegative = DecimalUtility.GetSqlDecimal128(new 
ArrowBuffer(bytes), 0, negative.Precision, negative.Scale);
+
+                Assert.Equal(negative, sqlNegative);
+                Assert.Equal(digits, sqlNegative.ToString());
+            }
+#endif
+        }
     }
 }

Reply via email to