This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 126235b8 feat(go/adbc/driver/snowflake): support PEM decoding JWT 
private keys (#1199)
126235b8 is described below

commit 126235b8ff234c67d141297f0cc1a5075721c1ce
Author: Aaron Ross <[email protected]>
AuthorDate: Fri Oct 20 13:05:02 2023 -0700

    feat(go/adbc/driver/snowflake): support PEM decoding JWT private keys 
(#1199)
    
    Resolves #1198.
---
 go/adbc/driver/snowflake/driver_test.go        | 91 ++++++++++++++++++++++++++
 go/adbc/driver/snowflake/snowflake_database.go | 26 +++++++-
 2 files changed, 116 insertions(+), 1 deletion(-)

diff --git a/go/adbc/driver/snowflake/driver_test.go 
b/go/adbc/driver/snowflake/driver_test.go
index c68ca0ab..3d795463 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -19,7 +19,12 @@ package snowflake_test
 
 import (
        "context"
+       "crypto/rand"
+       "crypto/rsa"
+       "crypto/x509"
        "database/sql"
+       "encoding/base64"
+       "encoding/pem"
        "fmt"
        "os"
        "strconv"
@@ -672,3 +677,89 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
        suite.Equal(1234567.89, rec.Column(1).(*array.Float64).Value(0))
        suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
 }
+
+func (suite *SnowflakeTests) TestJwtPrivateKey() {
+       // grab the username from the DSN
+       cfg, err := gosnowflake.ParseDSN(suite.Quirks.dsn)
+       suite.NoError(err)
+       username := cfg.User
+
+       // write the generated RSA key out to a file
+       writeKey := func(filename string, key []byte) string {
+               f, err := os.CreateTemp("", filename)
+               suite.NoError(err)
+               _, err = f.Write(key)
+               suite.NoError(err)
+               return f.Name()
+       }
+
+       // set the Snowflake user's RSA public key
+       setKey := func(privKey *rsa.PrivateKey) {
+               suite.NoError(suite.stmt.SetSqlQuery("USE ROLE ACCOUNTADMIN"))
+               _, err := suite.stmt.ExecuteUpdate(suite.ctx)
+               suite.NoError(err)
+
+               if privKey != nil {
+                       pubKeyBytes, err := 
x509.MarshalPKIXPublicKey(privKey.Public())
+                       suite.NoError(err)
+                       encodedKey := 
base64.StdEncoding.EncodeToString(pubKeyBytes)
+                       suite.NoError(suite.stmt.SetSqlQuery(fmt.Sprintf("ALTER 
USER %s SET RSA_PUBLIC_KEY='%s'", username, encodedKey)))
+               } else {
+                       suite.NoError(suite.stmt.SetSqlQuery(fmt.Sprintf("ALTER 
USER %s SET RSA_PUBLIC_KEY=''", username)))
+               }
+               _, err = suite.stmt.ExecuteUpdate(suite.ctx)
+               suite.NoError(err)
+       }
+
+       // open a new connection using JWT authentication and verify that a 
simple query runs
+       verifyKey := func(keyFile string) {
+               opts := suite.Quirks.DatabaseOptions()
+               opts[driver.OptionAuthType] = driver.OptionValueAuthJwt
+               opts[driver.OptionJwtPrivateKey] = keyFile
+               db, err := suite.driver.NewDatabase(opts)
+               suite.NoError(err)
+               cnxn, err := db.Open(suite.ctx)
+               suite.NoError(err)
+               defer cnxn.Close()
+               stmt, err := cnxn.NewStatement()
+               suite.NoError(err)
+               defer stmt.Close()
+
+               suite.NoError(stmt.SetSqlQuery("SELECT 1"))
+               rdr, _, err := stmt.ExecuteQuery(suite.ctx)
+               defer rdr.Release()
+               suite.NoError(err)
+       }
+
+       // generate a key and set it the Snowflake user
+       rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
+       setKey(rsaKey)
+
+       // when the test concludes, reset the user's key
+       defer setKey(nil)
+
+       // PKCS1 key
+       rsaKeyPem := pem.EncodeToMemory(&pem.Block{
+               Type:  "RSA PRIVATE KEY",
+               Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
+       })
+       pkcs1Key := writeKey("key.pem", rsaKeyPem)
+       defer os.Remove(pkcs1Key)
+       verifyKey(pkcs1Key)
+
+       // PKCS8 key
+       rsaKeyP8Bytes, _ := x509.MarshalPKCS8PrivateKey(rsaKey)
+       rsaKeyP8 := pem.EncodeToMemory(&pem.Block{
+               Type:  "PRIVATE KEY",
+               Bytes: rsaKeyP8Bytes,
+       })
+       pkcs8Key := writeKey("key.p8", rsaKeyP8)
+       defer os.Remove(pkcs8Key)
+       verifyKey(pkcs8Key)
+
+       // binary key
+       block, _ := pem.Decode([]byte(rsaKeyPem))
+       binKey := writeKey("key.bin", block.Bytes)
+       defer os.Remove(binKey)
+       verifyKey(binKey)
+}
diff --git a/go/adbc/driver/snowflake/snowflake_database.go 
b/go/adbc/driver/snowflake/snowflake_database.go
index 418c08e3..1d005705 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -19,12 +19,16 @@ package snowflake
 
 import (
        "context"
+       "crypto/rsa"
        "crypto/x509"
        "database/sql"
+       "encoding/pem"
+       "errors"
        "fmt"
        "net/url"
        "os"
        "strconv"
+       "strings"
        "time"
 
        "github.com/apache/arrow-adbc/go/adbc"
@@ -328,13 +332,33 @@ func (d *databaseImpl) SetOptions(cnOptions 
map[string]string) error {
                                }
                        }
 
-                       d.cfg.PrivateKey, err = x509.ParsePKCS1PrivateKey(data)
+                       var block []byte
+                       if strings.Contains(string(data), "PRIVATE KEY") {
+                               b, _ := pem.Decode(data)
+                               block = b.Bytes
+                       } else {
+                               block = data
+                       }
+
+                       var key *rsa.PrivateKey
+                       key, err = x509.ParsePKCS1PrivateKey(block)
+                       if err != nil && strings.Contains(err.Error(), "use 
ParsePKCS8PrivateKey instead") {
+                               var pkcs8Key any
+                               pkcs8Key, err = x509.ParsePKCS8PrivateKey(block)
+                               key, ok = pkcs8Key.(*rsa.PrivateKey)
+                               if !ok {
+                                       err = errors.New("file does not contain 
an RSA private key")
+                               }
+                       }
+
                        if err != nil {
                                return adbc.Error{
                                        Msg:  "failed parsing private key file 
'" + v + "': " + err.Error(),
                                        Code: adbc.StatusInvalidArgument,
                                }
                        }
+
+                       d.cfg.PrivateKey = key
                case OptionClientRequestMFAToken:
                        switch v {
                        case adbc.OptionValueEnabled:

Reply via email to