This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new c21cbf12a feat(csharp/src/Client): add support for parameterized
execution (#2096)
c21cbf12a is described below
commit c21cbf12ab7be98605927ec4166213c99c2981dc
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri Aug 23 05:35:27 2024 -0700
feat(csharp/src/Client): add support for parameterized execution (#2096)
Add support for parameterized execution through ADO.NET wrapper.
Resolves #2095
---
csharp/src/Client/AdbcCommand.cs | 217 ++++++++++++++++++++-
csharp/src/Client/AdbcParameter.cs | 44 +++++
csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs | 7 +-
.../test/Drivers/Interop/Snowflake/ClientTests.cs | 22 +++
.../test/Drivers/Interop/Snowflake/DriverTests.cs | 18 ++
5 files changed, 298 insertions(+), 10 deletions(-)
diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index 5b3ed7c24..f76c246cc 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -16,9 +16,13 @@
*/
using System;
+using System.Collections;
+using System.Collections.Generic;
using System.Data;
using System.Data.Common;
+using System.Linq;
using System.Threading.Tasks;
+using Apache.Arrow.Types;
namespace Apache.Arrow.Adbc.Client
{
@@ -27,7 +31,8 @@ namespace Apache.Arrow.Adbc.Client
/// </summary>
public sealed class AdbcCommand : DbCommand
{
- private AdbcStatement adbcStatement;
+ private AdbcStatement _adbcStatement;
+ private AdbcParameterCollection? _dbParameterCollection;
private int _timeout = 30;
private bool _disposed;
@@ -45,7 +50,7 @@ namespace Apache.Arrow.Adbc.Client
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
- this.adbcStatement = adbcConnection.CreateStatement();
+ this._adbcStatement = adbcConnection.CreateStatement();
}
/// <summary>
@@ -61,7 +66,7 @@ namespace Apache.Arrow.Adbc.Client
if (adbcConnection == null)
throw new ArgumentNullException(nameof(adbcConnection));
- this.adbcStatement = adbcConnection.CreateStatement();
+ this._adbcStatement = adbcConnection.CreateStatement();
this.CommandText = query;
this.DbConnection = adbcConnection;
@@ -71,7 +76,7 @@ namespace Apache.Arrow.Adbc.Client
// For testing
internal AdbcCommand(AdbcStatement adbcStatement, AdbcConnection
adbcConnection)
{
- this.adbcStatement = adbcStatement;
+ this._adbcStatement = adbcStatement;
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
}
@@ -80,7 +85,7 @@ namespace Apache.Arrow.Adbc.Client
/// Gets the <see cref="AdbcStatement"/> associated with
/// this <see cref="AdbcCommand"/>.
/// </summary>
- public AdbcStatement AdbcStatement => _disposed ? throw new
ObjectDisposedException(nameof(AdbcCommand)) : this.adbcStatement;
+ public AdbcStatement AdbcStatement => _disposed ? throw new
ObjectDisposedException(nameof(AdbcCommand)) : this._adbcStatement;
public DecimalBehavior DecimalBehavior { get; set; }
@@ -114,6 +119,18 @@ namespace Apache.Arrow.Adbc.Client
set => _timeout = value;
}
+ protected override DbParameterCollection DbParameterCollection
+ {
+ get
+ {
+ if (_dbParameterCollection == null)
+ {
+ _dbParameterCollection = new AdbcParameterCollection();
+ }
+ return _dbParameterCollection;
+ }
+ }
+
/// <summary>
/// Gets or sets the Substrait plan used by the command.
/// </summary>
@@ -127,6 +144,7 @@ namespace Apache.Arrow.Adbc.Client
public override int ExecuteNonQuery()
{
+ BindParameters();
return Convert.ToInt32(AdbcStatement.ExecuteUpdate().AffectedRows);
}
@@ -137,6 +155,7 @@ namespace Apache.Arrow.Adbc.Client
/// <returns></returns>
public long ExecuteUpdate()
{
+ BindParameters();
return AdbcStatement.ExecuteUpdate().AffectedRows;
}
@@ -146,6 +165,7 @@ namespace Apache.Arrow.Adbc.Client
/// <returns><see cref="Result"></returns>
public QueryResult ExecuteQuery()
{
+ BindParameters();
QueryResult executed = AdbcStatement.ExecuteQuery();
return executed;
@@ -195,13 +215,143 @@ namespace Apache.Arrow.Adbc.Client
if (disposing && !_disposed)
{
// TODO: ensure not in the middle of pulling
- this.adbcStatement.Dispose();
+ this._adbcStatement.Dispose();
_disposed = true;
}
base.Dispose(disposing);
}
+ private void BindParameters()
+ {
+ if (_dbParameterCollection?.Count > 0)
+ {
+ Field[] fields = new Field[_dbParameterCollection.Count];
+ IArrowArray[] parameters = new
IArrowArray[_dbParameterCollection.Count];
+ for (int i = 0; i < fields.Length; i++)
+ {
+ AdbcParameter param =
(AdbcParameter)_dbParameterCollection[i];
+ ArrowType type;
+ switch (param.DbType)
+ {
+ case DbType.Binary:
+ type = BinaryType.Default;
+ var binaryBuilder = new BinaryArray.Builder();
+ if (param.Value == null)
+ {
+ binaryBuilder.AppendNull();
+ }
+ else
+ {
+
binaryBuilder.Append(((byte[])param.Value).AsSpan());
+ }
+ parameters[i] = binaryBuilder.Build();
+ break;
+ case DbType.Boolean:
+ type = BooleanType.Default;
+ var boolBuilder = new BooleanArray.Builder();
+ if (param.Value == null)
+ {
+ boolBuilder.AppendNull();
+ }
+ else
+ {
+ boolBuilder.Append((bool)param.Value);
+ }
+ parameters[i] = boolBuilder.Build();
+ break;
+ case DbType.Byte:
+ type = UInt8Type.Default;
+ parameters[i] = new
UInt8Array.Builder().Append((byte?)param.Value).Build();
+ break;
+ case DbType.Date:
+ type = Date32Type.Default;
+ var dateBuilder = new Date32Array.Builder();
+ if (param.Value == null)
+ {
+ dateBuilder.AppendNull();
+ }
+#if NET5_0_OR_GREATER
+ else if (param.Value is DateOnly)
+ {
+ dateBuilder.Append((DateOnly)param.Value);
+ }
+#endif
+ else
+ {
+ dateBuilder.Append((DateTime)param.Value);
+ }
+ parameters[i] = dateBuilder.Build();
+ break;
+ case DbType.DateTime:
+ type = TimestampType.Default;
+ var timestampBuilder = new
TimestampArray.Builder();
+ if (param.Value == null)
+ {
+ timestampBuilder.AppendNull();
+ }
+ else
+ {
+ timestampBuilder.Append((DateTime)param.Value);
+ }
+ break;
+ // TODO: case DbType.Decimal:
+ case DbType.Double:
+ type = DoubleType.Default;
+ parameters[i] = new
DoubleArray.Builder().Append((double?)param.Value).Build();
+ break;
+ case DbType.Int16:
+ type = Int16Type.Default;
+ parameters[i] = new
Int16Array.Builder().Append((short?)param.Value).Build();
+ break;
+ case DbType.Int32:
+ type = Int32Type.Default;
+ parameters[i] = new
Int32Array.Builder().Append((int?)param.Value).Build();
+ break;
+ case DbType.Int64:
+ type = Int64Type.Default;
+ parameters[i] = new
Int64Array.Builder().Append((long?)param.Value).Build();
+ break;
+ case DbType.SByte:
+ type = Int8Type.Default;
+ parameters[i] = new
Int8Array.Builder().Append((sbyte?)param.Value).Build();
+ break;
+ case DbType.Single:
+ type = FloatType.Default;
+ parameters[i] = new
FloatArray.Builder().Append((float?)param.Value).Build();
+ break;
+ case DbType.String:
+ type = StringType.Default;
+ parameters[i] = new
StringArray.Builder().Append((string)param.Value!).Build();
+ break;
+ // TODO: case DbType.Time:
+ case DbType.UInt16:
+ type = UInt16Type.Default;
+ parameters[i] = new
UInt16Array.Builder().Append((ushort?)param.Value).Build();
+ break;
+ case DbType.UInt32:
+ type = UInt32Type.Default;
+ parameters[i] = new
UInt32Array.Builder().Append((uint?)param.Value).Build();
+ break;
+ case DbType.UInt64:
+ type = UInt64Type.Default;
+ parameters[i] = new
UInt64Array.Builder().Append((ulong?)param.Value).Build();
+ break;
+ default:
+ throw new NotSupportedException($"Parameters of
type {param.DbType} are not supported");
+ }
+
+ fields[i] = new Field(
+ string.IsNullOrWhiteSpace(param.ParameterName) ?
Guid.NewGuid().ToString() : param.ParameterName,
+ type,
+ param.IsNullable || param.Value == null);
+ }
+
+ Schema schema = new Schema(fields, null);
+ AdbcStatement.Bind(new RecordBatch(schema, parameters, 1),
schema);
+ }
+ }
+
#if NET5_0_OR_GREATER
public override ValueTask DisposeAsync()
{
@@ -214,8 +364,6 @@ namespace Apache.Arrow.Adbc.Client
public override UpdateRowSource UpdatedRowSource { get => throw new
NotImplementedException(); set => throw new NotImplementedException(); }
- protected override DbParameterCollection DbParameterCollection =>
throw new NotImplementedException();
-
protected override DbTransaction? DbTransaction { get => throw new
NotImplementedException(); set => throw new NotImplementedException(); }
public override void Cancel()
@@ -235,9 +383,60 @@ namespace Apache.Arrow.Adbc.Client
protected override DbParameter CreateDbParameter()
{
- throw new NotImplementedException();
+ return new AdbcParameter();
}
#endregion
+
+ private class AdbcParameterCollection : DbParameterCollection
+ {
+ readonly List<AdbcParameter> _parameters = new
List<AdbcParameter>();
+
+ public override int Count => _parameters.Count;
+
+ public override object SyncRoot => throw new
NotImplementedException();
+
+ public override int Add(object value)
+ {
+ int result = _parameters.Count;
+ _parameters.Add((AdbcParameter)value);
+ return result;
+ }
+
+ public override void AddRange(System.Array values) =>
_parameters.AddRange(values.Cast<AdbcParameter>());
+ public override void Clear() => _parameters.Clear();
+ public override bool Contains(object value) =>
_parameters.Contains((AdbcParameter)value);
+ public override bool Contains(string value) => IndexOf(value) >= 0;
+ public override void CopyTo(System.Array array, int index) =>
throw new NotImplementedException();
+ public override IEnumerator GetEnumerator() =>
_parameters.GetEnumerator();
+ public override int IndexOf(object value) =>
_parameters.IndexOf((AdbcParameter)value);
+ public override int IndexOf(string parameterName) =>
GetParameterIndex(parameterName, throwOnFailure: false);
+ public override void Insert(int index, object value) =>
_parameters.Insert(index, (AdbcParameter)value);
+ public override void Remove(object value) =>
_parameters.Remove((AdbcParameter)value);
+ public override void RemoveAt(int index) =>
_parameters.RemoveAt(index);
+ public override void RemoveAt(string parameterName) =>
_parameters.RemoveAt(GetParameterIndex(parameterName));
+ protected override DbParameter GetParameter(int index) =>
_parameters[index];
+ protected override DbParameter GetParameter(string parameterName)
=> _parameters[GetParameterIndex(parameterName)];
+ protected override void SetParameter(int index, DbParameter value)
=> _parameters[index] = (AdbcParameter)value;
+ protected override void SetParameter(string parameterName,
DbParameter value) => throw new NotImplementedException();
+
+ private int GetParameterIndex(string parameterName, bool
throwOnFailure = true)
+ {
+ for (int i = 0; i < _parameters.Count; i++)
+ {
+ if (parameterName == _parameters[i].ParameterName)
+ {
+ return i;
+ }
+ }
+
+ if (throwOnFailure)
+ {
+ throw new IndexOutOfRangeException("parameterName not
found");
+ }
+
+ return -1;
+ }
+ }
}
}
diff --git a/csharp/src/Client/AdbcParameter.cs
b/csharp/src/Client/AdbcParameter.cs
new file mode 100644
index 000000000..620b921c5
--- /dev/null
+++ b/csharp/src/Client/AdbcParameter.cs
@@ -0,0 +1,44 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Data.Common;
+using System.Data;
+using System.Diagnostics.CodeAnalysis;
+using System;
+
+namespace Apache.Arrow.Adbc.Client
+{
+ sealed public class AdbcParameter : DbParameter
+ {
+ public override DbType DbType { get; set; }
+ public override ParameterDirection Direction { get =>
ParameterDirection.Input; set => throw new NotImplementedException(); }
+ public override bool IsNullable { get; set; } = true;
+#if NET5_0_OR_GREATER
+ [AllowNull]
+#endif
+ public override string ParameterName { get; set; } = string.Empty;
+ public override int Size { get => throw new NotImplementedException();
set => throw new NotImplementedException(); }
+#if NET5_0_OR_GREATER
+ [AllowNull]
+#endif
+ public override string SourceColumn { get => throw new
NotImplementedException(); set => throw new NotImplementedException(); }
+ public override bool SourceColumnNullMapping { get => throw new
NotImplementedException(); set => throw new NotImplementedException(); }
+ public override object? Value { get; set; }
+
+ public override void ResetDbType() => throw new
NotImplementedException();
+ }
+}
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
index 1d947062d..d5a3b117c 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs
@@ -94,7 +94,11 @@ namespace Apache.Arrow.Adbc.Tests
/// </summary>
/// <param name="adbcConnection">The <see
cref="Adbc.Client.AdbcConnection"/> to use.</param>
/// <param name="testConfiguration">The <see
cref="TestConfiguration"/> to use</param>
- public static void CanClientExecuteQuery(Adbc.Client.AdbcConnection
adbcConnection, TestConfiguration testConfiguration)
+ /// <param name="additionalCommandOptionsSetter">Allows additional
options to be set on the command before execution</param>
+ public static void CanClientExecuteQuery(
+ Adbc.Client.AdbcConnection adbcConnection,
+ TestConfiguration testConfiguration,
+ Action<AdbcCommand>? additionalCommandOptionsSetter = null)
{
if (adbcConnection == null) throw new
ArgumentNullException(nameof(adbcConnection));
if (testConfiguration == null) throw new
ArgumentNullException(nameof(testConfiguration));
@@ -104,6 +108,7 @@ namespace Apache.Arrow.Adbc.Tests
adbcConnection.Open();
using AdbcCommand adbcCommand = new
AdbcCommand(testConfiguration.Query, adbcConnection);
+ additionalCommandOptionsSetter?.Invoke(adbcCommand);
using AdbcDataReader reader = adbcCommand.ExecuteReader();
try
diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
index 04cf0d1eb..942a78fe7 100644
--- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
@@ -142,6 +142,28 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
}
}
+ // <summary>
+ /// Validates if the client can connect to a live server and execute a
parameterized query.
+ /// </summary>
+ [SkippableFact, Order(4)]
+ public void CanClientExecuteParameterizedQuery()
+ {
+ SnowflakeTestConfiguration testConfiguration =
Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
+ testConfiguration.Query = "SELECT * FROM (SELECT column1 FROM
(VALUES (1), (2), (3))) WHERE column1 < ?";
+ testConfiguration.ExpectedResultsCount = 1;
+
+ using (Adbc.Client.AdbcConnection adbcConnection =
GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration))
+ {
+ Tests.ClientTests.CanClientExecuteQuery(adbcConnection,
testConfiguration, command =>
+ {
+ DbParameter parameter1 = command.CreateParameter();
+ parameter1.Value = 2;
+ parameter1.DbType = DbType.Int32;
+ command.Parameters.Add(parameter1);
+ });
+ }
+ }
+
// <summary>
/// Validates if the client can connect to a live server
/// and parse the results.
diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
index 4e9fbb4f4..45ce416f7 100644
--- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
@@ -399,6 +399,24 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
Tests.DriverTests.CanExecuteQuery(queryResult,
_testConfiguration.ExpectedResultsCount);
}
+ /// <summary>
+ /// Validates if the driver can connect to a live server and execute a
parameterized query.
+ /// </summary>
+ [SkippableFact, Order(6)]
+ public void CanExecuteParameterizedQuery()
+ {
+ using AdbcStatement statement = _connection.CreateStatement();
+ statement.SqlQuery = "SELECT * FROM (SELECT column1 FROM (VALUES
(1), (2), (3))) WHERE column1 < ?";
+
+ Schema parameterSchema = new Schema(new[] { new Field("column1",
Int32Type.Default, false) }, null);
+ RecordBatch parameters = new RecordBatch(parameterSchema, new[] {
new Int32Array.Builder().Append(2).Build() }, 1);
+ statement.Bind(parameters, parameterSchema);
+
+ QueryResult queryResult = statement.ExecuteQuery();
+
+ Tests.DriverTests.CanExecuteQuery(queryResult, 1);
+ }
+
[SkippableFact, Order(7)]
public void CanIngestData()
{