This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new e8444219 feat(go/adbc/driver/snowflake): enable passing private key 
for JWT via string and not file (#1207)
e8444219 is described below

commit e8444219d4f2723c081a5df5ccacca284f47ac5a
Author: davidhcoe <[email protected]>
AuthorDate: Mon Oct 30 11:23:57 2023 -0400

    feat(go/adbc/driver/snowflake): enable passing private key for JWT via 
string and not file (#1207)
    
    Adds support for passing PKCS8 values to Snowflake without having to
    read a file.
    
    - Includes C# integration updates to tests
    - includes support for PKCS8 encryption
    - Includes test case in Snowflake driver
    
    Resolves https://github.com/apache/arrow-adbc/issues/1204
    
    ---------
    
    Co-authored-by: David Coe <[email protected]>
    Co-authored-by: David Li <[email protected]>
---
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs |  2 +-
 csharp/src/Client/readme.md                        |  4 +-
 csharp/test/Drivers/Snowflake/ClientTests.cs       | 29 ++++++--
 csharp/test/Drivers/Snowflake/DriverTests.cs       | 12 +--
 .../Drivers/Snowflake/Resources/SnowflakeData.sql  |  6 +-
 .../Snowflake/Resources/snowflakeconfig.json       |  4 +
 .../Snowflake/SnowflakeTestConfiguration.cs        | 24 ++++++
 .../Drivers/Snowflake/SnowflakeTestingUtils.cs     | 56 ++++++++++----
 docs/source/driver/snowflake.rst                   |  9 +++
 go/adbc/driver/snowflake/driver.go                 |  9 ++-
 go/adbc/driver/snowflake/driver_test.go            | 87 ++++++++++++++++++++++
 go/adbc/driver/snowflake/record_reader.go          | 40 ++++------
 go/adbc/driver/snowflake/snowflake_database.go     | 41 ++++++++++
 go/adbc/go.mod                                     |  1 +
 go/adbc/go.sum                                     |  2 +
 .../adbc_driver_snowflake/__init__.py              |  9 +++
 16 files changed, 278 insertions(+), 57 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index 626c8357..914a7f22 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -333,7 +333,7 @@ namespace Apache.Arrow.Adbc.C
             {
                 if (arrowArray is BooleanArray)
                 {
-                    return 
Convert.ToBoolean(((BooleanArray)arrowArray).Values[index]);
+                    return ((BooleanArray)arrowArray).GetValue(index);
                 }
                 else if (arrowArray is Date32Array)
                 {
diff --git a/csharp/src/Client/readme.md b/csharp/src/Client/readme.md
index 5ad4698b..91e33ca9 100644
--- a/csharp/src/Client/readme.md
+++ b/csharp/src/Client/readme.md
@@ -67,9 +67,9 @@ For example, when using the [Snowflake ADBC Go 
Driver](https://arrow.apache.org/
 if using the default user name and password authentication, but look like
 
 ```
- 
adbc.snowflake.sql.account={account};adbc.snowflake.sql.warehouse={warehouse};username={user};password={password};adbc.snowflake.sql.auth_type=snowflake_jwt;adbc.snowflake.sql.client_option.auth_token={token}
+ 
adbc.snowflake.sql.account={account};adbc.snowflake.sql.warehouse={warehouse};username={user};password={password};adbc.snowflake.sql.auth_type=snowflake_jwt;adbc.snowflake.sql.client_option.jwt_private_key={private_key_file}
 ```
 
-when using JWT authentication.
+when using JWT authentication with an unencrypted key file.
 
 Other ADBC drivers will have different connection parameters, so be sure to 
check the documentation for each driver.
diff --git a/csharp/test/Drivers/Snowflake/ClientTests.cs 
b/csharp/test/Drivers/Snowflake/ClientTests.cs
index abd1ac48..3eaf13c0 100644
--- a/csharp/test/Drivers/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Snowflake/ClientTests.cs
@@ -156,6 +156,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
                         while (reader.Read())
                         {
                             count++;
+
+                            for (int i = 0; i < reader.FieldCount; i++)
+                            {
+                                Console.WriteLine($"{reader.GetName(i)}: 
{reader.GetValue(i)}");
+                            }
                         }
                     }
                     finally { reader.Close(); }
@@ -251,19 +256,31 @@ namespace 
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
 
             DbConnectionStringBuilder builder = new 
DbConnectionStringBuilder(true);
 
-            builder["adbc.snowflake.sql.account"] = testConfiguration.Account;
-            builder["adbc.snowflake.sql.warehouse"] = 
testConfiguration.Warehouse;
-            builder["username"] = testConfiguration.User;
+            builder[SnowflakeParameters.ACCOUNT] = testConfiguration.Account;
+            builder[SnowflakeParameters.WAREHOUSE] = 
testConfiguration.Warehouse;
+            builder[SnowflakeParameters.HOST] = testConfiguration.Host;
+            builder[SnowflakeParameters.DATABASE] = testConfiguration.Database;
+            builder[SnowflakeParameters.USERNAME] = testConfiguration.User;
 
             if 
(!string.IsNullOrEmpty(testConfiguration.AuthenticationTokenPath))
             {
+                builder[SnowflakeParameters.AUTH_TYPE] = 
testConfiguration.AuthenticationType;
+
                 string privateKey = 
File.ReadAllText(testConfiguration.AuthenticationTokenPath);
-                builder["adbc.snowflake.sql.auth_type"] = 
testConfiguration.AuthenticationType;
-                builder["adbc.snowflake.sql.client_option.auth_token"] = 
privateKey;
+
+                if (testConfiguration.AuthenticationType.Equals("auth_jwt", 
StringComparison.OrdinalIgnoreCase))
+                {
+                    builder[SnowflakeParameters.PKCS8_VALUE] = privateKey;
+
+                    if(!string.IsNullOrEmpty(testConfiguration.Pkcs8Passcode))
+                    {
+                        builder[SnowflakeParameters.PKCS8_PASS] = 
testConfiguration.Pkcs8Passcode;
+                    }
+                }
             }
             else
             {
-                builder["password"] = testConfiguration.Password;
+                builder[SnowflakeParameters.PASSWORD] = 
testConfiguration.Password;
             }
 
             AdbcDriver snowflakeDriver = 
SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration);
diff --git a/csharp/test/Drivers/Snowflake/DriverTests.cs 
b/csharp/test/Drivers/Snowflake/DriverTests.cs
index 21593f5c..41527e99 100644
--- a/csharp/test/Drivers/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Snowflake/DriverTests.cs
@@ -132,8 +132,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
                 string tableName = metadataTestConfiguration.Metadata.Table;
                 string columnName = null;
 
-                parameters["adbc.snowflake.sql.db"] = databaseName;
-                parameters["adbc.snowflake.sql.schema"] = schemaName;
+                parameters[SnowflakeParameters.DATABASE] = databaseName;
+                parameters[SnowflakeParameters.SCHEMA] = schemaName;
 
                 AdbcDatabase adbcDatabase = driver.Open(parameters);
                 AdbcConnection adbcConnection = adbcDatabase.Connect(new 
Dictionary<string, string>());
@@ -158,6 +158,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
                     .Select(c => c.Columns)
                     .FirstOrDefault();
 
+                Assert.IsTrue(columns != null, "Columns cannot be null");
+
                 
Assert.AreEqual(metadataTestConfiguration.Metadata.ExpectedColumnCount, 
columns.Count);
             }
         }
@@ -215,9 +217,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
                 StringArray stringArray = 
(StringArray)recordBatch.Column("table_type");
 
                 List<string> known_types = new List<string>
-            {
-                "BASE TABLE", "TEMPORARY TABLE", "VIEW"
-            };
+                {
+                    "BASE TABLE", "TEMPORARY TABLE", "VIEW"
+                };
 
                 int results = 0;
 
diff --git a/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql 
b/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
index 48ceead9..d17dfd4a 100644
--- a/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
+++ b/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
@@ -14,7 +14,7 @@
  -- See the License for the specific language governing permissions and
  -- limitations under the License.
 
-CREATE OR REPLACE TRANSIENT TABLE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
+CREATE OR REPLACE TRANSIENT TABLE {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_TABLE} (
        NUMBERTYPE NUMBER(38,0),
        DECIMALTYPE NUMBER(18,1),
        NUMERICTYPE NUMBER(28,1),
@@ -47,7 +47,7 @@ CREATE OR REPLACE TRANSIENT TABLE 
{ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
        TIMESTAMPTZTYPE TIMESTAMP_TZ(9)
 );
 
-INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}
+INSERT INTO {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_TABLE}
 (
     NUMBERTYPE, DECIMALTYPE, NUMERICTYPE, INTTYPE, INTEGERTYPE, BIGINTTYPE, 
SMALLINTTYPE, TINYINTTYPE,
     BYTEINTTYPE, FLOATTYPE, FLOAT4TYPE, FLOAT8TYPE, DOUBLETYPE, 
DOUBLEPRECISIONTYPE, REALTYPE, VARCHARTYPE,
@@ -62,7 +62,7 @@ VALUES
     '12:34:56', '2023-07-28 12:34:56', '2023-07-28 12:34:56', '2023-07-28 
12:34:56', '2023-07-28 12:34:56'
 );
 
-INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}
+INSERT INTO {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_TABLE}
 (
     NUMBERTYPE, DECIMALTYPE, NUMERICTYPE, INTTYPE, INTEGERTYPE, BIGINTTYPE, 
SMALLINTTYPE, TINYINTTYPE,
     BYTEINTTYPE, FLOATTYPE, FLOAT4TYPE, FLOAT8TYPE, DOUBLETYPE, 
DOUBLEPRECISIONTYPE, REALTYPE, VARCHARTYPE,
diff --git a/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json 
b/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
index e14e2d41..d3b43581 100644
--- a/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
+++ b/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
@@ -2,11 +2,15 @@
     "driverPath": "",
     "driverEntryPoint": "",
     "account": "",
+    "host": "",
+    "database": "",
     "user": "",
     "password": "",
     "warehouse": "",
     "authenticationType": "",
     "authenticationTokenPath": "",
+    "pkcs8Passcode": "",
+    "useHighPrecision":  true,
     "metadata": {
         "catalog": "",
         "schema": "",
diff --git a/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs 
b/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
index 591f851d..96bdcc22 100644
--- a/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
+++ b/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
@@ -42,6 +42,18 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
         [JsonPropertyName("account")]
         public string Account { get; set; }
 
+        /// <summary>
+        /// The Snowflake host.
+        /// </summary>
+        [JsonPropertyName("host")]
+        public string Host { get; set; }
+
+        /// <summary>
+        /// The Snowflake database.
+        /// </summary>
+        [JsonPropertyName("database")]
+        public string Database { get; set; }
+
         /// <summary>
         /// The Snowflake user.
         /// </summary>
@@ -71,5 +83,17 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
         /// </summary>
         [JsonPropertyName("authenticationTokenPath"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
         public string AuthenticationTokenPath { get; set; }
+
+        /// <summary>
+        /// The passcode to use if the JWT token is encrypted.
+        /// </summary>
+        [JsonPropertyName("pkcs8Passcode"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string Pkcs8Passcode { get; set; }
+
+        /// <summary>
+        /// The Snowflake authentication type.
+        /// </summary>
+        [JsonPropertyName("useHighPrecision")]
+        public bool UseHighPrecision { get; set; } = true;
     }
 }
diff --git a/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs 
b/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
index 3ef8e718..0da3eff7 100644
--- a/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
+++ b/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
@@ -23,6 +23,21 @@ using Apache.Arrow.Adbc.C;
 
 namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
 {
+    internal class SnowflakeParameters
+    {
+        public const string DATABASE = "adbc.snowflake.sql.db";
+        public const string SCHEMA = "adbc.snowflake.sql.schema";
+        public const string ACCOUNT = "adbc.snowflake.sql.account";
+        public const string USERNAME = "username";
+        public const string PASSWORD = "password";
+        public const string WAREHOUSE = "adbc.snowflake.sql.warehouse";
+        public const string AUTH_TYPE = "adbc.snowflake.sql.auth_type";
+        public const string HOST = "adbc.snowflake.sql.uri.host";
+        public const string PKCS8_VALUE = 
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value";
+        public const string PKCS8_PASS = 
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password";
+        public const string USE_HIGH_PRECISION = 
"adbc.snowflake.sql.client_option.use_high_precision";
+    }
+
     internal class SnowflakeTestingUtils
     {
         internal const string SNOWFLAKE_TEST_CONFIG_VARIABLE = 
"SNOWFLAKE_TEST_CONFIG_FILE";
@@ -43,13 +58,24 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
 
             parameters = new Dictionary<string, string>
             {
-                { "adbc.snowflake.sql.account", testConfiguration.Account },
-                { "username", testConfiguration.User },
-                { "password", testConfiguration.Password },
-                { "adbc.snowflake.sql.warehouse", testConfiguration.Warehouse 
},
-                { "adbc.snowflake.sql.auth_type", 
testConfiguration.AuthenticationType }
+                { SnowflakeParameters.ACCOUNT, testConfiguration.Account },
+                { SnowflakeParameters.USERNAME, testConfiguration.User },
+                { SnowflakeParameters.PASSWORD, testConfiguration.Password },
+                { SnowflakeParameters.WAREHOUSE, testConfiguration.Warehouse },
+                { SnowflakeParameters.AUTH_TYPE, 
testConfiguration.AuthenticationType },
+                { SnowflakeParameters.USE_HIGH_PRECISION, 
testConfiguration.UseHighPrecision.ToString().ToLowerInvariant() }
             };
 
+            if(!string.IsNullOrWhiteSpace(testConfiguration.Host))
+            {
+                parameters[SnowflakeParameters.HOST] = testConfiguration.Host;
+            }
+
+            if(!string.IsNullOrWhiteSpace(testConfiguration.Database))
+            {
+                parameters[SnowflakeParameters.DATABASE] = 
testConfiguration.Database;
+            }
+
             Dictionary<string, string> options = new Dictionary<string, 
string>() { };
             AdbcDriver snowflakeDriver = 
CAdbcDriverImporter.Load(testConfiguration.DriverPath, 
testConfiguration.DriverEntryPoint);
 
@@ -67,7 +93,6 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             SnowflakeTestConfiguration testConfiguration
            )
         {
-            Dictionary<string, string> options = new Dictionary<string, 
string>() { };
             AdbcDriver snowflakeDriver = 
CAdbcDriverImporter.Load(testConfiguration.DriverPath, 
testConfiguration.DriverEntryPoint);
 
             return snowflakeDriver;
@@ -83,22 +108,25 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
 
             string[] sql = File.ReadAllLines("resources/SnowflakeData.sql");
 
-            string placeholder = "{ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}";
+            Dictionary<string, string> placeholderValues = new 
Dictionary<string, string>() {
+                {"{ADBC_CATALOG}", testConfiguration.Metadata.Catalog },
+                {"{ADBC_SCHEMA}", testConfiguration.Metadata.Schema },
+                {"{ADBC_TABLE}", testConfiguration.Metadata.Table }
+            };
 
             foreach (string line in sql)
             {
                 if (!line.TrimStart().StartsWith("--"))
                 {
-                    if (line.Contains(placeholder))
-                    {
-                        string modifiedLine = line.Replace(placeholder, 
$"{testConfiguration.Metadata.Catalog}.{testConfiguration.Metadata.Schema}.{testConfiguration.Metadata.Table}");
+                    string modifiedLine = line;
 
-                        content.AppendLine(modifiedLine);
-                    }
-                    else
+                    foreach(string key in placeholderValues.Keys)
                     {
-                        content.AppendLine(line);
+                        if(modifiedLine.Contains(key))
+                            modifiedLine = modifiedLine.Replace(key, 
placeholderValues[key]);
                     }
+
+                    content.AppendLine(modifiedLine);
                 }
             }
 
diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst
index 784eded8..04023a62 100644
--- a/docs/source/driver/snowflake.rst
+++ b/docs/source/driver/snowflake.rst
@@ -406,6 +406,15 @@ These options map 1:1 with the Snowflake `Config object 
<https://pkg.go.dev/gith
     private key to be read in and parsed. Commonly encoded in PEM blocks
     of type "RSA PRIVATE KEY".
 
+``adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value``
+    Parses an encrypted or unencrypted PKCS #8 private key without having to
+    read it from the file system. If using encrypted, the
+    ``adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password`` value
+    is required and used to decrypt.
+
+``adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password``
+    Passcode to use when passing an encrypted PKCS #8 value.
+
 ``adbc.snowflake.sql.client_option.disable_telemetry``
     The Snowflake driver allows for telemetry information which can be
     disabled by setting this to ``true``. Value should be either ``true``
diff --git a/go/adbc/driver/snowflake/driver.go 
b/go/adbc/driver/snowflake/driver.go
index 51a9d056..1d8ea30a 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -87,8 +87,13 @@ const (
        // specify the RSA private key to use to sign the JWT
        // this should point to a file containing a PKCS1 private key to be
        // loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY"
-       OptionJwtPrivateKey    = 
"adbc.snowflake.sql.client_option.jwt_private_key"
-       OptionDisableTelemetry = 
"adbc.snowflake.sql.client_option.disable_telemetry"
+       OptionJwtPrivateKey = "adbc.snowflake.sql.client_option.jwt_private_key"
+       // parses a private key in PKCS #8, ASN.1 DER form. Specify the private 
key
+       // value without having to load it from the file system.
+       OptionJwtPrivateKeyPkcs8Value = 
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value"
+       // a passcode to use with encrypted private keys for JWT authentication
+       OptionJwtPrivateKeyPkcs8Password = 
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password"
+       OptionDisableTelemetry           = 
"adbc.snowflake.sql.client_option.disable_telemetry"
        // snowflake driver logging level
        OptionLogTracing = "adbc.snowflake.sql.client_option.tracing"
        // When true, the MFA token is cached in the credential manager. True 
by default
diff --git a/go/adbc/driver/snowflake/driver_test.go 
b/go/adbc/driver/snowflake/driver_test.go
index 3d795463..846bc256 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -27,6 +27,7 @@ import (
        "encoding/pem"
        "fmt"
        "os"
+       "runtime"
        "strconv"
        "strings"
        "testing"
@@ -678,6 +679,92 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
        suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
 }
 
+func TestJwtAuthenticationUnencryptedValue(t *testing.T) {
+       // test doesn't participate in SnowflakeTests because
+       // JWT auth has a different behavior
+       uri, ok := os.LookupEnv("SNOWFLAKE_URI")
+       if !ok {
+               t.Skip("Cannot find the `SNOWFLAKE_URI` value")
+       }
+
+       keyValue, ok := os.LookupEnv("SNOWFLAKE_TEST_PKCS8_VALUE")
+       if !ok {
+               t.Skip("Cannot find the `SNOWFLAKE_TEST_PKCS8_VALUE` value")
+       }
+
+       ConnectWithJwt(uri, keyValue, "")
+}
+
+func TestJwtAuthenticationEncryptedValue(t *testing.T) {
+       // test doesn't participate in SnowflakeTests because
+       // JWT auth has a different behavior
+       uri, ok := os.LookupEnv("SNOWFLAKE_URI")
+       if !ok {
+               t.Skip("Cannot find the `SNOWFLAKE_URI` value")
+       }
+
+       keyValue, ok := os.LookupEnv("SNOWFLAKE_TEST_PKCS8_EN_VALUE")
+       if !ok {
+               t.Skip("Cannot find the `SNOWFLAKE_TEST_PKCS8_EN_VALUE` value")
+       }
+
+       passcode, ok := os.LookupEnv("SNOWFLAKE_TEST_PKCS8_PASS")
+       if !ok {
+               t.Skip("Cannot find the `SNOWFLAKE_TEST_PKCS8_PASS` value")
+       }
+
+       ConnectWithJwt(uri, keyValue, passcode)
+}
+
+func ConnectWithJwt(uri, keyValue, passcode string) {
+
+       // Windows funkiness
+       if runtime.GOOS == "windows" {
+               keyValue = strings.ReplaceAll(keyValue, "\\r", "\r")
+               keyValue = strings.ReplaceAll(keyValue, "\\n", "\n")
+       }
+
+       cfg, err := gosnowflake.ParseDSN(uri)
+       if err != nil {
+               panic(err)
+       }
+
+       opts := map[string]string{
+               driver.OptionAccount:                 cfg.Account,
+               adbc.OptionKeyUsername:               cfg.User,
+               driver.OptionDatabase:                cfg.Database,
+               driver.OptionSchema:                  cfg.Schema,
+               driver.OptionAuthType:                driver.OptionValueAuthJwt,
+               driver.OptionJwtPrivateKeyPkcs8Value: keyValue,
+       }
+
+       if cfg.Warehouse != "" {
+               opts[driver.OptionWarehouse] = cfg.Warehouse
+       }
+
+       if cfg.Host != "" {
+               opts[driver.OptionHost] = cfg.Host
+       }
+
+       // if doing encrypted
+       if passcode != "" {
+               opts[driver.OptionJwtPrivateKeyPkcs8Password] = passcode
+       }
+
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       adbcDriver := driver.NewDriver(mem)
+       db, err := adbcDriver.NewDatabase(opts)
+       if err != nil {
+               panic(err)
+       }
+
+       cnxn, err := db.Open(context.Background())
+       if err != nil {
+               panic(err)
+       }
+       defer cnxn.Close()
+}
+
 func (suite *SnowflakeTests) TestJwtPrivateKey() {
        // grab the username from the DSN
        cfg, err := gosnowflake.ParseDSN(suite.Quirks.dsn)
diff --git a/go/adbc/driver/snowflake/record_reader.go 
b/go/adbc/driver/snowflake/record_reader.go
index 3b76f23e..a0563c85 100644
--- a/go/adbc/driver/snowflake/record_reader.go
+++ b/go/adbc/driver/snowflake/record_reader.go
@@ -449,36 +449,28 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, ld gosnowflake
        }
 
        if len(batches) == 0 {
-               if ld.TotalRows() != 0 {
-                       // 
XXX(https://github.com/apache/arrow-adbc/issues/863): Snowflake won't return 
Arrow data for certain queries
-                       schema, err := rowTypesToArrowSchema(ctx, ld, 
useHighPrecision)
-                       if err != nil {
-                               return nil, adbc.Error{
-                                       Msg:  err.Error(),
-                                       Code: adbc.StatusInternal,
-                               }
-                       }
-
-                       bldr := array.NewRecordBuilder(alloc, schema)
-                       defer bldr.Release()
-
-                       rec, err := jsonDataToArrow(ctx, bldr, ld)
-                       if err != nil {
-                               return nil, err
-                       }
-                       defer rec.Release()
-
-                       return array.NewRecordReader(schema, 
[]arrow.Record{rec})
-               }
-               schema := arrow.NewSchema([]arrow.Field{}, nil)
-               reader, err := array.NewRecordReader(schema, []arrow.Record{})
+               schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
                if err != nil {
                        return nil, adbc.Error{
                                Msg:  err.Error(),
                                Code: adbc.StatusInternal,
                        }
                }
-               return reader, nil
+
+               bldr := array.NewRecordBuilder(alloc, schema)
+               defer bldr.Release()
+
+               rec, err := jsonDataToArrow(ctx, bldr, ld)
+               if err != nil {
+                       return nil, err
+               }
+               defer rec.Release()
+
+               if ld.TotalRows() != 0 {
+                       return array.NewRecordReader(schema, 
[]arrow.Record{rec})
+               } else {
+                       return array.NewRecordReader(schema, []arrow.Record{})
+               }
        }
 
        ch := make(chan arrow.Record, bufferSize)
diff --git a/go/adbc/driver/snowflake/snowflake_database.go 
b/go/adbc/driver/snowflake/snowflake_database.go
index 1d005705..45e3aab4 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -34,6 +34,7 @@ import (
        "github.com/apache/arrow-adbc/go/adbc"
        "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
        "github.com/snowflakedb/gosnowflake"
+       "github.com/youmark/pkcs8"
 )
 
 var (
@@ -359,6 +360,46 @@ func (d *databaseImpl) SetOptions(cnOptions 
map[string]string) error {
                        }
 
                        d.cfg.PrivateKey = key
+               case OptionJwtPrivateKeyPkcs8Value:
+                       block, _ := pem.Decode([]byte(v))
+
+                       if block == nil {
+                               return adbc.Error{
+                                       Msg:  "Failed to parse PEM block 
containing the private key",
+                                       Code: adbc.StatusInvalidArgument,
+                               }
+                       }
+
+                       var parsedKey any
+
+                       if block.Type == "ENCRYPTED PRIVATE KEY" {
+                               passcode, ok := 
cnOptions[OptionJwtPrivateKeyPkcs8Password]
+                               if ok {
+                                       parsedKey, err = 
pkcs8.ParsePKCS8PrivateKey(block.Bytes, []byte(passcode))
+                               } else {
+                                       return adbc.Error{
+                                               Msg:  
OptionJwtPrivateKeyPkcs8Password + " is not configured",
+                                               Code: 
adbc.StatusInvalidArgument,
+                                       }
+                               }
+                       } else if block.Type == "PRIVATE KEY" {
+                               parsedKey, err = 
pkcs8.ParsePKCS8PrivateKey(block.Bytes)
+                       } else {
+                               return adbc.Error{
+                                       Msg:  block.Type + " is not supported",
+                                       Code: adbc.StatusInvalidArgument,
+                               }
+                       }
+
+                       if err != nil {
+                               return adbc.Error{
+                                       Msg:  "[Snowflake] failed parsing PKCS8 
private key: " + err.Error(),
+                                       Code: adbc.StatusInvalidArgument,
+                               }
+                       }
+
+                       d.cfg.PrivateKey = parsedKey.(*rsa.PrivateKey)
+
                case OptionClientRequestMFAToken:
                        switch v {
                        case adbc.OptionValueEnabled:
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index b9e3825d..4d698a7d 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -81,6 +81,7 @@ require (
        github.com/pmezard/go-difflib v1.0.0 // indirect
        github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // 
indirect
        github.com/sirupsen/logrus v1.9.3 // indirect
+       github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a //indirect
        github.com/zeebo/xxh3 v1.0.2 // indirect
        golang.org/x/crypto v0.14.0 // indirect
        golang.org/x/mod v0.12.0 // indirect
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index a7645a8f..7db46fc6 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -141,6 +141,8 @@ github.com/stretchr/objx v0.5.0 
h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
 github.com/stretchr/testify v1.7.0/go.mod 
h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.4 
h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
 github.com/stretchr/testify v1.8.4/go.mod 
h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a 
h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
+github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod 
h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
 github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
 github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
 github.com/zeebo/xxh3 v1.0.2/go.mod 
h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py 
b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
index f45045c1..701de23e 100644
--- a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
+++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
@@ -65,6 +65,15 @@ class DatabaseOptions(enum.Enum):
     #: this should point to a file containing a PKCS1 private key to be
     #: loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY"
     JWT_PRIVATE_KEY = "adbc.snowflake.sql.client_option.jwt_private_key"
+    #: parses a private key in PKCS #8, ASN.1 DER form. Specify the private key
+    #: value without having to load it from the file system.
+    JWT_PRIVATE_KEY_VALUE = (
+        "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value"
+    )
+    #: a passcode to use with encrypted private keys for JWT authentication.
+    JWT_PRIVATE_KEY_PASSWORD = (
+        "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password"
+    )
     #: Login retry timeout EXCLUDING network roundtrip and reading http 
response
     #: use format like http://pkg.go.dev/time#ParseDuration such as
     #: "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values

Reply via email to