This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 126235b8 feat(go/adbc/driver/snowflake): support PEM decoding JWT
private keys (#1199)
126235b8 is described below
commit 126235b8ff234c67d141297f0cc1a5075721c1ce
Author: Aaron Ross <[email protected]>
AuthorDate: Fri Oct 20 13:05:02 2023 -0700
feat(go/adbc/driver/snowflake): support PEM decoding JWT private keys
(#1199)
Resolves #1198.
---
go/adbc/driver/snowflake/driver_test.go | 91 ++++++++++++++++++++++++++
go/adbc/driver/snowflake/snowflake_database.go | 26 +++++++-
2 files changed, 116 insertions(+), 1 deletion(-)
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index c68ca0ab..3d795463 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -19,7 +19,12 @@ package snowflake_test
import (
"context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
"database/sql"
+ "encoding/base64"
+ "encoding/pem"
"fmt"
"os"
"strconv"
@@ -672,3 +677,89 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
suite.Equal(1234567.89, rec.Column(1).(*array.Float64).Value(0))
suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
}
+
+func (suite *SnowflakeTests) TestJwtPrivateKey() {
+ // grab the username from the DSN
+ cfg, err := gosnowflake.ParseDSN(suite.Quirks.dsn)
+ suite.NoError(err)
+ username := cfg.User
+
+ // write the generated RSA key out to a file
+ writeKey := func(filename string, key []byte) string {
+ f, err := os.CreateTemp("", filename)
+ suite.NoError(err)
+ _, err = f.Write(key)
+ suite.NoError(err)
+ return f.Name()
+ }
+
+ // set the Snowflake user's RSA public key
+ setKey := func(privKey *rsa.PrivateKey) {
+ suite.NoError(suite.stmt.SetSqlQuery("USE ROLE ACCOUNTADMIN"))
+ _, err := suite.stmt.ExecuteUpdate(suite.ctx)
+ suite.NoError(err)
+
+ if privKey != nil {
+ pubKeyBytes, err :=
x509.MarshalPKIXPublicKey(privKey.Public())
+ suite.NoError(err)
+ encodedKey :=
base64.StdEncoding.EncodeToString(pubKeyBytes)
+ suite.NoError(suite.stmt.SetSqlQuery(fmt.Sprintf("ALTER
USER %s SET RSA_PUBLIC_KEY='%s'", username, encodedKey)))
+ } else {
+ suite.NoError(suite.stmt.SetSqlQuery(fmt.Sprintf("ALTER
USER %s SET RSA_PUBLIC_KEY=''", username)))
+ }
+ _, err = suite.stmt.ExecuteUpdate(suite.ctx)
+ suite.NoError(err)
+ }
+
+ // open a new connection using JWT authentication and verify that a
simple query runs
+ verifyKey := func(keyFile string) {
+ opts := suite.Quirks.DatabaseOptions()
+ opts[driver.OptionAuthType] = driver.OptionValueAuthJwt
+ opts[driver.OptionJwtPrivateKey] = keyFile
+ db, err := suite.driver.NewDatabase(opts)
+ suite.NoError(err)
+ cnxn, err := db.Open(suite.ctx)
+ suite.NoError(err)
+ defer cnxn.Close()
+ stmt, err := cnxn.NewStatement()
+ suite.NoError(err)
+ defer stmt.Close()
+
+ suite.NoError(stmt.SetSqlQuery("SELECT 1"))
+ rdr, _, err := stmt.ExecuteQuery(suite.ctx)
+ defer rdr.Release()
+ suite.NoError(err)
+ }
+
+ // generate a key and set it the Snowflake user
+ rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
+ setKey(rsaKey)
+
+ // when the test concludes, reset the user's key
+ defer setKey(nil)
+
+ // PKCS1 key
+ rsaKeyPem := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
+ })
+ pkcs1Key := writeKey("key.pem", rsaKeyPem)
+ defer os.Remove(pkcs1Key)
+ verifyKey(pkcs1Key)
+
+ // PKCS8 key
+ rsaKeyP8Bytes, _ := x509.MarshalPKCS8PrivateKey(rsaKey)
+ rsaKeyP8 := pem.EncodeToMemory(&pem.Block{
+ Type: "PRIVATE KEY",
+ Bytes: rsaKeyP8Bytes,
+ })
+ pkcs8Key := writeKey("key.p8", rsaKeyP8)
+ defer os.Remove(pkcs8Key)
+ verifyKey(pkcs8Key)
+
+ // binary key
+ block, _ := pem.Decode([]byte(rsaKeyPem))
+ binKey := writeKey("key.bin", block.Bytes)
+ defer os.Remove(binKey)
+ verifyKey(binKey)
+}
diff --git a/go/adbc/driver/snowflake/snowflake_database.go
b/go/adbc/driver/snowflake/snowflake_database.go
index 418c08e3..1d005705 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -19,12 +19,16 @@ package snowflake
import (
"context"
+ "crypto/rsa"
"crypto/x509"
"database/sql"
+ "encoding/pem"
+ "errors"
"fmt"
"net/url"
"os"
"strconv"
+ "strings"
"time"
"github.com/apache/arrow-adbc/go/adbc"
@@ -328,13 +332,33 @@ func (d *databaseImpl) SetOptions(cnOptions
map[string]string) error {
}
}
- d.cfg.PrivateKey, err = x509.ParsePKCS1PrivateKey(data)
+ var block []byte
+ if strings.Contains(string(data), "PRIVATE KEY") {
+ b, _ := pem.Decode(data)
+ block = b.Bytes
+ } else {
+ block = data
+ }
+
+ var key *rsa.PrivateKey
+ key, err = x509.ParsePKCS1PrivateKey(block)
+ if err != nil && strings.Contains(err.Error(), "use
ParsePKCS8PrivateKey instead") {
+ var pkcs8Key any
+ pkcs8Key, err = x509.ParsePKCS8PrivateKey(block)
+ key, ok = pkcs8Key.(*rsa.PrivateKey)
+ if !ok {
+ err = errors.New("file does not contain
an RSA private key")
+ }
+ }
+
if err != nil {
return adbc.Error{
Msg: "failed parsing private key file
'" + v + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
+
+ d.cfg.PrivateKey = key
case OptionClientRequestMFAToken:
switch v {
case adbc.OptionValueEnabled: