This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new e39d71549 feat(csharp): Add support for Prepare to ImportedStatement
and to ADO.NET wrapper (#2628)
e39d71549 is described below
commit e39d71549bde79ebd2662480ae82b94140c62e03
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Tue Mar 18 07:31:08 2025 -0700
feat(csharp): Add support for Prepare to ImportedStatement and to ADO.NET
wrapper (#2628)
Closes #2616.
---
.../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 70 ++++++++++++++--------
csharp/src/Client/AdbcCommand.cs | 59 ++++++++++++++----
.../Client/DuckDbClientTests.cs | 32 +++++++++-
.../Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs | 32 ++++++++++
4 files changed, 157 insertions(+), 36 deletions(-)
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index 73ae28b7d..133fff57a 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -797,6 +797,15 @@ namespace Apache.Arrow.Adbc.C
Dispose(false);
}
+ public override string? SqlQuery
+ {
+ set
+ {
+ SetSqlQuery(value);
+ base.SqlQuery = value;
+ }
+ }
+
private unsafe ref CAdbcDriver Driver
{
get
@@ -885,12 +894,6 @@ namespace Apache.Arrow.Adbc.C
public unsafe override QueryResult ExecuteQuery()
{
- if (SqlQuery != null)
- {
- // TODO: Consider moving this to the setter
- SetSqlQuery(SqlQuery);
- }
-
using (CallHelper caller = new CallHelper())
{
fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -911,12 +914,6 @@ namespace Apache.Arrow.Adbc.C
public override unsafe Schema ExecuteSchema()
{
- if (SqlQuery != null)
- {
- // TODO: Consider moving this to the setter
- SetSqlQuery(SqlQuery);
- }
-
using (CallHelper caller = new CallHelper())
{
fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -936,12 +933,6 @@ namespace Apache.Arrow.Adbc.C
public unsafe override UpdateResult ExecuteUpdate()
{
- if (SqlQuery != null)
- {
- // TODO: Consider moving this to the setter
- SetSqlQuery(SqlQuery);
- }
-
using (CallHelper caller = new CallHelper())
{
fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -962,12 +953,6 @@ namespace Apache.Arrow.Adbc.C
public unsafe override PartitionedResult ExecutePartitioned()
{
- if (SqlQuery != null)
- {
- // TODO: Consider moving this to the setter
- SetSqlQuery(SqlQuery);
- }
-
using (CallHelper caller = new CallHelper())
{
fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -1013,6 +998,41 @@ namespace Apache.Arrow.Adbc.C
}
}
+ public unsafe override Schema GetParameterSchema()
+ {
+ using (CallHelper caller = new CallHelper())
+ {
+ fixed (CAdbcStatement* statement = &_nativeStatement)
+ {
+ caller.TranslateCode(
+#if NET5_0_OR_GREATER
+ Driver.StatementGetParameterSchema
+#else
+
Marshal.GetDelegateForFunctionPointer<StatementGetParameterSchema>(Driver.StatementGetParameterSchema)
+#endif
+ (statement, caller.CreateSchema(), &
caller._error));
+ }
+ return caller.ImportSchema();
+ }
+ }
+
+ public unsafe override void Prepare()
+ {
+ using (CallHelper caller = new CallHelper())
+ {
+ fixed (CAdbcStatement* statement = &_nativeStatement)
+ {
+ caller.TranslateCode(
+#if NET5_0_OR_GREATER
+ Driver.StatementPrepare
+#else
+
Marshal.GetDelegateForFunctionPointer<StatementPrepare>(Driver.StatementPrepare)
+#endif
+ (statement, &caller._error));
+ }
+ }
+ }
+
public unsafe override void SetOption(string key, string value)
{
using (CallHelper caller = new CallHelper())
@@ -1055,7 +1075,7 @@ namespace Apache.Arrow.Adbc.C
}
}
- private unsafe void SetSqlQuery(string sqlQuery)
+ private unsafe void SetSqlQuery(string? sqlQuery)
{
fixed (CAdbcStatement* statement = &_nativeStatement)
{
diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index a317ca19c..7ad0678af 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -483,12 +483,61 @@ namespace Apache.Arrow.Adbc.Client
}
}
+ public override void Prepare()
+ {
+ _adbcStatement.Prepare();
+ var schema = _adbcStatement.GetParameterSchema();
+
+ DbParameterCollection.Clear();
+
+ foreach (Field field in schema.FieldsList)
+ {
+ AdbcParameter parameter = new AdbcParameter
+ {
+ ParameterName = field.Name,
+ IsNullable = field.IsNullable,
+ DbType = field.DataType.TypeId switch
+ {
+ ArrowTypeId.UInt8 => DbType.Byte,
+ ArrowTypeId.UInt16 => DbType.UInt16,
+ ArrowTypeId.UInt32 => DbType.UInt32,
+ ArrowTypeId.UInt64 => DbType.UInt64,
+ ArrowTypeId.Int8 => DbType.SByte,
+ ArrowTypeId.Int16 => DbType.Int16,
+ ArrowTypeId.Int32 => DbType.Int32,
+ ArrowTypeId.Int64 => DbType.Int64,
+ ArrowTypeId.Float => DbType.Single,
+ ArrowTypeId.Double => DbType.Double,
+ ArrowTypeId.Boolean => DbType.Boolean,
+ ArrowTypeId.String => DbType.String,
+ ArrowTypeId.Date32 => DbType.Date,
+ ArrowTypeId.Date64 => DbType.DateTime,
+ ArrowTypeId.Time32 => DbType.Time,
+ ArrowTypeId.Time64 => DbType.Time,
+ ArrowTypeId.Timestamp => DbType.DateTime,
+ ArrowTypeId.Decimal32 or
+ ArrowTypeId.Decimal64 or
+ ArrowTypeId.Decimal128 or
+ ArrowTypeId.Decimal256 => DbType.Decimal,
+ _ => DbType.Object,
+ },
+ };
+ DbParameterCollection.Add(parameter);
+ }
+ }
+
+ protected override DbParameter CreateDbParameter()
+ {
+ return new AdbcParameter();
+ }
+
#if NET5_0_OR_GREATER
public override ValueTask DisposeAsync()
{
return base.DisposeAsync();
}
#endif
+
#region NOT_IMPLEMENTED
public override bool DesignTimeVisible { get => throw new
NotImplementedException(); set => throw new NotImplementedException(); }
@@ -507,16 +556,6 @@ namespace Apache.Arrow.Adbc.Client
throw new NotImplementedException();
}
- public override void Prepare()
- {
- throw new NotImplementedException();
- }
-
- protected override DbParameter CreateDbParameter()
- {
- return new AdbcParameter();
- }
-
#endregion
private class AdbcParameterCollection : DbParameterCollection
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
index 31e25f662..0a8e89d82 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
@@ -15,8 +15,9 @@
* limitations under the License.
*/
-using System.Collections.Generic;
+using System.Data;
using Apache.Arrow.Adbc.Client;
+using Apache.Arrow.Types;
using Xunit;
namespace Apache.Arrow.Adbc.Tests.Client
@@ -117,6 +118,35 @@ namespace Apache.Arrow.Adbc.Tests.Client
});
}
+ [Fact]
+ public void BindParameters()
+ {
+ using var connection =
_duckDb.CreateConnection("bindparameters.db", null);
+ connection.Open();
+ var command = connection.CreateCommand();
+
+ command.CommandText = "select ?, ?";
+ command.Prepare();
+ Assert.Equal(2, command.Parameters.Count);
+ Assert.Equal("0", command.Parameters[0].ParameterName);
+ Assert.Equal(DbType.Object, command.Parameters[0].DbType);
+ Assert.Equal("1", command.Parameters[1].ParameterName);
+ Assert.Equal(DbType.Object, command.Parameters[1].DbType);
+
+ command.Parameters[0].DbType = DbType.Int32;
+ command.Parameters[0].Value = 1;
+ command.Parameters[1].DbType = DbType.String;
+ command.Parameters[1].Value = "foo";
+
+ using var reader = command.ExecuteReader();
+ long count = 0;
+ while (reader.Read())
+ {
+ count++;
+ }
+ Assert.Equal(1, count);
+ }
+
private static long GetResultCount(AdbcCommand command, string query)
{
command.CommandText = "SELECT * from test";
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
index d5ad20c9c..63bc1022b 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
@@ -209,6 +209,33 @@ namespace Apache.Arrow.Adbc.Tests
Assert.Equal(6, GetResultCount(statement3, "SELECT * from
main.ingested"));
}
+ [Fact]
+ public void PrepareAndBind()
+ {
+ using var database = _duckDb.OpenDatabase("bind.db");
+ using var connection = database.Connect(null);
+ using var statement = connection.CreateStatement();
+
+ statement.SqlQuery = "select ?, ?";
+ statement.Prepare();
+ var schema = statement.GetParameterSchema();
+ Assert.Equal(2, schema.FieldsList.Count);
+ Assert.Equal("0", schema.FieldsList[0].Name);
+ Assert.Equal(ArrowTypeId.Null,
schema.FieldsList[0].DataType.TypeId);
+ Assert.Equal("1", schema.FieldsList[1].Name);
+ Assert.Equal(ArrowTypeId.Null,
schema.FieldsList[1].DataType.TypeId);
+
+ schema = new Schema([new Field("0", Int32Type.Default, false), new
Field("1", StringType.Default, false)], null);
+ RecordBatch recordBatch = new RecordBatch(schema, [
+ new Int32Array.Builder().AppendRange([1]).Build(),
+ new StringArray.Builder().AppendRange(["foo"]).Build()
+ ], 1);
+ statement.Bind(recordBatch, schema);
+
+ var results = statement.ExecuteQuery();
+ Assert.Equal(1, GetResultCount(results));
+ }
+
[Fact]
public async Task GetTableTypes()
{
@@ -255,6 +282,11 @@ namespace Apache.Arrow.Adbc.Tests
{
statement.SqlQuery = query;
var results = statement.ExecuteQuery();
+ return GetResultCount(results);
+ }
+
+ private static long GetResultCount(QueryResult results)
+ {
long count = 0;
using (var stream = results.Stream ?? throw new
InvalidOperationException("no results found"))
{