This is an automated email from the ASF dual-hosted git repository.

lahirujayathilake pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata-custos.git

commit d1e8bc5f4a46b66e246f6c7bbd9ab9b52f1ceb60
Author: lahiruj <[email protected]>
AuthorDate: Thu Apr 2 21:44:24 2026 -0400

    maintain a denied list for the extensions in client config and move the 
force-command to the certificate issuance
---
 signer/internal/audit/logger.go             |  6 +++
 signer/internal/cert/signer.go              |  6 +--
 signer/internal/cert/signer_test.go         |  6 ++-
 signer/internal/handler/certificates.go     | 30 ++++++++-----
 signer/internal/handler/sign.go             | 68 ++++++++++++++++++++---------
 signer/internal/policy/enforcer.go          | 17 ++++++--
 signer/internal/policy/enforcer_test.go     | 57 +++++++++++++++---------
 signer/internal/store/certificate_query.go  | 28 ++++++++++++
 signer/internal/store/client_config.go      | 22 +++-------
 signer/internal/store/issuance_log.go       | 21 ++++++---
 signer/migrations/001_initial_schema.up.sql |  5 ++-
 11 files changed, 179 insertions(+), 87 deletions(-)

diff --git a/signer/internal/audit/logger.go b/signer/internal/audit/logger.go
index b73529f3d..2da44e1ef 100644
--- a/signer/internal/audit/logger.go
+++ b/signer/internal/audit/logger.go
@@ -49,6 +49,8 @@ type IssuanceEntry struct {
        ValidAfter           time.Time
        ValidBefore          time.Time
        SourceIP             string
+       GrantedExtensions    []string
+       ForceCommand         *string
        UserAccessTokenHash  string
        CorrelationID        string
 }
@@ -66,6 +68,8 @@ func (l *Logger) LogIssuance(ctx context.Context, entry 
*IssuanceEntry) error {
                ValidAfter:           entry.ValidAfter,
                ValidBefore:          entry.ValidBefore,
                SourceIP:             entry.SourceIP,
+               GrantedExtensions:    entry.GrantedExtensions,
+               ForceCommand:         entry.ForceCommand,
                UserAccessTokenHash:  entry.UserAccessTokenHash,
        }
 
@@ -87,6 +91,8 @@ func (l *Logger) LogIssuance(ctx context.Context, entry 
*IssuanceEntry) error {
                "valid_after", entry.ValidAfter.Unix(),
                "valid_before", entry.ValidBefore.Unix(),
                "source_ip", entry.SourceIP,
+               "granted_extensions", entry.GrantedExtensions,
+               "force_command", entry.ForceCommand,
                "user_token_hash", entry.UserAccessTokenHash,
                "correlation_id", entry.CorrelationID,
        )
diff --git a/signer/internal/cert/signer.go b/signer/internal/cert/signer.go
index 173e162a9..adb1deb51 100644
--- a/signer/internal/cert/signer.go
+++ b/signer/internal/cert/signer.go
@@ -63,11 +63,7 @@ func SignCertificate(req *SignRequest) (*SignResult, error) {
 
        extensions := req.Extensions
        if extensions == nil {
-               extensions = map[string]string{
-                       "permit-pty":             "",
-                       "permit-port-forwarding": "",
-                       "permit-user-rc":         "",
-               }
+               extensions = make(map[string]string)
        }
 
        cert := &ssh.Certificate{
diff --git a/signer/internal/cert/signer_test.go 
b/signer/internal/cert/signer_test.go
index d17f26fe8..55cd105b4 100644
--- a/signer/internal/cert/signer_test.go
+++ b/signer/internal/cert/signer_test.go
@@ -64,6 +64,7 @@ func TestSignCertificate_Ed25519(t *testing.T) {
        userKey := generateTestEd25519UserKey(t)
 
        before := time.Now().UTC()
+       exts := ExtensionsToMap(AllStandardExtensions())
        result, err := SignCertificate(&SignRequest{
                PublicKey:       userKey,
                CAPrivateKeyPEM: caPrivPEM,
@@ -71,6 +72,7 @@ func TestSignCertificate_Ed25519(t *testing.T) {
                Principal:       "testuser",
                ClientID:        "webapp",
                TTLSeconds:      7200,
+               Extensions:      exts,
        })
        after := time.Now().UTC()
 
@@ -120,8 +122,8 @@ func TestSignCertificate_Ed25519(t *testing.T) {
                t.Errorf("expected UserCert type, got %d", cert.CertType)
        }
 
-       for _, ext := range []string{"permit-pty", "permit-port-forwarding", 
"permit-user-rc"} {
-               if _, ok := cert.Extensions[ext]; !ok {
+       for _, ext := range AllStandardExtensions() {
+               if _, ok := cert.Extensions[string(ext)]; !ok {
                        t.Errorf("expected extension %s", ext)
                }
        }
diff --git a/signer/internal/handler/certificates.go 
b/signer/internal/handler/certificates.go
index 19f7ef727..6bd9adde1 100644
--- a/signer/internal/handler/certificates.go
+++ b/signer/internal/handler/certificates.go
@@ -28,18 +28,20 @@ import (
 )
 
 type CertificateResponse struct {
-       SerialNumber         int64  `json:"serial_number"`
-       KeyID                string `json:"key_id"`
-       Principal            string `json:"principal"`
-       PublicKeyFingerprint string `json:"public_key_fingerprint"`
-       CAFingerprint        string `json:"ca_fingerprint"`
-       ValidAfter           int64  `json:"valid_after"`
-       ValidBefore          int64  `json:"valid_before"`
-       IssuedAt             int64  `json:"issued_at"`
-       SourceIP             string `json:"source_ip,omitempty"`
-       Revoked              bool   `json:"revoked"`
-       RevokedAt            *int64 `json:"revoked_at,omitempty"`
-       RevocationReason     string `json:"revocation_reason,omitempty"`
+       SerialNumber         int64    `json:"serial_number"`
+       KeyID                string   `json:"key_id"`
+       Principal            string   `json:"principal"`
+       PublicKeyFingerprint string   `json:"public_key_fingerprint"`
+       CAFingerprint        string   `json:"ca_fingerprint"`
+       ValidAfter           int64    `json:"valid_after"`
+       ValidBefore          int64    `json:"valid_before"`
+       IssuedAt             int64    `json:"issued_at"`
+       SourceIP             string   `json:"source_ip,omitempty"`
+       GrantedExtensions    []string `json:"granted_extensions,omitempty"`
+       ForceCommand         *string  `json:"force_command,omitempty"`
+       Revoked              bool     `json:"revoked"`
+       RevokedAt            *int64   `json:"revoked_at,omitempty"`
+       RevocationReason     string   `json:"revocation_reason,omitempty"`
 }
 
 type CertificateListResponse struct {
@@ -94,6 +96,8 @@ func (h *CertificatesHandler) HandleList(w 
http.ResponseWriter, r *http.Request)
                        ValidBefore:          c.ValidBefore.Unix(),
                        IssuedAt:             c.IssuedAt.Unix(),
                        SourceIP:             c.SourceIP,
+                       GrantedExtensions:    c.GrantedExtensions,
+                       ForceCommand:         c.ForceCommand,
                        Revoked:              c.Revoked,
                        RevocationReason:     c.RevocationReason,
                }
@@ -168,6 +172,8 @@ func toCertificateResponse(c *store.CertificateWithStatus) 
CertificateResponse {
                ValidBefore:          c.ValidBefore.Unix(),
                IssuedAt:             c.IssuedAt.Unix(),
                SourceIP:             c.SourceIP,
+               GrantedExtensions:    c.GrantedExtensions,
+               ForceCommand:         c.ForceCommand,
                Revoked:              c.Revoked,
                RevocationReason:     c.RevocationReason,
        }
diff --git a/signer/internal/handler/sign.go b/signer/internal/handler/sign.go
index fed25f3ea..a8ee07b5f 100644
--- a/signer/internal/handler/sign.go
+++ b/signer/internal/handler/sign.go
@@ -37,21 +37,25 @@ import (
 var principalRegex = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}$`)
 
 type SignRequest struct {
-       Principal       string `json:"principal"`
-       TTLSeconds      int    `json:"ttl_seconds"`
-       PublicKey       string `json:"public_key"`
-       UserAccessToken string `json:"user_access_token"`
+       Principal         string   `json:"principal"`
+       TTLSeconds        int      `json:"ttl_seconds"`
+       PublicKey         string   `json:"public_key"`
+       UserAccessToken   string   `json:"user_access_token"`
+       ForceCommand      string   `json:"force_command,omitempty"`
+       ExcludeExtensions []string `json:"exclude_extensions,omitempty"`
 }
 
 type SignResponse struct {
-       Certificate    string `json:"certificate"`
-       SerialNumber   int64  `json:"serial_number"`
-       ValidAfter     int64  `json:"valid_after"`
-       ValidBefore    int64  `json:"valid_before"`
-       CAFingerprint  string `json:"ca_fingerprint"`
-       TargetHost     string `json:"target_host"`
-       TargetPort     int    `json:"target_port"`
-       TargetUsername string `json:"target_username"`
+       Certificate       string   `json:"certificate"`
+       SerialNumber      int64    `json:"serial_number"`
+       ValidAfter        int64    `json:"valid_after"`
+       ValidBefore       int64    `json:"valid_before"`
+       CAFingerprint     string   `json:"ca_fingerprint"`
+       TargetHost        string   `json:"target_host"`
+       TargetPort        int      `json:"target_port"`
+       TargetUsername    string   `json:"target_username"`
+       ForceCommand      string   `json:"force_command,omitempty"`
+       GrantedExtensions []string `json:"granted_extensions"`
 }
 
 type SignHandler struct {
@@ -155,6 +159,16 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r 
*http.Request) {
                return
        }
 
+       // Resolve extensions: all standard - client denied - request excluded
+       grantedExts, err := cert.ResolveExtensions(clientCfg.DeniedExtensions, 
req.ExcludeExtensions)
+       if err != nil {
+               metrics.SignRequestsTotal.WithLabelValues(tenantID, 
"error").Inc()
+               writeError(w, http.StatusBadRequest, "invalid_request", 
err.Error())
+               return
+       }
+       extensionsMap := cert.ExtensionsToMap(grantedExts)
+       grantedExtNames := cert.ExtensionNames(grantedExts)
+
        valResult, err := h.principalValidator.Validate(tenantID, clientID, 
req.Principal, identity.Subject)
        if err != nil {
                metrics.SignRequestsTotal.WithLabelValues(tenantID, 
"error").Inc()
@@ -190,7 +204,7 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r 
*http.Request) {
        }
        metrics.VaultOperationsTotal.WithLabelValues("increment_serial", 
"success").Inc()
 
-       criticalOpts := policy.GetCriticalOptions(clientCfg)
+       criticalOpts := 
policy.BuildCriticalOptions(clientCfg.SourceAddressRestriction, 
req.ForceCommand)
 
        signReq := &cert.SignRequest{
                PublicKey:       sshPubKey,
@@ -200,6 +214,7 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r 
*http.Request) {
                ClientID:        clientID,
                TTLSeconds:      uint64(req.TTLSeconds),
                CriticalOptions: criticalOpts,
+               Extensions:      extensionsMap,
        }
 
        signResult, err := cert.SignCertificate(signReq)
@@ -226,6 +241,8 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r 
*http.Request) {
                ValidBefore:          time.Unix(int64(signResult.ValidBefore), 
0).UTC(),
                SourceIP:             sourceIP,
                UserAccessTokenHash:  tokenHash,
+               GrantedExtensions:    grantedExtNames,
+               ForceCommand:         stringPtrIfNonEmpty(req.ForceCommand),
        }
 
        if err := h.auditLogger.LogIssuance(r.Context(), auditEntry); err != 
nil {
@@ -277,14 +294,16 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r 
*http.Request) {
        metrics.SignRequestsTotal.WithLabelValues(tenantID, "success").Inc()
 
        resp := SignResponse{
-               Certificate:    
base64.StdEncoding.EncodeToString(signResult.CertBytes),
-               SerialNumber:   serial,
-               ValidAfter:     int64(signResult.ValidAfter),
-               ValidBefore:    int64(signResult.ValidBefore),
-               CAFingerprint:  signResult.CAFingerprint,
-               TargetHost:     clientCfg.TargetHost,
-               TargetPort:     clientCfg.TargetPort,
-               TargetUsername: validatedPrincipal,
+               Certificate:       
base64.StdEncoding.EncodeToString(signResult.CertBytes),
+               SerialNumber:      serial,
+               ValidAfter:        int64(signResult.ValidAfter),
+               ValidBefore:       int64(signResult.ValidBefore),
+               CAFingerprint:     signResult.CAFingerprint,
+               TargetHost:        clientCfg.TargetHost,
+               TargetPort:        clientCfg.TargetPort,
+               TargetUsername:    validatedPrincipal,
+               ForceCommand:      req.ForceCommand,
+               GrantedExtensions: grantedExtNames,
        }
 
        w.Header().Set("Content-Type", "application/json")
@@ -296,3 +315,10 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r 
*http.Request) {
 func isDuplicateKeyError(err error) bool {
        return err != nil && strings.Contains(err.Error(), "Duplicate entry")
 }
+
+func stringPtrIfNonEmpty(s string) *string {
+       if s == "" {
+               return nil
+       }
+       return &s
+}
diff --git a/signer/internal/policy/enforcer.go 
b/signer/internal/policy/enforcer.go
index b6d006793..815701a3e 100644
--- a/signer/internal/policy/enforcer.go
+++ b/signer/internal/policy/enforcer.go
@@ -89,9 +89,18 @@ func (e *Enforcer) Enforce(ttlSeconds int, sshKeyType 
string, sourceIP string, c
        return nil
 }
 
-func GetCriticalOptions(clientCfg *store.ClientConfig) map[string]string {
-       if clientCfg.CriticalOptions != nil && len(clientCfg.CriticalOptions) > 
0 {
-               return clientCfg.CriticalOptions
+// BuildCriticalOptions constructs the SSH certificate critical options map 
from
+// the client config's source address restriction and the per-request force 
command.
+func BuildCriticalOptions(sourceAddr *string, forceCommand string) 
map[string]string {
+       opts := make(map[string]string)
+       if sourceAddr != nil && *sourceAddr != "" {
+               opts["source-address"] = *sourceAddr
        }
-       return nil
+       if forceCommand != "" {
+               opts["force-command"] = forceCommand
+       }
+       if len(opts) == 0 {
+               return nil
+       }
+       return opts
 }
diff --git a/signer/internal/policy/enforcer_test.go 
b/signer/internal/policy/enforcer_test.go
index cc33de05a..6cf493751 100644
--- a/signer/internal/policy/enforcer_test.go
+++ b/signer/internal/policy/enforcer_test.go
@@ -21,12 +21,11 @@ import (
        "github.com/apache/airavata-custos/signer/internal/store"
 )
 
-func clientCfg(maxTTL int, keyTypes []string, sourceRestriction *string, 
critOpts map[string]string) *store.ClientConfig {
+func clientCfg(maxTTL int, keyTypes []string, sourceRestriction *string) 
*store.ClientConfig {
        return &store.ClientConfig{
                MaxTTLSeconds:            maxTTL,
                AllowedKeyTypes:          keyTypes,
                SourceAddressRestriction: sourceRestriction,
-               CriticalOptions:          critOpts,
        }
 }
 
@@ -36,7 +35,7 @@ func strPtr(s string) *string {
 
 func TestEnforce_TTLWithinLimit(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519", "rsa", "ecdsa"})
-       cfg := clientCfg(86400, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, nil)
        err := e.Enforce(3600, "ssh-ed25519", "10.0.0.1", cfg)
        if err != nil {
                t.Errorf("expected no error, got: %v", err)
@@ -45,7 +44,7 @@ func TestEnforce_TTLWithinLimit(t *testing.T) {
 
 func TestEnforce_TTLExceedsLimit(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(3600, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(3600, []string{"ed25519"}, nil)
        err := e.Enforce(7200, "ssh-ed25519", "10.0.0.1", cfg)
        if err == nil {
                t.Fatal("expected error for TTL exceeding limit")
@@ -61,7 +60,7 @@ func TestEnforce_TTLExceedsLimit(t *testing.T) {
 
 func TestEnforce_TTLAtExactLimit(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(3600, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(3600, []string{"ed25519"}, nil)
        err := e.Enforce(3600, "ssh-ed25519", "10.0.0.1", cfg)
        if err != nil {
                t.Errorf("expected no error, got: %v", err)
@@ -70,7 +69,7 @@ func TestEnforce_TTLAtExactLimit(t *testing.T) {
 
 func TestEnforce_NullMaxTTLUsesDefault(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(0, []string{"ed25519"}, nil, nil) // 0 = null 
equivalent
+       cfg := clientCfg(0, []string{"ed25519"}, nil) // 0 = null equivalent
        err := e.Enforce(86400, "ssh-ed25519", "10.0.0.1", cfg)
        if err != nil {
                t.Errorf("expected no error, got: %v", err)
@@ -79,7 +78,7 @@ func TestEnforce_NullMaxTTLUsesDefault(t *testing.T) {
 
 func TestEnforce_TTLZero(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(86400, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, nil)
        err := e.Enforce(0, "ssh-ed25519", "10.0.0.1", cfg)
        if err == nil {
                t.Fatal("expected error for TTL 0")
@@ -88,7 +87,7 @@ func TestEnforce_TTLZero(t *testing.T) {
 
 func TestEnforce_NegativeTTL(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(86400, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, nil)
        err := e.Enforce(-100, "ssh-ed25519", "10.0.0.1", cfg)
        if err == nil {
                t.Fatal("expected error for negative TTL")
@@ -97,7 +96,7 @@ func TestEnforce_NegativeTTL(t *testing.T) {
 
 func TestEnforce_DisallowedKeyType(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(86400, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, nil)
        err := e.Enforce(3600, "ssh-rsa", "10.0.0.1", cfg)
        if err == nil {
                t.Fatal("expected error for disallowed key type")
@@ -110,7 +109,7 @@ func TestEnforce_DisallowedKeyType(t *testing.T) {
 
 func TestEnforce_AllowedKeyType(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519", "rsa"})
-       cfg := clientCfg(86400, []string{"ed25519", "rsa"}, nil, nil)
+       cfg := clientCfg(86400, []string{"ed25519", "rsa"}, nil)
        err := e.Enforce(3600, "ssh-rsa", "10.0.0.1", cfg)
        if err != nil {
                t.Errorf("expected no error, got: %v", err)
@@ -120,7 +119,7 @@ func TestEnforce_AllowedKeyType(t *testing.T) {
 func TestEnforce_SourceAddressRejected(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
        cidr := "10.0.0.0/8"
-       cfg := clientCfg(86400, []string{"ed25519"}, &cidr, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, &cidr)
        err := e.Enforce(3600, "ssh-ed25519", "192.168.1.1", cfg)
        if err == nil {
                t.Fatal("expected error for source address outside CIDR")
@@ -130,7 +129,7 @@ func TestEnforce_SourceAddressRejected(t *testing.T) {
 func TestEnforce_SourceAddressAccepted(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
        cidr := "10.0.0.0/8"
-       cfg := clientCfg(86400, []string{"ed25519"}, &cidr, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, &cidr)
        err := e.Enforce(3600, "ssh-ed25519", "10.1.2.3", cfg)
        if err != nil {
                t.Errorf("expected no error, got: %v", err)
@@ -139,28 +138,44 @@ func TestEnforce_SourceAddressAccepted(t *testing.T) {
 
 func TestEnforce_NoSourceRestriction(t *testing.T) {
        e := NewEnforcer(86400, []string{"ed25519"})
-       cfg := clientCfg(86400, []string{"ed25519"}, nil, nil)
+       cfg := clientCfg(86400, []string{"ed25519"}, nil)
        err := e.Enforce(3600, "ssh-ed25519", "192.168.1.1", cfg)
        if err != nil {
                t.Errorf("expected no error, got: %v", err)
        }
 }
 
-func TestGetCriticalOptions(t *testing.T) {
-       opts := map[string]string{"source-address": "10.0.0.0/8"}
-       cfg := clientCfg(86400, nil, nil, opts)
-       result := GetCriticalOptions(cfg)
+func TestBuildCriticalOptions_SourceAddress(t *testing.T) {
+       addr := "10.0.0.0/8"
+       result := BuildCriticalOptions(&addr, "")
        if result == nil {
                t.Fatal("expected non-nil critical options")
        }
        if result["source-address"] != "10.0.0.0/8" {
-               t.Errorf("expected source-address, got %v", result)
+               t.Errorf("expected source-address=10.0.0.0/8, got %v", result)
        }
 }
 
-func TestGetCriticalOptions_Nil(t *testing.T) {
-       cfg := clientCfg(86400, nil, nil, nil)
-       result := GetCriticalOptions(cfg)
+func TestBuildCriticalOptions_ForceCommand(t *testing.T) {
+       result := BuildCriticalOptions(nil, "sbatch /path/to/job.sh")
+       if result == nil {
+               t.Fatal("expected non-nil critical options")
+       }
+       if result["force-command"] != "sbatch /path/to/job.sh" {
+               t.Errorf("expected force-command, got %v", result)
+       }
+}
+
+func TestBuildCriticalOptions_Both(t *testing.T) {
+       addr := "10.0.0.0/8"
+       result := BuildCriticalOptions(&addr, "sbatch /job.sh")
+       if len(result) != 2 {
+               t.Errorf("expected 2 critical options, got %d", len(result))
+       }
+}
+
+func TestBuildCriticalOptions_Empty(t *testing.T) {
+       result := BuildCriticalOptions(nil, "")
        if result != nil {
                t.Errorf("expected nil critical options, got %v", result)
        }
diff --git a/signer/internal/store/certificate_query.go 
b/signer/internal/store/certificate_query.go
index 3891920bf..5ceb864c7 100644
--- a/signer/internal/store/certificate_query.go
+++ b/signer/internal/store/certificate_query.go
@@ -17,6 +17,8 @@ package store
 
 import (
        "context"
+       "database/sql"
+       "encoding/json"
        "fmt"
        "time"
 )
@@ -35,6 +37,8 @@ type CertificateWithStatus struct {
        ValidBefore          time.Time
        IssuedAt             time.Time
        SourceIP             string
+       GrantedExtensions    []string
+       ForceCommand         *string
        Revoked              bool
        RevokedAt            *time.Time
        RevocationReason     string
@@ -74,6 +78,7 @@ func (d *DB) ListCertificatesByEmail(ctx context.Context, 
email string, limit, o
                        c.id, c.tenant_id, c.client_id, c.serial_number, 
c.key_id,
                        c.principal, COALESCE(c.user_email, ''), 
c.public_key_fingerprint, c.ca_fingerprint,
                        c.valid_after, c.valid_before, c.issued_at, 
COALESCE(c.source_ip, ''),
+                       c.granted_extensions, c.force_command,
                        CASE WHEN r.id IS NOT NULL THEN TRUE ELSE FALSE END AS 
revoked,
                        r.revoked_at,
                        COALESCE(r.reason, '') AS revocation_reason
@@ -93,16 +98,27 @@ func (d *DB) ListCertificatesByEmail(ctx context.Context, 
email string, limit, o
        for rows.Next() {
                var cert CertificateWithStatus
                var revokedAt *time.Time
+               var grantedExtensionsJSON []byte
+               var forceCommand sql.NullString
 
                if err := rows.Scan(
                        &cert.ID, &cert.TenantID, &cert.ClientID, 
&cert.SerialNumber, &cert.KeyID,
                        &cert.Principal, &cert.UserEmail, 
&cert.PublicKeyFingerprint, &cert.CAFingerprint,
                        &cert.ValidAfter, &cert.ValidBefore, &cert.IssuedAt, 
&cert.SourceIP,
+                       &grantedExtensionsJSON, &forceCommand,
                        &cert.Revoked, &revokedAt, &cert.RevocationReason,
                ); err != nil {
                        return nil, fmt.Errorf("scanning certificate row: %w", 
err)
                }
 
+               if grantedExtensionsJSON != nil {
+                       if err := json.Unmarshal(grantedExtensionsJSON, 
&cert.GrantedExtensions); err != nil {
+                               return nil, fmt.Errorf("unmarshaling 
granted_extensions: %w", err)
+                       }
+               }
+               if forceCommand.Valid {
+                       cert.ForceCommand = &forceCommand.String
+               }
                cert.RevokedAt = revokedAt
                certs = append(certs, cert)
        }
@@ -120,12 +136,15 @@ func (d *DB) ListCertificatesByEmail(ctx context.Context, 
email string, limit, o
 func (d *DB) GetCertificateBySerial(ctx context.Context, serial int64) 
(*CertificateWithStatus, error) {
        var cert CertificateWithStatus
        var revokedAt *time.Time
+       var grantedExtensionsJSON []byte
+       var forceCommand sql.NullString
 
        err := d.QueryRowContext(ctx,
                `SELECT
                        c.id, c.tenant_id, c.client_id, c.serial_number, 
c.key_id,
                        c.principal, COALESCE(c.user_email, ''), 
c.public_key_fingerprint, c.ca_fingerprint,
                        c.valid_after, c.valid_before, c.issued_at, 
COALESCE(c.source_ip, ''),
+                       c.granted_extensions, c.force_command,
                        CASE WHEN r.id IS NOT NULL THEN TRUE ELSE FALSE END AS 
revoked,
                        r.revoked_at,
                        COALESCE(r.reason, '') AS revocation_reason
@@ -137,12 +156,21 @@ func (d *DB) GetCertificateBySerial(ctx context.Context, 
serial int64) (*Certifi
                &cert.ID, &cert.TenantID, &cert.ClientID, &cert.SerialNumber, 
&cert.KeyID,
                &cert.Principal, &cert.UserEmail, &cert.PublicKeyFingerprint, 
&cert.CAFingerprint,
                &cert.ValidAfter, &cert.ValidBefore, &cert.IssuedAt, 
&cert.SourceIP,
+               &grantedExtensionsJSON, &forceCommand,
                &cert.Revoked, &revokedAt, &cert.RevocationReason,
        )
        if err != nil {
                return nil, fmt.Errorf("querying certificate: %w", err)
        }
 
+       if grantedExtensionsJSON != nil {
+               if err := json.Unmarshal(grantedExtensionsJSON, 
&cert.GrantedExtensions); err != nil {
+                       return nil, fmt.Errorf("unmarshaling 
granted_extensions: %w", err)
+               }
+       }
+       if forceCommand.Valid {
+               cert.ForceCommand = &forceCommand.String
+       }
        cert.RevokedAt = revokedAt
        return &cert, nil
 }
diff --git a/signer/internal/store/client_config.go 
b/signer/internal/store/client_config.go
index 7c2d95c48..606c2e32d 100644
--- a/signer/internal/store/client_config.go
+++ b/signer/internal/store/client_config.go
@@ -31,8 +31,7 @@ type ClientConfig struct {
        MaxTTLSeconds            int
        AllowedKeyTypes          []string
        SourceAddressRestriction *string
-       CriticalOptions          map[string]string
-       Extensions               map[string]string
+       DeniedExtensions         []string
        Enabled                  bool
 }
 
@@ -41,21 +40,20 @@ func (d *DB) GetClientConfig(ctx context.Context, tenantID, 
clientID string) (*C
                cc                       ClientConfig
                allowedKeyTypesJSON      []byte
                sourceAddressRestriction sql.NullString
-               criticalOptionsJSON      sql.NullString
-               extensionsJSON           sql.NullString
+               deniedExtensionsJSON     sql.NullString
        )
 
        err := d.QueryRowContext(ctx,
                `SELECT tenant_id, client_id, client_secret, target_host, 
target_port,
                        max_ttl_seconds, allowed_key_types, 
source_address_restriction,
-                       critical_options, extensions, enabled
+                       denied_extensions, enabled
                 FROM client_ssh_configs
                 WHERE tenant_id = ? AND client_id = ?`,
                tenantID, clientID,
        ).Scan(
                &cc.TenantID, &cc.ClientID, &cc.ClientSecret, &cc.TargetHost, 
&cc.TargetPort,
                &cc.MaxTTLSeconds, &allowedKeyTypesJSON, 
&sourceAddressRestriction,
-               &criticalOptionsJSON, &extensionsJSON, &cc.Enabled,
+               &deniedExtensionsJSON, &cc.Enabled,
        )
        if err != nil {
                if err == sql.ErrNoRows {
@@ -72,15 +70,9 @@ func (d *DB) GetClientConfig(ctx context.Context, tenantID, 
clientID string) (*C
                cc.SourceAddressRestriction = &sourceAddressRestriction.String
        }
 
-       if criticalOptionsJSON.Valid && criticalOptionsJSON.String != "" {
-               if err := json.Unmarshal([]byte(criticalOptionsJSON.String), 
&cc.CriticalOptions); err != nil {
-                       return nil, fmt.Errorf("unmarshaling critical_options: 
%w", err)
-               }
-       }
-
-       if extensionsJSON.Valid && extensionsJSON.String != "" {
-               if err := json.Unmarshal([]byte(extensionsJSON.String), 
&cc.Extensions); err != nil {
-                       return nil, fmt.Errorf("unmarshaling extensions: %w", 
err)
+       if deniedExtensionsJSON.Valid && deniedExtensionsJSON.String != "" {
+               if err := json.Unmarshal([]byte(deniedExtensionsJSON.String), 
&cc.DeniedExtensions); err != nil {
+                       return nil, fmt.Errorf("unmarshaling denied_extensions: 
%w", err)
                }
        }
 
diff --git a/signer/internal/store/issuance_log.go 
b/signer/internal/store/issuance_log.go
index 823656eca..f9ef4ef51 100644
--- a/signer/internal/store/issuance_log.go
+++ b/signer/internal/store/issuance_log.go
@@ -34,29 +34,40 @@ type IssuanceLog struct {
        ValidAfter           time.Time
        ValidBefore          time.Time
        SourceIP             string
+       GrantedExtensions    []string
+       ForceCommand         *string
        UserAccessTokenHash  string
        RequestMetadata      map[string]interface{}
 }
 
 // InsertIssuanceLog writes a certificate issuance entry. Fails on duplicate 
serial_number.
 func (d *DB) InsertIssuanceLog(ctx context.Context, log *IssuanceLog) error {
+       grantedExtensions := log.GrantedExtensions
+       if grantedExtensions == nil {
+               grantedExtensions = []string{}
+       }
+       grantedExtensionsJSON, err := json.Marshal(grantedExtensions)
+       if err != nil {
+               return fmt.Errorf("marshaling granted_extensions: %w", err)
+       }
+
        var metadataJSON []byte
        if log.RequestMetadata != nil {
-               var err error
                metadataJSON, err = json.Marshal(log.RequestMetadata)
                if err != nil {
                        return fmt.Errorf("marshaling request_metadata: %w", 
err)
                }
        }
 
-       _, err := d.ExecContext(ctx,
+       _, err = d.ExecContext(ctx,
                `INSERT INTO certificate_issuance_logs
                 (tenant_id, client_id, serial_number, key_id, principal, 
user_email, public_key_fingerprint,
-                 ca_fingerprint, valid_after, valid_before, source_ip, 
user_access_token_hash, request_metadata)
-                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+                 ca_fingerprint, valid_after, valid_before, source_ip, 
granted_extensions, force_command,
+                 user_access_token_hash, request_metadata)
+                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
                log.TenantID, log.ClientID, log.SerialNumber, log.KeyID, 
log.Principal,
                log.UserEmail, log.PublicKeyFingerprint, log.CAFingerprint, 
log.ValidAfter, log.ValidBefore,
-               log.SourceIP, log.UserAccessTokenHash, metadataJSON,
+               log.SourceIP, grantedExtensionsJSON, log.ForceCommand, 
log.UserAccessTokenHash, metadataJSON,
        )
        if err != nil {
                return fmt.Errorf("inserting issuance log: %w", err)
diff --git a/signer/migrations/001_initial_schema.up.sql 
b/signer/migrations/001_initial_schema.up.sql
index 1fd2f2211..5da7969cd 100644
--- a/signer/migrations/001_initial_schema.up.sql
+++ b/signer/migrations/001_initial_schema.up.sql
@@ -11,8 +11,7 @@ CREATE TABLE IF NOT EXISTS client_ssh_configs
     max_ttl_seconds            INT          NOT NULL DEFAULT 86400,
     allowed_key_types          JSON         NOT NULL,
     source_address_restriction VARCHAR(255) NULL,
-    critical_options           JSON         NULL,
-    extensions                 JSON         NULL,
+    denied_extensions          JSON         NULL,
     enabled                    BOOLEAN      NOT NULL DEFAULT TRUE,
     created_at                 TIMESTAMP(6) NOT NULL DEFAULT 
CURRENT_TIMESTAMP(6),
     updated_at                 TIMESTAMP(6) NOT NULL DEFAULT 
CURRENT_TIMESTAMP(6)
@@ -40,6 +39,8 @@ CREATE TABLE IF NOT EXISTS certificate_issuance_logs
     valid_before           TIMESTAMP(6) NOT NULL,
     issued_at              TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
     source_ip              VARCHAR(45)  NULL,
+    granted_extensions     JSON         NOT NULL,
+    force_command          TEXT         NULL,
     user_access_token_hash VARCHAR(255) NULL,
     request_metadata       JSON         NULL,
 

Reply via email to