This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new e8444219 feat(go/adbc/driver/snowflake): enable passing private key
for JWT via string and not file (#1207)
e8444219 is described below
commit e8444219d4f2723c081a5df5ccacca284f47ac5a
Author: davidhcoe <[email protected]>
AuthorDate: Mon Oct 30 11:23:57 2023 -0400
feat(go/adbc/driver/snowflake): enable passing private key for JWT via
string and not file (#1207)
Adds support for passing PKCS8 values to Snowflake without having to
read a file.
- Includes C# integration updates to tests
- includes support for PKCS8 encryption
- Includes test case in Snowflake driver
Resolves https://github.com/apache/arrow-adbc/issues/1204
---------
Co-authored-by: David Coe <[email protected]>
Co-authored-by: David Li <[email protected]>
---
.../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 2 +-
csharp/src/Client/readme.md | 4 +-
csharp/test/Drivers/Snowflake/ClientTests.cs | 29 ++++++--
csharp/test/Drivers/Snowflake/DriverTests.cs | 12 +--
.../Drivers/Snowflake/Resources/SnowflakeData.sql | 6 +-
.../Snowflake/Resources/snowflakeconfig.json | 4 +
.../Snowflake/SnowflakeTestConfiguration.cs | 24 ++++++
.../Drivers/Snowflake/SnowflakeTestingUtils.cs | 56 ++++++++++----
docs/source/driver/snowflake.rst | 9 +++
go/adbc/driver/snowflake/driver.go | 9 ++-
go/adbc/driver/snowflake/driver_test.go | 87 ++++++++++++++++++++++
go/adbc/driver/snowflake/record_reader.go | 40 ++++------
go/adbc/driver/snowflake/snowflake_database.go | 41 ++++++++++
go/adbc/go.mod | 1 +
go/adbc/go.sum | 2 +
.../adbc_driver_snowflake/__init__.py | 9 +++
16 files changed, 278 insertions(+), 57 deletions(-)
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index 626c8357..914a7f22 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -333,7 +333,7 @@ namespace Apache.Arrow.Adbc.C
{
if (arrowArray is BooleanArray)
{
- return
Convert.ToBoolean(((BooleanArray)arrowArray).Values[index]);
+ return ((BooleanArray)arrowArray).GetValue(index);
}
else if (arrowArray is Date32Array)
{
diff --git a/csharp/src/Client/readme.md b/csharp/src/Client/readme.md
index 5ad4698b..91e33ca9 100644
--- a/csharp/src/Client/readme.md
+++ b/csharp/src/Client/readme.md
@@ -67,9 +67,9 @@ For example, when using the [Snowflake ADBC Go
Driver](https://arrow.apache.org/
if using the default user name and password authentication, but look like
```
-
adbc.snowflake.sql.account={account};adbc.snowflake.sql.warehouse={warehouse};username={user};password={password};adbc.snowflake.sql.auth_type=snowflake_jwt;adbc.snowflake.sql.client_option.auth_token={token}
+
adbc.snowflake.sql.account={account};adbc.snowflake.sql.warehouse={warehouse};username={user};password={password};adbc.snowflake.sql.auth_type=snowflake_jwt;adbc.snowflake.sql.client_option.jwt_private_key={private_key_file}
```
-when using JWT authentication.
+when using JWT authentication with an unencrypted key file.
Other ADBC drivers will have different connection parameters, so be sure to
check the documentation for each driver.
diff --git a/csharp/test/Drivers/Snowflake/ClientTests.cs
b/csharp/test/Drivers/Snowflake/ClientTests.cs
index abd1ac48..3eaf13c0 100644
--- a/csharp/test/Drivers/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Snowflake/ClientTests.cs
@@ -156,6 +156,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
while (reader.Read())
{
count++;
+
+ for (int i = 0; i < reader.FieldCount; i++)
+ {
+ Console.WriteLine($"{reader.GetName(i)}:
{reader.GetValue(i)}");
+ }
}
}
finally { reader.Close(); }
@@ -251,19 +256,31 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
DbConnectionStringBuilder builder = new
DbConnectionStringBuilder(true);
- builder["adbc.snowflake.sql.account"] = testConfiguration.Account;
- builder["adbc.snowflake.sql.warehouse"] =
testConfiguration.Warehouse;
- builder["username"] = testConfiguration.User;
+ builder[SnowflakeParameters.ACCOUNT] = testConfiguration.Account;
+ builder[SnowflakeParameters.WAREHOUSE] =
testConfiguration.Warehouse;
+ builder[SnowflakeParameters.HOST] = testConfiguration.Host;
+ builder[SnowflakeParameters.DATABASE] = testConfiguration.Database;
+ builder[SnowflakeParameters.USERNAME] = testConfiguration.User;
if
(!string.IsNullOrEmpty(testConfiguration.AuthenticationTokenPath))
{
+ builder[SnowflakeParameters.AUTH_TYPE] =
testConfiguration.AuthenticationType;
+
string privateKey =
File.ReadAllText(testConfiguration.AuthenticationTokenPath);
- builder["adbc.snowflake.sql.auth_type"] =
testConfiguration.AuthenticationType;
- builder["adbc.snowflake.sql.client_option.auth_token"] =
privateKey;
+
+ if (testConfiguration.AuthenticationType.Equals("auth_jwt",
StringComparison.OrdinalIgnoreCase))
+ {
+ builder[SnowflakeParameters.PKCS8_VALUE] = privateKey;
+
+ if(!string.IsNullOrEmpty(testConfiguration.Pkcs8Passcode))
+ {
+ builder[SnowflakeParameters.PKCS8_PASS] =
testConfiguration.Pkcs8Passcode;
+ }
+ }
}
else
{
- builder["password"] = testConfiguration.Password;
+ builder[SnowflakeParameters.PASSWORD] =
testConfiguration.Password;
}
AdbcDriver snowflakeDriver =
SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration);
diff --git a/csharp/test/Drivers/Snowflake/DriverTests.cs
b/csharp/test/Drivers/Snowflake/DriverTests.cs
index 21593f5c..41527e99 100644
--- a/csharp/test/Drivers/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Snowflake/DriverTests.cs
@@ -132,8 +132,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
string tableName = metadataTestConfiguration.Metadata.Table;
string columnName = null;
- parameters["adbc.snowflake.sql.db"] = databaseName;
- parameters["adbc.snowflake.sql.schema"] = schemaName;
+ parameters[SnowflakeParameters.DATABASE] = databaseName;
+ parameters[SnowflakeParameters.SCHEMA] = schemaName;
AdbcDatabase adbcDatabase = driver.Open(parameters);
AdbcConnection adbcConnection = adbcDatabase.Connect(new
Dictionary<string, string>());
@@ -158,6 +158,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
.Select(c => c.Columns)
.FirstOrDefault();
+ Assert.IsTrue(columns != null, "Columns cannot be null");
+
Assert.AreEqual(metadataTestConfiguration.Metadata.ExpectedColumnCount,
columns.Count);
}
}
@@ -215,9 +217,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
StringArray stringArray =
(StringArray)recordBatch.Column("table_type");
List<string> known_types = new List<string>
- {
- "BASE TABLE", "TEMPORARY TABLE", "VIEW"
- };
+ {
+ "BASE TABLE", "TEMPORARY TABLE", "VIEW"
+ };
int results = 0;
diff --git a/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
b/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
index 48ceead9..d17dfd4a 100644
--- a/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
+++ b/csharp/test/Drivers/Snowflake/Resources/SnowflakeData.sql
@@ -14,7 +14,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
-CREATE OR REPLACE TRANSIENT TABLE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
+CREATE OR REPLACE TRANSIENT TABLE {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_TABLE} (
NUMBERTYPE NUMBER(38,0),
DECIMALTYPE NUMBER(18,1),
NUMERICTYPE NUMBER(28,1),
@@ -47,7 +47,7 @@ CREATE OR REPLACE TRANSIENT TABLE
{ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
TIMESTAMPTZTYPE TIMESTAMP_TZ(9)
);
-INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}
+INSERT INTO {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_TABLE}
(
NUMBERTYPE, DECIMALTYPE, NUMERICTYPE, INTTYPE, INTEGERTYPE, BIGINTTYPE,
SMALLINTTYPE, TINYINTTYPE,
BYTEINTTYPE, FLOATTYPE, FLOAT4TYPE, FLOAT8TYPE, DOUBLETYPE,
DOUBLEPRECISIONTYPE, REALTYPE, VARCHARTYPE,
@@ -62,7 +62,7 @@ VALUES
'12:34:56', '2023-07-28 12:34:56', '2023-07-28 12:34:56', '2023-07-28
12:34:56', '2023-07-28 12:34:56'
);
-INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}
+INSERT INTO {ADBC_CATALOG}.{ADBC_SCHEMA}.{ADBC_TABLE}
(
NUMBERTYPE, DECIMALTYPE, NUMERICTYPE, INTTYPE, INTEGERTYPE, BIGINTTYPE,
SMALLINTTYPE, TINYINTTYPE,
BYTEINTTYPE, FLOATTYPE, FLOAT4TYPE, FLOAT8TYPE, DOUBLETYPE,
DOUBLEPRECISIONTYPE, REALTYPE, VARCHARTYPE,
diff --git a/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
b/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
index e14e2d41..d3b43581 100644
--- a/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
+++ b/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json
@@ -2,11 +2,15 @@
"driverPath": "",
"driverEntryPoint": "",
"account": "",
+ "host": "",
+ "database": "",
"user": "",
"password": "",
"warehouse": "",
"authenticationType": "",
"authenticationTokenPath": "",
+ "pkcs8Passcode": "",
+ "useHighPrecision": true,
"metadata": {
"catalog": "",
"schema": "",
diff --git a/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
b/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
index 591f851d..96bdcc22 100644
--- a/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
+++ b/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs
@@ -42,6 +42,18 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
[JsonPropertyName("account")]
public string Account { get; set; }
+ /// <summary>
+ /// The Snowflake host.
+ /// </summary>
+ [JsonPropertyName("host")]
+ public string Host { get; set; }
+
+ /// <summary>
+ /// The Snowflake database.
+ /// </summary>
+ [JsonPropertyName("database")]
+ public string Database { get; set; }
+
/// <summary>
/// The Snowflake user.
/// </summary>
@@ -71,5 +83,17 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
/// </summary>
[JsonPropertyName("authenticationTokenPath"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
public string AuthenticationTokenPath { get; set; }
+
+ /// <summary>
+ /// The passcode to use if the JWT token is encrypted.
+ /// </summary>
+ [JsonPropertyName("pkcs8Passcode"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
+ public string Pkcs8Passcode { get; set; }
+
+ /// <summary>
+ /// The Snowflake authentication type.
+ /// </summary>
+ [JsonPropertyName("useHighPrecision")]
+ public bool UseHighPrecision { get; set; } = true;
}
}
diff --git a/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
b/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
index 3ef8e718..0da3eff7 100644
--- a/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
+++ b/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs
@@ -23,6 +23,21 @@ using Apache.Arrow.Adbc.C;
namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
+ internal class SnowflakeParameters
+ {
+ public const string DATABASE = "adbc.snowflake.sql.db";
+ public const string SCHEMA = "adbc.snowflake.sql.schema";
+ public const string ACCOUNT = "adbc.snowflake.sql.account";
+ public const string USERNAME = "username";
+ public const string PASSWORD = "password";
+ public const string WAREHOUSE = "adbc.snowflake.sql.warehouse";
+ public const string AUTH_TYPE = "adbc.snowflake.sql.auth_type";
+ public const string HOST = "adbc.snowflake.sql.uri.host";
+ public const string PKCS8_VALUE =
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value";
+ public const string PKCS8_PASS =
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password";
+ public const string USE_HIGH_PRECISION =
"adbc.snowflake.sql.client_option.use_high_precision";
+ }
+
internal class SnowflakeTestingUtils
{
internal const string SNOWFLAKE_TEST_CONFIG_VARIABLE =
"SNOWFLAKE_TEST_CONFIG_FILE";
@@ -43,13 +58,24 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
parameters = new Dictionary<string, string>
{
- { "adbc.snowflake.sql.account", testConfiguration.Account },
- { "username", testConfiguration.User },
- { "password", testConfiguration.Password },
- { "adbc.snowflake.sql.warehouse", testConfiguration.Warehouse
},
- { "adbc.snowflake.sql.auth_type",
testConfiguration.AuthenticationType }
+ { SnowflakeParameters.ACCOUNT, testConfiguration.Account },
+ { SnowflakeParameters.USERNAME, testConfiguration.User },
+ { SnowflakeParameters.PASSWORD, testConfiguration.Password },
+ { SnowflakeParameters.WAREHOUSE, testConfiguration.Warehouse },
+ { SnowflakeParameters.AUTH_TYPE,
testConfiguration.AuthenticationType },
+ { SnowflakeParameters.USE_HIGH_PRECISION,
testConfiguration.UseHighPrecision.ToString().ToLowerInvariant() }
};
+ if(!string.IsNullOrWhiteSpace(testConfiguration.Host))
+ {
+ parameters[SnowflakeParameters.HOST] = testConfiguration.Host;
+ }
+
+ if(!string.IsNullOrWhiteSpace(testConfiguration.Database))
+ {
+ parameters[SnowflakeParameters.DATABASE] =
testConfiguration.Database;
+ }
+
Dictionary<string, string> options = new Dictionary<string,
string>() { };
AdbcDriver snowflakeDriver =
CAdbcDriverImporter.Load(testConfiguration.DriverPath,
testConfiguration.DriverEntryPoint);
@@ -67,7 +93,6 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
SnowflakeTestConfiguration testConfiguration
)
{
- Dictionary<string, string> options = new Dictionary<string,
string>() { };
AdbcDriver snowflakeDriver =
CAdbcDriverImporter.Load(testConfiguration.DriverPath,
testConfiguration.DriverEntryPoint);
return snowflakeDriver;
@@ -83,22 +108,25 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
string[] sql = File.ReadAllLines("resources/SnowflakeData.sql");
- string placeholder = "{ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}";
+ Dictionary<string, string> placeholderValues = new
Dictionary<string, string>() {
+ {"{ADBC_CATALOG}", testConfiguration.Metadata.Catalog },
+ {"{ADBC_SCHEMA}", testConfiguration.Metadata.Schema },
+ {"{ADBC_TABLE}", testConfiguration.Metadata.Table }
+ };
foreach (string line in sql)
{
if (!line.TrimStart().StartsWith("--"))
{
- if (line.Contains(placeholder))
- {
- string modifiedLine = line.Replace(placeholder,
$"{testConfiguration.Metadata.Catalog}.{testConfiguration.Metadata.Schema}.{testConfiguration.Metadata.Table}");
+ string modifiedLine = line;
- content.AppendLine(modifiedLine);
- }
- else
+ foreach(string key in placeholderValues.Keys)
{
- content.AppendLine(line);
+ if(modifiedLine.Contains(key))
+ modifiedLine = modifiedLine.Replace(key,
placeholderValues[key]);
}
+
+ content.AppendLine(modifiedLine);
}
}
diff --git a/docs/source/driver/snowflake.rst b/docs/source/driver/snowflake.rst
index 784eded8..04023a62 100644
--- a/docs/source/driver/snowflake.rst
+++ b/docs/source/driver/snowflake.rst
@@ -406,6 +406,15 @@ These options map 1:1 with the Snowflake `Config object
<https://pkg.go.dev/gith
private key to be read in and parsed. Commonly encoded in PEM blocks
of type "RSA PRIVATE KEY".
+``adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value``
+ Parses an encrypted or unencrypted PKCS #8 private key without having to
+ read it from the file system. If using encrypted, the
+ ``adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password`` value
+ is required and used to decrypt.
+
+``adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password``
+ Passcode to use when passing an encrypted PKCS #8 value.
+
``adbc.snowflake.sql.client_option.disable_telemetry``
The Snowflake driver allows for telemetry information which can be
disabled by setting this to ``true``. Value should be either ``true``
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/driver.go
index 51a9d056..1d8ea30a 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -87,8 +87,13 @@ const (
// specify the RSA private key to use to sign the JWT
// this should point to a file containing a PKCS1 private key to be
// loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY"
- OptionJwtPrivateKey =
"adbc.snowflake.sql.client_option.jwt_private_key"
- OptionDisableTelemetry =
"adbc.snowflake.sql.client_option.disable_telemetry"
+ OptionJwtPrivateKey = "adbc.snowflake.sql.client_option.jwt_private_key"
+ // parses a private key in PKCS #8, ASN.1 DER form. Specify the private
key
+ // value without having to load it from the file system.
+ OptionJwtPrivateKeyPkcs8Value =
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value"
+ // a passcode to use with encrypted private keys for JWT authentication
+ OptionJwtPrivateKeyPkcs8Password =
"adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password"
+ OptionDisableTelemetry =
"adbc.snowflake.sql.client_option.disable_telemetry"
// snowflake driver logging level
OptionLogTracing = "adbc.snowflake.sql.client_option.tracing"
// When true, the MFA token is cached in the credential manager. True
by default
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index 3d795463..846bc256 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -27,6 +27,7 @@ import (
"encoding/pem"
"fmt"
"os"
+ "runtime"
"strconv"
"strings"
"testing"
@@ -678,6 +679,92 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
}
+func TestJwtAuthenticationUnencryptedValue(t *testing.T) {
+ // test doesn't participate in SnowflakeTests because
+ // JWT auth has a different behavior
+ uri, ok := os.LookupEnv("SNOWFLAKE_URI")
+ if !ok {
+ t.Skip("Cannot find the `SNOWFLAKE_URI` value")
+ }
+
+ keyValue, ok := os.LookupEnv("SNOWFLAKE_TEST_PKCS8_VALUE")
+ if !ok {
+ t.Skip("Cannot find the `SNOWFLAKE_TEST_PKCS8_VALUE` value")
+ }
+
+ ConnectWithJwt(uri, keyValue, "")
+}
+
+func TestJwtAuthenticationEncryptedValue(t *testing.T) {
+ // test doesn't participate in SnowflakeTests because
+ // JWT auth has a different behavior
+ uri, ok := os.LookupEnv("SNOWFLAKE_URI")
+ if !ok {
+ t.Skip("Cannot find the `SNOWFLAKE_URI` value")
+ }
+
+ keyValue, ok := os.LookupEnv("SNOWFLAKE_TEST_PKCS8_EN_VALUE")
+ if !ok {
+ t.Skip("Cannot find the `SNOWFLAKE_TEST_PKCS8_EN_VALUE` value")
+ }
+
+ passcode, ok := os.LookupEnv("SNOWFLAKE_TEST_PKCS8_PASS")
+ if !ok {
+ t.Skip("Cannot find the `SNOWFLAKE_TEST_PKCS8_PASS` value")
+ }
+
+ ConnectWithJwt(uri, keyValue, passcode)
+}
+
+func ConnectWithJwt(uri, keyValue, passcode string) {
+
+ // Windows funkiness
+ if runtime.GOOS == "windows" {
+ keyValue = strings.ReplaceAll(keyValue, "\\r", "\r")
+ keyValue = strings.ReplaceAll(keyValue, "\\n", "\n")
+ }
+
+ cfg, err := gosnowflake.ParseDSN(uri)
+ if err != nil {
+ panic(err)
+ }
+
+ opts := map[string]string{
+ driver.OptionAccount: cfg.Account,
+ adbc.OptionKeyUsername: cfg.User,
+ driver.OptionDatabase: cfg.Database,
+ driver.OptionSchema: cfg.Schema,
+ driver.OptionAuthType: driver.OptionValueAuthJwt,
+ driver.OptionJwtPrivateKeyPkcs8Value: keyValue,
+ }
+
+ if cfg.Warehouse != "" {
+ opts[driver.OptionWarehouse] = cfg.Warehouse
+ }
+
+ if cfg.Host != "" {
+ opts[driver.OptionHost] = cfg.Host
+ }
+
+ // if doing encrypted
+ if passcode != "" {
+ opts[driver.OptionJwtPrivateKeyPkcs8Password] = passcode
+ }
+
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ adbcDriver := driver.NewDriver(mem)
+ db, err := adbcDriver.NewDatabase(opts)
+ if err != nil {
+ panic(err)
+ }
+
+ cnxn, err := db.Open(context.Background())
+ if err != nil {
+ panic(err)
+ }
+ defer cnxn.Close()
+}
+
func (suite *SnowflakeTests) TestJwtPrivateKey() {
// grab the username from the DSN
cfg, err := gosnowflake.ParseDSN(suite.Quirks.dsn)
diff --git a/go/adbc/driver/snowflake/record_reader.go
b/go/adbc/driver/snowflake/record_reader.go
index 3b76f23e..a0563c85 100644
--- a/go/adbc/driver/snowflake/record_reader.go
+++ b/go/adbc/driver/snowflake/record_reader.go
@@ -449,36 +449,28 @@ func newRecordReader(ctx context.Context, alloc
memory.Allocator, ld gosnowflake
}
if len(batches) == 0 {
- if ld.TotalRows() != 0 {
- //
XXX(https://github.com/apache/arrow-adbc/issues/863): Snowflake won't return
Arrow data for certain queries
- schema, err := rowTypesToArrowSchema(ctx, ld,
useHighPrecision)
- if err != nil {
- return nil, adbc.Error{
- Msg: err.Error(),
- Code: adbc.StatusInternal,
- }
- }
-
- bldr := array.NewRecordBuilder(alloc, schema)
- defer bldr.Release()
-
- rec, err := jsonDataToArrow(ctx, bldr, ld)
- if err != nil {
- return nil, err
- }
- defer rec.Release()
-
- return array.NewRecordReader(schema,
[]arrow.Record{rec})
- }
- schema := arrow.NewSchema([]arrow.Field{}, nil)
- reader, err := array.NewRecordReader(schema, []arrow.Record{})
+ schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInternal,
}
}
- return reader, nil
+
+ bldr := array.NewRecordBuilder(alloc, schema)
+ defer bldr.Release()
+
+ rec, err := jsonDataToArrow(ctx, bldr, ld)
+ if err != nil {
+ return nil, err
+ }
+ defer rec.Release()
+
+ if ld.TotalRows() != 0 {
+ return array.NewRecordReader(schema,
[]arrow.Record{rec})
+ } else {
+ return array.NewRecordReader(schema, []arrow.Record{})
+ }
}
ch := make(chan arrow.Record, bufferSize)
diff --git a/go/adbc/driver/snowflake/snowflake_database.go
b/go/adbc/driver/snowflake/snowflake_database.go
index 1d005705..45e3aab4 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -34,6 +34,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
"github.com/snowflakedb/gosnowflake"
+ "github.com/youmark/pkcs8"
)
var (
@@ -359,6 +360,46 @@ func (d *databaseImpl) SetOptions(cnOptions
map[string]string) error {
}
d.cfg.PrivateKey = key
+ case OptionJwtPrivateKeyPkcs8Value:
+ block, _ := pem.Decode([]byte(v))
+
+ if block == nil {
+ return adbc.Error{
+ Msg: "Failed to parse PEM block
containing the private key",
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ var parsedKey any
+
+ if block.Type == "ENCRYPTED PRIVATE KEY" {
+ passcode, ok :=
cnOptions[OptionJwtPrivateKeyPkcs8Password]
+ if ok {
+ parsedKey, err =
pkcs8.ParsePKCS8PrivateKey(block.Bytes, []byte(passcode))
+ } else {
+ return adbc.Error{
+ Msg:
OptionJwtPrivateKeyPkcs8Password + " is not configured",
+ Code:
adbc.StatusInvalidArgument,
+ }
+ }
+ } else if block.Type == "PRIVATE KEY" {
+ parsedKey, err =
pkcs8.ParsePKCS8PrivateKey(block.Bytes)
+ } else {
+ return adbc.Error{
+ Msg: block.Type + " is not supported",
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ if err != nil {
+ return adbc.Error{
+ Msg: "[Snowflake] failed parsing PKCS8
private key: " + err.Error(),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ d.cfg.PrivateKey = parsedKey.(*rsa.PrivateKey)
+
case OptionClientRequestMFAToken:
switch v {
case adbc.OptionValueEnabled:
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index b9e3825d..4d698a7d 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -81,6 +81,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec //
indirect
github.com/sirupsen/logrus v1.9.3 // indirect
+ github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a //indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/mod v0.12.0 // indirect
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index a7645a8f..7db46fc6 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -141,6 +141,8 @@ github.com/stretchr/objx v0.5.0
h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/testify v1.7.0/go.mod
h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4
h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod
h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
+github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod
h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod
h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
index f45045c1..701de23e 100644
--- a/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
+++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/__init__.py
@@ -65,6 +65,15 @@ class DatabaseOptions(enum.Enum):
#: this should point to a file containing a PKCS1 private key to be
#: loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY"
JWT_PRIVATE_KEY = "adbc.snowflake.sql.client_option.jwt_private_key"
+ #: parses a private key in PKCS #8, ASN.1 DER form. Specify the private key
+ #: value without having to load it from the file system.
+ JWT_PRIVATE_KEY_VALUE = (
+ "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value"
+ )
+ #: a passcode to use with encrypted private keys for JWT authentication.
+ JWT_PRIVATE_KEY_PASSWORD = (
+ "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password"
+ )
#: Login retry timeout EXCLUDING network roundtrip and reading http
response
#: use format like http://pkg.go.dev/time#ParseDuration such as
#: "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values