This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new c21cbf12a feat(csharp/src/Client): add support for parameterized 
execution (#2096)
c21cbf12a is described below

commit c21cbf12ab7be98605927ec4166213c99c2981dc
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri Aug 23 05:35:27 2024 -0700

    feat(csharp/src/Client): add support for parameterized execution (#2096)
    
    Add support for parameterized execution through ADO.NET wrapper.
    
    Resolves #2095
---
 csharp/src/Client/AdbcCommand.cs                   | 217 ++++++++++++++++++++-
 csharp/src/Client/AdbcParameter.cs                 |  44 +++++
 csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs |   7 +-
 .../test/Drivers/Interop/Snowflake/ClientTests.cs  |  22 +++
 .../test/Drivers/Interop/Snowflake/DriverTests.cs  |  18 ++
 5 files changed, 298 insertions(+), 10 deletions(-)

diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index 5b3ed7c24..f76c246cc 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -16,9 +16,13 @@
 */
 
 using System;
+using System.Collections;
+using System.Collections.Generic;
 using System.Data;
 using System.Data.Common;
+using System.Linq;
 using System.Threading.Tasks;
+using Apache.Arrow.Types;
 
 namespace Apache.Arrow.Adbc.Client
 {
@@ -27,7 +31,8 @@ namespace Apache.Arrow.Adbc.Client
     /// </summary>
     public sealed class AdbcCommand : DbCommand
     {
-        private AdbcStatement adbcStatement;
+        private AdbcStatement _adbcStatement;
+        private AdbcParameterCollection? _dbParameterCollection;
         private int _timeout = 30;
         private bool _disposed;
 
@@ -45,7 +50,7 @@ namespace Apache.Arrow.Adbc.Client
 
             this.DbConnection = adbcConnection;
             this.DecimalBehavior = adbcConnection.DecimalBehavior;
-            this.adbcStatement = adbcConnection.CreateStatement();
+            this._adbcStatement = adbcConnection.CreateStatement();
         }
 
         /// <summary>
@@ -61,7 +66,7 @@ namespace Apache.Arrow.Adbc.Client
             if (adbcConnection == null)
                 throw new ArgumentNullException(nameof(adbcConnection));
 
-            this.adbcStatement = adbcConnection.CreateStatement();
+            this._adbcStatement = adbcConnection.CreateStatement();
             this.CommandText = query;
 
             this.DbConnection = adbcConnection;
@@ -71,7 +76,7 @@ namespace Apache.Arrow.Adbc.Client
         // For testing
         internal AdbcCommand(AdbcStatement adbcStatement, AdbcConnection 
adbcConnection)
         {
-            this.adbcStatement = adbcStatement;
+            this._adbcStatement = adbcStatement;
             this.DbConnection = adbcConnection;
             this.DecimalBehavior = adbcConnection.DecimalBehavior;
         }
@@ -80,7 +85,7 @@ namespace Apache.Arrow.Adbc.Client
         /// Gets the <see cref="AdbcStatement"/> associated with
         /// this <see cref="AdbcCommand"/>.
         /// </summary>
-        public AdbcStatement AdbcStatement => _disposed ? throw new 
ObjectDisposedException(nameof(AdbcCommand)) : this.adbcStatement;
+        public AdbcStatement AdbcStatement => _disposed ? throw new 
ObjectDisposedException(nameof(AdbcCommand)) : this._adbcStatement;
 
         public DecimalBehavior DecimalBehavior { get; set; }
 
@@ -114,6 +119,18 @@ namespace Apache.Arrow.Adbc.Client
             set => _timeout = value;
         }
 
+        protected override DbParameterCollection DbParameterCollection
+        {
+            get
+            {
+                if (_dbParameterCollection == null)
+                {
+                    _dbParameterCollection = new AdbcParameterCollection();
+                }
+                return _dbParameterCollection;
+            }
+        }
+
         /// <summary>
         /// Gets or sets the Substrait plan used by the command.
         /// </summary>
@@ -127,6 +144,7 @@ namespace Apache.Arrow.Adbc.Client
 
         public override int ExecuteNonQuery()
         {
+            BindParameters();
             return Convert.ToInt32(AdbcStatement.ExecuteUpdate().AffectedRows);
         }
 
@@ -137,6 +155,7 @@ namespace Apache.Arrow.Adbc.Client
         /// <returns></returns>
         public long ExecuteUpdate()
         {
+            BindParameters();
             return AdbcStatement.ExecuteUpdate().AffectedRows;
         }
 
@@ -146,6 +165,7 @@ namespace Apache.Arrow.Adbc.Client
         /// <returns><see cref="Result"></returns>
         public QueryResult ExecuteQuery()
         {
+            BindParameters();
             QueryResult executed = AdbcStatement.ExecuteQuery();
 
             return executed;
@@ -195,13 +215,143 @@ namespace Apache.Arrow.Adbc.Client
             if (disposing && !_disposed)
             {
                 // TODO: ensure not in the middle of pulling
-                this.adbcStatement.Dispose();
+                this._adbcStatement.Dispose();
                 _disposed = true;
             }
 
             base.Dispose(disposing);
         }
 
+        private void BindParameters()
+        {
+            if (_dbParameterCollection?.Count > 0)
+            {
+                Field[] fields = new Field[_dbParameterCollection.Count];
+                IArrowArray[] parameters = new 
IArrowArray[_dbParameterCollection.Count];
+                for (int i = 0; i < fields.Length; i++)
+                {
+                    AdbcParameter param = 
(AdbcParameter)_dbParameterCollection[i];
+                    ArrowType type;
+                    switch (param.DbType)
+                    {
+                        case DbType.Binary:
+                            type = BinaryType.Default;
+                            var binaryBuilder = new BinaryArray.Builder();
+                            if (param.Value == null)
+                            {
+                                binaryBuilder.AppendNull();
+                            }
+                            else
+                            {
+                                
binaryBuilder.Append(((byte[])param.Value).AsSpan());
+                            }
+                            parameters[i] = binaryBuilder.Build();
+                            break;
+                        case DbType.Boolean:
+                            type = BooleanType.Default;
+                            var boolBuilder = new BooleanArray.Builder();
+                            if (param.Value == null)
+                            {
+                                boolBuilder.AppendNull();
+                            }
+                            else
+                            {
+                                boolBuilder.Append((bool)param.Value);
+                            }
+                            parameters[i] = boolBuilder.Build();
+                            break;
+                        case DbType.Byte:
+                            type = UInt8Type.Default;
+                            parameters[i] = new 
UInt8Array.Builder().Append((byte?)param.Value).Build();
+                            break;
+                        case DbType.Date:
+                            type = Date32Type.Default;
+                            var dateBuilder = new Date32Array.Builder();
+                            if (param.Value == null)
+                            {
+                                dateBuilder.AppendNull();
+                            }
+#if NET5_0_OR_GREATER
+                            else if (param.Value is DateOnly)
+                            {
+                                dateBuilder.Append((DateOnly)param.Value);
+                            }
+#endif
+                            else
+                            {
+                                dateBuilder.Append((DateTime)param.Value);
+                            }
+                            parameters[i] = dateBuilder.Build();
+                            break;
+                        case DbType.DateTime:
+                            type = TimestampType.Default;
+                            var timestampBuilder = new 
TimestampArray.Builder();
+                            if (param.Value == null)
+                            {
+                                timestampBuilder.AppendNull();
+                            }
+                            else
+                            {
+                                timestampBuilder.Append((DateTime)param.Value);
+                            }
+                            break;
+                        // TODO: case DbType.Decimal:
+                        case DbType.Double:
+                            type = DoubleType.Default;
+                            parameters[i] = new 
DoubleArray.Builder().Append((double?)param.Value).Build();
+                            break;
+                        case DbType.Int16:
+                            type = Int16Type.Default;
+                            parameters[i] = new 
Int16Array.Builder().Append((short?)param.Value).Build();
+                            break;
+                        case DbType.Int32:
+                            type = Int32Type.Default;
+                            parameters[i] = new 
Int32Array.Builder().Append((int?)param.Value).Build();
+                            break;
+                        case DbType.Int64:
+                            type = Int64Type.Default;
+                            parameters[i] = new 
Int64Array.Builder().Append((long?)param.Value).Build();
+                            break;
+                        case DbType.SByte:
+                            type = Int8Type.Default;
+                            parameters[i] = new 
Int8Array.Builder().Append((sbyte?)param.Value).Build();
+                            break;
+                        case DbType.Single:
+                            type = FloatType.Default;
+                            parameters[i] = new 
FloatArray.Builder().Append((float?)param.Value).Build();
+                            break;
+                        case DbType.String:
+                            type = StringType.Default;
+                            parameters[i] = new 
StringArray.Builder().Append((string)param.Value!).Build();
+                            break;
+                        // TODO: case DbType.Time:
+                        case DbType.UInt16:
+                            type = UInt16Type.Default;
+                            parameters[i] = new 
UInt16Array.Builder().Append((ushort?)param.Value).Build();
+                            break;
+                        case DbType.UInt32:
+                            type = UInt32Type.Default;
+                            parameters[i] = new 
UInt32Array.Builder().Append((uint?)param.Value).Build();
+                            break;
+                        case DbType.UInt64:
+                            type = UInt64Type.Default;
+                            parameters[i] = new 
UInt64Array.Builder().Append((ulong?)param.Value).Build();
+                            break;
+                        default:
+                            throw new NotSupportedException($"Parameters of 
type {param.DbType} are not supported");
+                    }
+
+                    fields[i] = new Field(
+                        string.IsNullOrWhiteSpace(param.ParameterName) ? 
Guid.NewGuid().ToString() : param.ParameterName,
+                        type,
+                        param.IsNullable || param.Value == null);
+                }
+
+                Schema schema = new Schema(fields, null);
+                AdbcStatement.Bind(new RecordBatch(schema, parameters, 1), 
schema);
+            }
+        }
+
 #if NET5_0_OR_GREATER
         public override ValueTask DisposeAsync()
         {
@@ -214,8 +364,6 @@ namespace Apache.Arrow.Adbc.Client
 
         public override UpdateRowSource UpdatedRowSource { get => throw new 
NotImplementedException(); set => throw new NotImplementedException(); }
 
-        protected override DbParameterCollection DbParameterCollection => 
throw new NotImplementedException();
-
         protected override DbTransaction? DbTransaction { get => throw new 
NotImplementedException(); set => throw new NotImplementedException(); }
 
         public override void Cancel()
@@ -235,9 +383,60 @@ namespace Apache.Arrow.Adbc.Client
 
         protected override DbParameter CreateDbParameter()
         {
-            throw new NotImplementedException();
+            return new AdbcParameter();
         }
 
         #endregion
+
+        private class AdbcParameterCollection : DbParameterCollection
+        {
+            readonly List<AdbcParameter> _parameters = new 
List<AdbcParameter>();
+
+            public override int Count => _parameters.Count;
+
+            public override object SyncRoot => throw new 
NotImplementedException();
+
+            public override int Add(object value)
+            {
+                int result = _parameters.Count;
+                _parameters.Add((AdbcParameter)value);
+                return result;
+            }
+
+            public override void AddRange(System.Array values) => 
_parameters.AddRange(values.Cast<AdbcParameter>());
+            public override void Clear() => _parameters.Clear();
+            public override bool Contains(object value) => 
_parameters.Contains((AdbcParameter)value);
+            public override bool Contains(string value) => IndexOf(value) >= 0;
+            public override void CopyTo(System.Array array, int index) => 
throw new NotImplementedException();
+            public override IEnumerator GetEnumerator() => 
_parameters.GetEnumerator();
+            public override int IndexOf(object value) => 
_parameters.IndexOf((AdbcParameter)value);
+            public override int IndexOf(string parameterName) => 
GetParameterIndex(parameterName, throwOnFailure: false);
+            public override void Insert(int index, object value) => 
_parameters.Insert(index, (AdbcParameter)value);
+            public override void Remove(object value) => 
_parameters.Remove((AdbcParameter)value);
+            public override void RemoveAt(int index) => 
_parameters.RemoveAt(index);
+            public override void RemoveAt(string parameterName) => 
_parameters.RemoveAt(GetParameterIndex(parameterName));
+            protected override DbParameter GetParameter(int index) => 
_parameters[index];
+            protected override DbParameter GetParameter(string parameterName) 
=> _parameters[GetParameterIndex(parameterName)];
+            protected override void SetParameter(int index, DbParameter value) 
=> _parameters[index] = (AdbcParameter)value;
+            protected override void SetParameter(string parameterName, 
DbParameter value) => throw new NotImplementedException();
+
+            private int GetParameterIndex(string parameterName, bool 
throwOnFailure = true)
+            {
+                for (int i = 0; i < _parameters.Count; i++)
+                {
+                    if (parameterName == _parameters[i].ParameterName)
+                    {
+                        return i;
+                    }
+                }
+
+                if (throwOnFailure)
+                {
+                    throw new IndexOutOfRangeException("parameterName not 
found");
+                }
+
+                return -1;
+            }
+        }
     }
 }
diff --git a/csharp/src/Client/AdbcParameter.cs 
b/csharp/src/Client/AdbcParameter.cs
new file mode 100644
index 000000000..620b921c5
--- /dev/null
+++ b/csharp/src/Client/AdbcParameter.cs
@@ -0,0 +1,44 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Data.Common;
+using System.Data;
+using System.Diagnostics.CodeAnalysis;
+using System;
+
+namespace Apache.Arrow.Adbc.Client
+{
+    sealed public class AdbcParameter : DbParameter
+    {
+        public override DbType DbType { get; set; }
+        public override ParameterDirection Direction { get => 
ParameterDirection.Input; set => throw new NotImplementedException(); }
+        public override bool IsNullable { get; set; } = true;
+#if NET5_0_OR_GREATER
+        [AllowNull]
+#endif
+        public override string ParameterName { get; set; } = string.Empty;
+        public override int Size { get => throw new NotImplementedException(); 
set => throw new NotImplementedException(); }
+#if NET5_0_OR_GREATER
+        [AllowNull]
+#endif
+        public override string SourceColumn { get => throw new 
NotImplementedException(); set => throw new NotImplementedException(); }
+        public override bool SourceColumnNullMapping { get => throw new 
NotImplementedException(); set => throw new NotImplementedException(); }
+        public override object? Value { get; set; }
+
+        public override void ResetDbType() => throw new 
NotImplementedException();
+    }
+}
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs 
b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
index 1d947062d..d5a3b117c 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
@@ -94,7 +94,11 @@ namespace Apache.Arrow.Adbc.Tests
         /// </summary>
         /// <param name="adbcConnection">The <see 
cref="Adbc.Client.AdbcConnection"/> to use.</param>
         /// <param name="testConfiguration">The <see 
cref="TestConfiguration"/> to use</param>
-        public static void CanClientExecuteQuery(Adbc.Client.AdbcConnection 
adbcConnection, TestConfiguration testConfiguration)
+        /// <param name="additionalCommandOptionsSetter">Allows additional 
options to be set on the command before execution</param>
+        public static void CanClientExecuteQuery(
+            Adbc.Client.AdbcConnection adbcConnection,
+            TestConfiguration testConfiguration,
+            Action<AdbcCommand>? additionalCommandOptionsSetter = null)
         {
             if (adbcConnection == null) throw new 
ArgumentNullException(nameof(adbcConnection));
             if (testConfiguration == null) throw new 
ArgumentNullException(nameof(testConfiguration));
@@ -104,6 +108,7 @@ namespace Apache.Arrow.Adbc.Tests
             adbcConnection.Open();
 
             using AdbcCommand adbcCommand = new 
AdbcCommand(testConfiguration.Query, adbcConnection);
+            additionalCommandOptionsSetter?.Invoke(adbcCommand);
             using AdbcDataReader reader = adbcCommand.ExecuteReader();
 
             try
diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
index 04cf0d1eb..942a78fe7 100644
--- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
@@ -142,6 +142,28 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             }
         }
 
+        // <summary>
+        /// Validates if the client can connect to a live server and execute a 
parameterized query.
+        /// </summary>
+        [SkippableFact, Order(4)]
+        public void CanClientExecuteParameterizedQuery()
+        {
+            SnowflakeTestConfiguration testConfiguration = 
Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
+            testConfiguration.Query = "SELECT * FROM (SELECT column1 FROM 
(VALUES (1), (2), (3))) WHERE column1 < ?";
+            testConfiguration.ExpectedResultsCount = 1;
+
+            using (Adbc.Client.AdbcConnection adbcConnection = 
GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration))
+            {
+                Tests.ClientTests.CanClientExecuteQuery(adbcConnection, 
testConfiguration, command =>
+                {
+                    DbParameter parameter1 = command.CreateParameter();
+                    parameter1.Value = 2;
+                    parameter1.DbType = DbType.Int32;
+                    command.Parameters.Add(parameter1);
+                });
+            }
+        }
+
         // <summary>
         /// Validates if the client can connect to a live server
         /// and parse the results.
diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
index 4e9fbb4f4..45ce416f7 100644
--- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
@@ -399,6 +399,24 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             Tests.DriverTests.CanExecuteQuery(queryResult, 
_testConfiguration.ExpectedResultsCount);
         }
 
+        /// <summary>
+        /// Validates if the driver can connect to a live server and execute a 
parameterized query.
+        /// </summary>
+        [SkippableFact, Order(6)]
+        public void CanExecuteParameterizedQuery()
+        {
+            using AdbcStatement statement = _connection.CreateStatement();
+            statement.SqlQuery = "SELECT * FROM (SELECT column1 FROM (VALUES 
(1), (2), (3))) WHERE column1 < ?";
+
+            Schema parameterSchema = new Schema(new[] { new Field("column1", 
Int32Type.Default, false) }, null);
+            RecordBatch parameters = new RecordBatch(parameterSchema, new[] { 
new Int32Array.Builder().Append(2).Build() }, 1);
+            statement.Bind(parameters, parameterSchema);
+
+            QueryResult queryResult = statement.ExecuteQuery();
+
+            Tests.DriverTests.CanExecuteQuery(queryResult, 1);
+        }
+
         [SkippableFact, Order(7)]
         public void CanIngestData()
         {

Reply via email to