This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 73cd1e773 fix(csharp/src/Apache.Arrow.Adbc): imported statements and 
databases don't allow options to be set (#1816)
73cd1e773 is described below

commit 73cd1e77338985ef3968e5e373736cfb431144d8
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Sat May 4 05:37:24 2024 -0700

    fix(csharp/src/Apache.Arrow.Adbc): imported statements and databases don't 
allow options to be set (#1816)
    
    Allow options to be set on imported datatabases and statements. Hook up
    AdbcConnection.BulkInsert.
    
    Closes #1814
---
 csharp/src/Apache.Arrow.Adbc/AdbcConnection.cs     |  1 +
 csharp/src/Apache.Arrow.Adbc/AdbcDatabase.cs       | 11 ++++
 csharp/src/Apache.Arrow.Adbc/AdbcOptions.cs        | 29 ++++++++++
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 63 +++++++++++++++++++++-
 .../Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs | 36 ++++++++++++-
 .../test/Drivers/Interop/Snowflake/ClientTests.cs  |  2 +-
 .../test/Drivers/Interop/Snowflake/DriverTests.cs  | 54 ++++++++++++++++++-
 7 files changed, 191 insertions(+), 5 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcConnection.cs 
b/csharp/src/Apache.Arrow.Adbc/AdbcConnection.cs
index 2ebb585fb..07cb586c8 100644
--- a/csharp/src/Apache.Arrow.Adbc/AdbcConnection.cs
+++ b/csharp/src/Apache.Arrow.Adbc/AdbcConnection.cs
@@ -18,6 +18,7 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using Apache.Arrow.Adbc.C;
 using Apache.Arrow.Ipc;
 
 namespace Apache.Arrow.Adbc
diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcDatabase.cs 
b/csharp/src/Apache.Arrow.Adbc/AdbcDatabase.cs
index deac7d80a..58f07dcb4 100644
--- a/csharp/src/Apache.Arrow.Adbc/AdbcDatabase.cs
+++ b/csharp/src/Apache.Arrow.Adbc/AdbcDatabase.cs
@@ -29,6 +29,17 @@ namespace Apache.Arrow.Adbc
     /// </summary>
     public abstract class AdbcDatabase : IDisposable
     {
+        /// <summary>
+        /// Options are generally set before opening a database.  Some drivers 
may
+        /// support setting options after opening as well.
+        /// </summary>
+        /// <param name="key">Option name</param>
+        /// <param name="value">Option value</param>
+        public virtual void SetOption(string key, string value)
+        {
+            throw AdbcException.NotImplemented("Connection does not support 
setting options");
+        }
+
         /// <summary>
         /// Create a new connection to the database.
         /// </summary>
diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcOptions.cs 
b/csharp/src/Apache.Arrow.Adbc/AdbcOptions.cs
index 8fc261aa2..b5a54b240 100644
--- a/csharp/src/Apache.Arrow.Adbc/AdbcOptions.cs
+++ b/csharp/src/Apache.Arrow.Adbc/AdbcOptions.cs
@@ -39,6 +39,18 @@ namespace Apache.Arrow.Adbc
             public const string Linearizable = 
"adbc.connection.transaction.isolation.linearizable";
         }
 
+        public static class Ingest
+        {
+            public const string TargetTable = "adbc.ingest.target_table";
+            public const string Mode = "adbc.ingest.mode";
+        }
+
+        public static class IngestMode
+        {
+            public const string Create = "adbc.ingest.mode.create";
+            public const string Append = "adbc.ingest.mode.append";
+        }
+
         public static string GetEnabled(bool value) => value ? Enabled : 
Disabled;
         public static bool GetEnabled(string value)
         {
@@ -73,5 +85,22 @@ namespace Apache.Arrow.Adbc
             if (StringComparer.OrdinalIgnoreCase.Equals(value, 
IsolationLevels.Linearizable)) { return Adbc.IsolationLevel.Linearizable; }
             throw new NotSupportedException("unknown isolation level");
         }
+
+        public static string GetIngestMode(BulkIngestMode value)
+        {
+            return value switch
+            {
+                BulkIngestMode.Create => IngestMode.Create,
+                BulkIngestMode.Append => IngestMode.Append,
+                _ => throw new NotSupportedException("unknown ingestion mode"),
+            };
+        }
+
+        public static BulkIngestMode GetIngestMode(string value)
+        {
+            if (StringComparer.OrdinalIgnoreCase.Equals(value, 
IngestMode.Create)) { return BulkIngestMode.Create; }
+            if (StringComparer.OrdinalIgnoreCase.Equals(value, 
IngestMode.Append)) { return BulkIngestMode.Append; }
+            throw new NotSupportedException("unknown ingestion mode");
+        }
     }
 }
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index 24c820631..b7b755405 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -514,6 +514,14 @@ namespace Apache.Arrow.Adbc.C
                 return result;
             }
 
+            public unsafe override void SetOption(string key, string value)
+            {
+                using (CallHelper caller = new CallHelper())
+                {
+                    caller.Call(Driver.DatabaseSetOption, ref _nativeDatabase, 
key, value);
+                }
+            }
+
             public override void Dispose()
             {
                 Dispose(true);
@@ -751,6 +759,23 @@ namespace Apache.Arrow.Adbc.C
                 }
             }
 
+            public override AdbcStatement BulkIngest(string targetTableName, 
BulkIngestMode mode)
+            {
+                AdbcStatement statement = CreateStatement();
+                bool succeeded = false;
+                try
+                {
+                    statement.SetOption(AdbcOptions.Ingest.TargetTable, 
targetTableName);
+                    statement.SetOption(AdbcOptions.Ingest.Mode, 
AdbcOptions.GetIngestMode(mode));
+                    succeeded = true;
+                    return statement;
+                }
+                finally
+                {
+                    if (!succeeded) { statement.Dispose(); }
+                }
+            }
+
             public unsafe override void Commit()
             {
                 using (CallHelper caller = new CallHelper())
@@ -953,7 +978,7 @@ namespace Apache.Arrow.Adbc.C
 #else
                             
Marshal.GetDelegateForFunctionPointer<StatementExecuteQuery>(Driver.StatementExecuteQuery)
 #endif
-                            (statement, caller.CreateStream(), &rows, 
&caller._error));
+                            (statement, null, &rows, &caller._error));
 
                         return new UpdateResult(rows);
                     }
@@ -1013,6 +1038,14 @@ namespace Apache.Arrow.Adbc.C
                 }
             }
 
+            public unsafe override void SetOption(string key, string value)
+            {
+                using (CallHelper caller = new CallHelper())
+                {
+                    caller.Call(Driver.StatementSetOption, ref 
_nativeStatement, key, value);
+                }
+            }
+
             public override void Dispose()
             {
                 Dispose(true);
@@ -1327,6 +1360,34 @@ namespace Apache.Arrow.Adbc.C
             }
 #endif
 
+#if NET5_0_OR_GREATER
+            public unsafe void Call(delegate* unmanaged<CAdbcStatement*, 
byte*, byte*, CAdbcError*, AdbcStatusCode> fn, ref CAdbcStatement 
nativeStatement, string key, string? value)
+            {
+                fixed (CAdbcStatement* stmt = &nativeStatement)
+                fixed (CAdbcError* e = &_error)
+                {
+                    using (Utf8Helper utf8Key = new Utf8Helper(key))
+                    using (Utf8Helper utf8Value = new Utf8Helper(value))
+                    {
+                        TranslateCode(fn(stmt, utf8Key, utf8Value, e));
+                    }
+                }
+            }
+#else
+            public unsafe void Call(IntPtr fn, ref CAdbcStatement 
nativeStatement, string key, string? value)
+            {
+                fixed (CAdbcStatement* stmt = &nativeStatement)
+                fixed (CAdbcError* e = &_error)
+                {
+                    using (Utf8Helper utf8Key = new Utf8Helper(key))
+                    using (Utf8Helper utf8Value = new Utf8Helper(value))
+                    {
+                        
TranslateCode(Marshal.GetDelegateForFunctionPointer<StatementSetOption>(fn)(stmt,
 utf8Key, utf8Value, e));
+                    }
+                }
+            }
+#endif
+
             public unsafe void Dispose()
             {
                 if (_error.release != default)
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs 
b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
index e48cfd3ab..0cd25a14c 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
@@ -162,9 +162,43 @@ namespace Apache.Arrow.Adbc.Tests
             });
         }
 
+        [Fact]
+        public void IngestData()
+        {
+            using var database = _duckDb.OpenDatabase("ingest.db");
+            using var connection = database.Connect(null);
+            using var statement = connection.CreateStatement();
+            statement.SetOption("adbc.ingest.target_table", "ingested");
+            statement.SetOption("adbc.ingest.mode", "adbc.ingest.mode.create");
+
+            Schema schema = new Schema([new Field("key", Int32Type.Default, 
false), new Field("value", StringType.Default, false)], null);
+            RecordBatch recordBatch = new RecordBatch(schema, [
+                new Int32Array.Builder().AppendRange([1, 2, 3]).Build(),
+                new StringArray.Builder().AppendRange(["foo", "bar", 
"baz"]).Build()
+                ], 3);
+            statement.Bind(recordBatch, schema);
+            statement.ExecuteUpdate();
+
+            Schema foundSchema = connection.GetTableSchema(null, null, 
"ingested");
+            Assert.Equal(schema.FieldsList.Count, 
foundSchema.FieldsList.Count);
+
+            Assert.Equal(3, GetResultCount(statement, "SELECT * from 
ingested"));
+
+            using var statement2 = connection.BulkIngest("ingested", 
BulkIngestMode.Append);
+
+            recordBatch = new RecordBatch(schema, [
+                new Int32Array.Builder().AppendRange([4, 5]).Build(),
+                new StringArray.Builder().AppendRange(["quux", 
"zozzle"]).Build()
+                ], 2);
+            statement2.Bind(recordBatch, schema);
+            statement2.ExecuteUpdate();
+
+            Assert.Equal(5, GetResultCount(statement2, "SELECT * from 
ingested"));
+        }
+
         private static long GetResultCount(AdbcStatement statement, string 
query)
         {
-            statement.SqlQuery = "SELECT * from test";
+            statement.SqlQuery = query;
             var results = statement.ExecuteQuery();
             long count = 0;
             using (var stream = results.Stream ?? throw new 
InvalidOperationException("no results found"))
diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
index f90e278f1..04cf0d1eb 100644
--- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
@@ -54,7 +54,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             {
                 string[] queries = 
SnowflakeTestingUtils.GetQueries(testConfiguration);
 
-                List<int> expectedResults = new List<int>() { 1, 1, 1 };
+                List<int> expectedResults = new List<int>() { -1, 1, 1 };
 
                 Tests.ClientTests.CanClientExecuteUpdate(adbcConnection, 
testConfiguration, queries, expectedResults);
             }
diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
index da27ef1b5..4e9fbb4f4 100644
--- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
@@ -16,12 +16,12 @@
 */
 
 using System;
-using System.Collections;
 using System.Collections.Generic;
 using System.Linq;
 using Apache.Arrow.Adbc.Tests.Metadata;
 using Apache.Arrow.Adbc.Tests.Xunit;
 using Apache.Arrow.Ipc;
+using Apache.Arrow.Types;
 using Xunit;
 
 namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
@@ -100,7 +100,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
         {
             string[] queries = 
SnowflakeTestingUtils.GetQueries(_testConfiguration);
 
-            List<int> expectedResults = new List<int>() { 1, 1, 1 };
+            List<int> expectedResults = new List<int>() { -1, 1, 1 };
 
             for (int i = 0; i < queries.Length; i++)
             {
@@ -399,6 +399,56 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             Tests.DriverTests.CanExecuteQuery(queryResult, 
_testConfiguration.ExpectedResultsCount);
         }
 
+        [SkippableFact, Order(7)]
+        public void CanIngestData()
+        {
+            const string tableName = "AdbcIngestTest";
+
+            using var statement = _connection.CreateStatement();
+            statement.SqlQuery = $"USE SCHEMA 
\"{_testConfiguration.Metadata.Schema}\"";
+            statement.ExecuteUpdate();
+            statement.SqlQuery = $"DROP TABLE IF EXISTS \"{tableName}\"";
+            statement.ExecuteUpdate();
+
+            statement.SqlQuery = null;
+            statement.SetOption("adbc.ingest.target_table", tableName);
+            statement.SetOption("adbc.ingest.mode", "adbc.ingest.mode.create");
+
+            Schema schema = new Schema([new Field("key", Int32Type.Default, 
false), new Field("value", StringType.Default, false)], null);
+            RecordBatch recordBatch = new RecordBatch(schema, [
+                new Int32Array.Builder().AppendRange([1, 2, 3]).Build(),
+                new StringArray.Builder().AppendRange(["foo", "bar", 
"baz"]).Build()
+                ], 3);
+            statement.Bind(recordBatch, schema);
+            statement.ExecuteUpdate();
+
+            Schema foundSchema = _connection.GetTableSchema(null, null, 
tableName);
+            Assert.Equal(schema.FieldsList.Count, 
foundSchema.FieldsList.Count);
+
+            statement.SqlQuery = $"SELECT * FROM \"{tableName}\"";
+            var result = statement.ExecuteQuery();
+            Assert.Equal(3, result.RowCount);
+            result.Stream?.Dispose();
+
+            using var statement2 = _connection.BulkIngest(tableName, 
BulkIngestMode.Append);
+
+            recordBatch = new RecordBatch(schema, [
+                new Int32Array.Builder().AppendRange([4, 5]).Build(),
+                new StringArray.Builder().AppendRange(["quux", 
"zozzle"]).Build()
+                ], 2);
+            statement2.Bind(recordBatch, schema);
+            statement2.ExecuteUpdate();
+
+            statement.SqlQuery = $"SELECT * FROM \"{tableName}\"";
+            result = statement.ExecuteQuery();
+            Assert.Equal(5, result.RowCount);
+            result.Stream?.Dispose();
+
+            statement.SqlQuery = $"DROP TABLE IF EXISTS \"{tableName}\"";
+            statement.ExecuteUpdate();
+        }
+
+
         private void CreateDatabaseAndTable(string databaseName, string 
schemaName, string tableName)
         {
             databaseName = databaseName.Replace("\"", "\"\"");

Reply via email to