This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit d1e8bc5f4a46b66e246f6c7bbd9ab9b52f1ceb60 Author: lahiruj <[email protected]> AuthorDate: Thu Apr 2 21:44:24 2026 -0400 maintain a denied list for the extensions in client config and move the force-command to the certificate issuance --- signer/internal/audit/logger.go | 6 +++ signer/internal/cert/signer.go | 6 +-- signer/internal/cert/signer_test.go | 6 ++- signer/internal/handler/certificates.go | 30 ++++++++----- signer/internal/handler/sign.go | 68 ++++++++++++++++++++--------- signer/internal/policy/enforcer.go | 17 ++++++-- signer/internal/policy/enforcer_test.go | 57 +++++++++++++++--------- signer/internal/store/certificate_query.go | 28 ++++++++++++ signer/internal/store/client_config.go | 22 +++------- signer/internal/store/issuance_log.go | 21 ++++++--- signer/migrations/001_initial_schema.up.sql | 5 ++- 11 files changed, 179 insertions(+), 87 deletions(-) diff --git a/signer/internal/audit/logger.go b/signer/internal/audit/logger.go index b73529f3d..2da44e1ef 100644 --- a/signer/internal/audit/logger.go +++ b/signer/internal/audit/logger.go @@ -49,6 +49,8 @@ type IssuanceEntry struct { ValidAfter time.Time ValidBefore time.Time SourceIP string + GrantedExtensions []string + ForceCommand *string UserAccessTokenHash string CorrelationID string } @@ -66,6 +68,8 @@ func (l *Logger) LogIssuance(ctx context.Context, entry *IssuanceEntry) error { ValidAfter: entry.ValidAfter, ValidBefore: entry.ValidBefore, SourceIP: entry.SourceIP, + GrantedExtensions: entry.GrantedExtensions, + ForceCommand: entry.ForceCommand, UserAccessTokenHash: entry.UserAccessTokenHash, } @@ -87,6 +91,8 @@ func (l *Logger) LogIssuance(ctx context.Context, entry *IssuanceEntry) error { "valid_after", entry.ValidAfter.Unix(), "valid_before", entry.ValidBefore.Unix(), "source_ip", entry.SourceIP, + "granted_extensions", entry.GrantedExtensions, + "force_command", entry.ForceCommand, "user_token_hash", entry.UserAccessTokenHash, "correlation_id", entry.CorrelationID, ) diff --git a/signer/internal/cert/signer.go b/signer/internal/cert/signer.go index 173e162a9..adb1deb51 100644 --- a/signer/internal/cert/signer.go +++ b/signer/internal/cert/signer.go @@ -63,11 +63,7 @@ func SignCertificate(req *SignRequest) (*SignResult, error) { extensions := req.Extensions if extensions == nil { - extensions = map[string]string{ - "permit-pty": "", - "permit-port-forwarding": "", - "permit-user-rc": "", - } + extensions = make(map[string]string) } cert := &ssh.Certificate{ diff --git a/signer/internal/cert/signer_test.go b/signer/internal/cert/signer_test.go index d17f26fe8..55cd105b4 100644 --- a/signer/internal/cert/signer_test.go +++ b/signer/internal/cert/signer_test.go @@ -64,6 +64,7 @@ func TestSignCertificate_Ed25519(t *testing.T) { userKey := generateTestEd25519UserKey(t) before := time.Now().UTC() + exts := ExtensionsToMap(AllStandardExtensions()) result, err := SignCertificate(&SignRequest{ PublicKey: userKey, CAPrivateKeyPEM: caPrivPEM, @@ -71,6 +72,7 @@ func TestSignCertificate_Ed25519(t *testing.T) { Principal: "testuser", ClientID: "webapp", TTLSeconds: 7200, + Extensions: exts, }) after := time.Now().UTC() @@ -120,8 +122,8 @@ func TestSignCertificate_Ed25519(t *testing.T) { t.Errorf("expected UserCert type, got %d", cert.CertType) } - for _, ext := range []string{"permit-pty", "permit-port-forwarding", "permit-user-rc"} { - if _, ok := cert.Extensions[ext]; !ok { + for _, ext := range AllStandardExtensions() { + if _, ok := cert.Extensions[string(ext)]; !ok { t.Errorf("expected extension %s", ext) } } diff --git a/signer/internal/handler/certificates.go b/signer/internal/handler/certificates.go index 19f7ef727..6bd9adde1 100644 --- a/signer/internal/handler/certificates.go +++ b/signer/internal/handler/certificates.go @@ -28,18 +28,20 @@ import ( ) type CertificateResponse struct { - SerialNumber int64 `json:"serial_number"` - KeyID string `json:"key_id"` - Principal string `json:"principal"` - PublicKeyFingerprint string `json:"public_key_fingerprint"` - CAFingerprint string `json:"ca_fingerprint"` - ValidAfter int64 `json:"valid_after"` - ValidBefore int64 `json:"valid_before"` - IssuedAt int64 `json:"issued_at"` - SourceIP string `json:"source_ip,omitempty"` - Revoked bool `json:"revoked"` - RevokedAt *int64 `json:"revoked_at,omitempty"` - RevocationReason string `json:"revocation_reason,omitempty"` + SerialNumber int64 `json:"serial_number"` + KeyID string `json:"key_id"` + Principal string `json:"principal"` + PublicKeyFingerprint string `json:"public_key_fingerprint"` + CAFingerprint string `json:"ca_fingerprint"` + ValidAfter int64 `json:"valid_after"` + ValidBefore int64 `json:"valid_before"` + IssuedAt int64 `json:"issued_at"` + SourceIP string `json:"source_ip,omitempty"` + GrantedExtensions []string `json:"granted_extensions,omitempty"` + ForceCommand *string `json:"force_command,omitempty"` + Revoked bool `json:"revoked"` + RevokedAt *int64 `json:"revoked_at,omitempty"` + RevocationReason string `json:"revocation_reason,omitempty"` } type CertificateListResponse struct { @@ -94,6 +96,8 @@ func (h *CertificatesHandler) HandleList(w http.ResponseWriter, r *http.Request) ValidBefore: c.ValidBefore.Unix(), IssuedAt: c.IssuedAt.Unix(), SourceIP: c.SourceIP, + GrantedExtensions: c.GrantedExtensions, + ForceCommand: c.ForceCommand, Revoked: c.Revoked, RevocationReason: c.RevocationReason, } @@ -168,6 +172,8 @@ func toCertificateResponse(c *store.CertificateWithStatus) CertificateResponse { ValidBefore: c.ValidBefore.Unix(), IssuedAt: c.IssuedAt.Unix(), SourceIP: c.SourceIP, + GrantedExtensions: c.GrantedExtensions, + ForceCommand: c.ForceCommand, Revoked: c.Revoked, RevocationReason: c.RevocationReason, } diff --git a/signer/internal/handler/sign.go b/signer/internal/handler/sign.go index fed25f3ea..a8ee07b5f 100644 --- a/signer/internal/handler/sign.go +++ b/signer/internal/handler/sign.go @@ -37,21 +37,25 @@ import ( var principalRegex = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}$`) type SignRequest struct { - Principal string `json:"principal"` - TTLSeconds int `json:"ttl_seconds"` - PublicKey string `json:"public_key"` - UserAccessToken string `json:"user_access_token"` + Principal string `json:"principal"` + TTLSeconds int `json:"ttl_seconds"` + PublicKey string `json:"public_key"` + UserAccessToken string `json:"user_access_token"` + ForceCommand string `json:"force_command,omitempty"` + ExcludeExtensions []string `json:"exclude_extensions,omitempty"` } type SignResponse struct { - Certificate string `json:"certificate"` - SerialNumber int64 `json:"serial_number"` - ValidAfter int64 `json:"valid_after"` - ValidBefore int64 `json:"valid_before"` - CAFingerprint string `json:"ca_fingerprint"` - TargetHost string `json:"target_host"` - TargetPort int `json:"target_port"` - TargetUsername string `json:"target_username"` + Certificate string `json:"certificate"` + SerialNumber int64 `json:"serial_number"` + ValidAfter int64 `json:"valid_after"` + ValidBefore int64 `json:"valid_before"` + CAFingerprint string `json:"ca_fingerprint"` + TargetHost string `json:"target_host"` + TargetPort int `json:"target_port"` + TargetUsername string `json:"target_username"` + ForceCommand string `json:"force_command,omitempty"` + GrantedExtensions []string `json:"granted_extensions"` } type SignHandler struct { @@ -155,6 +159,16 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r *http.Request) { return } + // Resolve extensions: all standard - client denied - request excluded + grantedExts, err := cert.ResolveExtensions(clientCfg.DeniedExtensions, req.ExcludeExtensions) + if err != nil { + metrics.SignRequestsTotal.WithLabelValues(tenantID, "error").Inc() + writeError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + extensionsMap := cert.ExtensionsToMap(grantedExts) + grantedExtNames := cert.ExtensionNames(grantedExts) + valResult, err := h.principalValidator.Validate(tenantID, clientID, req.Principal, identity.Subject) if err != nil { metrics.SignRequestsTotal.WithLabelValues(tenantID, "error").Inc() @@ -190,7 +204,7 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r *http.Request) { } metrics.VaultOperationsTotal.WithLabelValues("increment_serial", "success").Inc() - criticalOpts := policy.GetCriticalOptions(clientCfg) + criticalOpts := policy.BuildCriticalOptions(clientCfg.SourceAddressRestriction, req.ForceCommand) signReq := &cert.SignRequest{ PublicKey: sshPubKey, @@ -200,6 +214,7 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r *http.Request) { ClientID: clientID, TTLSeconds: uint64(req.TTLSeconds), CriticalOptions: criticalOpts, + Extensions: extensionsMap, } signResult, err := cert.SignCertificate(signReq) @@ -226,6 +241,8 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r *http.Request) { ValidBefore: time.Unix(int64(signResult.ValidBefore), 0).UTC(), SourceIP: sourceIP, UserAccessTokenHash: tokenHash, + GrantedExtensions: grantedExtNames, + ForceCommand: stringPtrIfNonEmpty(req.ForceCommand), } if err := h.auditLogger.LogIssuance(r.Context(), auditEntry); err != nil { @@ -277,14 +294,16 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r *http.Request) { metrics.SignRequestsTotal.WithLabelValues(tenantID, "success").Inc() resp := SignResponse{ - Certificate: base64.StdEncoding.EncodeToString(signResult.CertBytes), - SerialNumber: serial, - ValidAfter: int64(signResult.ValidAfter), - ValidBefore: int64(signResult.ValidBefore), - CAFingerprint: signResult.CAFingerprint, - TargetHost: clientCfg.TargetHost, - TargetPort: clientCfg.TargetPort, - TargetUsername: validatedPrincipal, + Certificate: base64.StdEncoding.EncodeToString(signResult.CertBytes), + SerialNumber: serial, + ValidAfter: int64(signResult.ValidAfter), + ValidBefore: int64(signResult.ValidBefore), + CAFingerprint: signResult.CAFingerprint, + TargetHost: clientCfg.TargetHost, + TargetPort: clientCfg.TargetPort, + TargetUsername: validatedPrincipal, + ForceCommand: req.ForceCommand, + GrantedExtensions: grantedExtNames, } w.Header().Set("Content-Type", "application/json") @@ -296,3 +315,10 @@ func (h *SignHandler) Handle(w http.ResponseWriter, r *http.Request) { func isDuplicateKeyError(err error) bool { return err != nil && strings.Contains(err.Error(), "Duplicate entry") } + +func stringPtrIfNonEmpty(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/signer/internal/policy/enforcer.go b/signer/internal/policy/enforcer.go index b6d006793..815701a3e 100644 --- a/signer/internal/policy/enforcer.go +++ b/signer/internal/policy/enforcer.go @@ -89,9 +89,18 @@ func (e *Enforcer) Enforce(ttlSeconds int, sshKeyType string, sourceIP string, c return nil } -func GetCriticalOptions(clientCfg *store.ClientConfig) map[string]string { - if clientCfg.CriticalOptions != nil && len(clientCfg.CriticalOptions) > 0 { - return clientCfg.CriticalOptions +// BuildCriticalOptions constructs the SSH certificate critical options map from +// the client config's source address restriction and the per-request force command. +func BuildCriticalOptions(sourceAddr *string, forceCommand string) map[string]string { + opts := make(map[string]string) + if sourceAddr != nil && *sourceAddr != "" { + opts["source-address"] = *sourceAddr } - return nil + if forceCommand != "" { + opts["force-command"] = forceCommand + } + if len(opts) == 0 { + return nil + } + return opts } diff --git a/signer/internal/policy/enforcer_test.go b/signer/internal/policy/enforcer_test.go index cc33de05a..6cf493751 100644 --- a/signer/internal/policy/enforcer_test.go +++ b/signer/internal/policy/enforcer_test.go @@ -21,12 +21,11 @@ import ( "github.com/apache/airavata-custos/signer/internal/store" ) -func clientCfg(maxTTL int, keyTypes []string, sourceRestriction *string, critOpts map[string]string) *store.ClientConfig { +func clientCfg(maxTTL int, keyTypes []string, sourceRestriction *string) *store.ClientConfig { return &store.ClientConfig{ MaxTTLSeconds: maxTTL, AllowedKeyTypes: keyTypes, SourceAddressRestriction: sourceRestriction, - CriticalOptions: critOpts, } } @@ -36,7 +35,7 @@ func strPtr(s string) *string { func TestEnforce_TTLWithinLimit(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519", "rsa", "ecdsa"}) - cfg := clientCfg(86400, []string{"ed25519"}, nil, nil) + cfg := clientCfg(86400, []string{"ed25519"}, nil) err := e.Enforce(3600, "ssh-ed25519", "10.0.0.1", cfg) if err != nil { t.Errorf("expected no error, got: %v", err) @@ -45,7 +44,7 @@ func TestEnforce_TTLWithinLimit(t *testing.T) { func TestEnforce_TTLExceedsLimit(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(3600, []string{"ed25519"}, nil, nil) + cfg := clientCfg(3600, []string{"ed25519"}, nil) err := e.Enforce(7200, "ssh-ed25519", "10.0.0.1", cfg) if err == nil { t.Fatal("expected error for TTL exceeding limit") @@ -61,7 +60,7 @@ func TestEnforce_TTLExceedsLimit(t *testing.T) { func TestEnforce_TTLAtExactLimit(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(3600, []string{"ed25519"}, nil, nil) + cfg := clientCfg(3600, []string{"ed25519"}, nil) err := e.Enforce(3600, "ssh-ed25519", "10.0.0.1", cfg) if err != nil { t.Errorf("expected no error, got: %v", err) @@ -70,7 +69,7 @@ func TestEnforce_TTLAtExactLimit(t *testing.T) { func TestEnforce_NullMaxTTLUsesDefault(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(0, []string{"ed25519"}, nil, nil) // 0 = null equivalent + cfg := clientCfg(0, []string{"ed25519"}, nil) // 0 = null equivalent err := e.Enforce(86400, "ssh-ed25519", "10.0.0.1", cfg) if err != nil { t.Errorf("expected no error, got: %v", err) @@ -79,7 +78,7 @@ func TestEnforce_NullMaxTTLUsesDefault(t *testing.T) { func TestEnforce_TTLZero(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(86400, []string{"ed25519"}, nil, nil) + cfg := clientCfg(86400, []string{"ed25519"}, nil) err := e.Enforce(0, "ssh-ed25519", "10.0.0.1", cfg) if err == nil { t.Fatal("expected error for TTL 0") @@ -88,7 +87,7 @@ func TestEnforce_TTLZero(t *testing.T) { func TestEnforce_NegativeTTL(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(86400, []string{"ed25519"}, nil, nil) + cfg := clientCfg(86400, []string{"ed25519"}, nil) err := e.Enforce(-100, "ssh-ed25519", "10.0.0.1", cfg) if err == nil { t.Fatal("expected error for negative TTL") @@ -97,7 +96,7 @@ func TestEnforce_NegativeTTL(t *testing.T) { func TestEnforce_DisallowedKeyType(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(86400, []string{"ed25519"}, nil, nil) + cfg := clientCfg(86400, []string{"ed25519"}, nil) err := e.Enforce(3600, "ssh-rsa", "10.0.0.1", cfg) if err == nil { t.Fatal("expected error for disallowed key type") @@ -110,7 +109,7 @@ func TestEnforce_DisallowedKeyType(t *testing.T) { func TestEnforce_AllowedKeyType(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519", "rsa"}) - cfg := clientCfg(86400, []string{"ed25519", "rsa"}, nil, nil) + cfg := clientCfg(86400, []string{"ed25519", "rsa"}, nil) err := e.Enforce(3600, "ssh-rsa", "10.0.0.1", cfg) if err != nil { t.Errorf("expected no error, got: %v", err) @@ -120,7 +119,7 @@ func TestEnforce_AllowedKeyType(t *testing.T) { func TestEnforce_SourceAddressRejected(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) cidr := "10.0.0.0/8" - cfg := clientCfg(86400, []string{"ed25519"}, &cidr, nil) + cfg := clientCfg(86400, []string{"ed25519"}, &cidr) err := e.Enforce(3600, "ssh-ed25519", "192.168.1.1", cfg) if err == nil { t.Fatal("expected error for source address outside CIDR") @@ -130,7 +129,7 @@ func TestEnforce_SourceAddressRejected(t *testing.T) { func TestEnforce_SourceAddressAccepted(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) cidr := "10.0.0.0/8" - cfg := clientCfg(86400, []string{"ed25519"}, &cidr, nil) + cfg := clientCfg(86400, []string{"ed25519"}, &cidr) err := e.Enforce(3600, "ssh-ed25519", "10.1.2.3", cfg) if err != nil { t.Errorf("expected no error, got: %v", err) @@ -139,28 +138,44 @@ func TestEnforce_SourceAddressAccepted(t *testing.T) { func TestEnforce_NoSourceRestriction(t *testing.T) { e := NewEnforcer(86400, []string{"ed25519"}) - cfg := clientCfg(86400, []string{"ed25519"}, nil, nil) + cfg := clientCfg(86400, []string{"ed25519"}, nil) err := e.Enforce(3600, "ssh-ed25519", "192.168.1.1", cfg) if err != nil { t.Errorf("expected no error, got: %v", err) } } -func TestGetCriticalOptions(t *testing.T) { - opts := map[string]string{"source-address": "10.0.0.0/8"} - cfg := clientCfg(86400, nil, nil, opts) - result := GetCriticalOptions(cfg) +func TestBuildCriticalOptions_SourceAddress(t *testing.T) { + addr := "10.0.0.0/8" + result := BuildCriticalOptions(&addr, "") if result == nil { t.Fatal("expected non-nil critical options") } if result["source-address"] != "10.0.0.0/8" { - t.Errorf("expected source-address, got %v", result) + t.Errorf("expected source-address=10.0.0.0/8, got %v", result) } } -func TestGetCriticalOptions_Nil(t *testing.T) { - cfg := clientCfg(86400, nil, nil, nil) - result := GetCriticalOptions(cfg) +func TestBuildCriticalOptions_ForceCommand(t *testing.T) { + result := BuildCriticalOptions(nil, "sbatch /path/to/job.sh") + if result == nil { + t.Fatal("expected non-nil critical options") + } + if result["force-command"] != "sbatch /path/to/job.sh" { + t.Errorf("expected force-command, got %v", result) + } +} + +func TestBuildCriticalOptions_Both(t *testing.T) { + addr := "10.0.0.0/8" + result := BuildCriticalOptions(&addr, "sbatch /job.sh") + if len(result) != 2 { + t.Errorf("expected 2 critical options, got %d", len(result)) + } +} + +func TestBuildCriticalOptions_Empty(t *testing.T) { + result := BuildCriticalOptions(nil, "") if result != nil { t.Errorf("expected nil critical options, got %v", result) } diff --git a/signer/internal/store/certificate_query.go b/signer/internal/store/certificate_query.go index 3891920bf..5ceb864c7 100644 --- a/signer/internal/store/certificate_query.go +++ b/signer/internal/store/certificate_query.go @@ -17,6 +17,8 @@ package store import ( "context" + "database/sql" + "encoding/json" "fmt" "time" ) @@ -35,6 +37,8 @@ type CertificateWithStatus struct { ValidBefore time.Time IssuedAt time.Time SourceIP string + GrantedExtensions []string + ForceCommand *string Revoked bool RevokedAt *time.Time RevocationReason string @@ -74,6 +78,7 @@ func (d *DB) ListCertificatesByEmail(ctx context.Context, email string, limit, o c.id, c.tenant_id, c.client_id, c.serial_number, c.key_id, c.principal, COALESCE(c.user_email, ''), c.public_key_fingerprint, c.ca_fingerprint, c.valid_after, c.valid_before, c.issued_at, COALESCE(c.source_ip, ''), + c.granted_extensions, c.force_command, CASE WHEN r.id IS NOT NULL THEN TRUE ELSE FALSE END AS revoked, r.revoked_at, COALESCE(r.reason, '') AS revocation_reason @@ -93,16 +98,27 @@ func (d *DB) ListCertificatesByEmail(ctx context.Context, email string, limit, o for rows.Next() { var cert CertificateWithStatus var revokedAt *time.Time + var grantedExtensionsJSON []byte + var forceCommand sql.NullString if err := rows.Scan( &cert.ID, &cert.TenantID, &cert.ClientID, &cert.SerialNumber, &cert.KeyID, &cert.Principal, &cert.UserEmail, &cert.PublicKeyFingerprint, &cert.CAFingerprint, &cert.ValidAfter, &cert.ValidBefore, &cert.IssuedAt, &cert.SourceIP, + &grantedExtensionsJSON, &forceCommand, &cert.Revoked, &revokedAt, &cert.RevocationReason, ); err != nil { return nil, fmt.Errorf("scanning certificate row: %w", err) } + if grantedExtensionsJSON != nil { + if err := json.Unmarshal(grantedExtensionsJSON, &cert.GrantedExtensions); err != nil { + return nil, fmt.Errorf("unmarshaling granted_extensions: %w", err) + } + } + if forceCommand.Valid { + cert.ForceCommand = &forceCommand.String + } cert.RevokedAt = revokedAt certs = append(certs, cert) } @@ -120,12 +136,15 @@ func (d *DB) ListCertificatesByEmail(ctx context.Context, email string, limit, o func (d *DB) GetCertificateBySerial(ctx context.Context, serial int64) (*CertificateWithStatus, error) { var cert CertificateWithStatus var revokedAt *time.Time + var grantedExtensionsJSON []byte + var forceCommand sql.NullString err := d.QueryRowContext(ctx, `SELECT c.id, c.tenant_id, c.client_id, c.serial_number, c.key_id, c.principal, COALESCE(c.user_email, ''), c.public_key_fingerprint, c.ca_fingerprint, c.valid_after, c.valid_before, c.issued_at, COALESCE(c.source_ip, ''), + c.granted_extensions, c.force_command, CASE WHEN r.id IS NOT NULL THEN TRUE ELSE FALSE END AS revoked, r.revoked_at, COALESCE(r.reason, '') AS revocation_reason @@ -137,12 +156,21 @@ func (d *DB) GetCertificateBySerial(ctx context.Context, serial int64) (*Certifi &cert.ID, &cert.TenantID, &cert.ClientID, &cert.SerialNumber, &cert.KeyID, &cert.Principal, &cert.UserEmail, &cert.PublicKeyFingerprint, &cert.CAFingerprint, &cert.ValidAfter, &cert.ValidBefore, &cert.IssuedAt, &cert.SourceIP, + &grantedExtensionsJSON, &forceCommand, &cert.Revoked, &revokedAt, &cert.RevocationReason, ) if err != nil { return nil, fmt.Errorf("querying certificate: %w", err) } + if grantedExtensionsJSON != nil { + if err := json.Unmarshal(grantedExtensionsJSON, &cert.GrantedExtensions); err != nil { + return nil, fmt.Errorf("unmarshaling granted_extensions: %w", err) + } + } + if forceCommand.Valid { + cert.ForceCommand = &forceCommand.String + } cert.RevokedAt = revokedAt return &cert, nil } diff --git a/signer/internal/store/client_config.go b/signer/internal/store/client_config.go index 7c2d95c48..606c2e32d 100644 --- a/signer/internal/store/client_config.go +++ b/signer/internal/store/client_config.go @@ -31,8 +31,7 @@ type ClientConfig struct { MaxTTLSeconds int AllowedKeyTypes []string SourceAddressRestriction *string - CriticalOptions map[string]string - Extensions map[string]string + DeniedExtensions []string Enabled bool } @@ -41,21 +40,20 @@ func (d *DB) GetClientConfig(ctx context.Context, tenantID, clientID string) (*C cc ClientConfig allowedKeyTypesJSON []byte sourceAddressRestriction sql.NullString - criticalOptionsJSON sql.NullString - extensionsJSON sql.NullString + deniedExtensionsJSON sql.NullString ) err := d.QueryRowContext(ctx, `SELECT tenant_id, client_id, client_secret, target_host, target_port, max_ttl_seconds, allowed_key_types, source_address_restriction, - critical_options, extensions, enabled + denied_extensions, enabled FROM client_ssh_configs WHERE tenant_id = ? AND client_id = ?`, tenantID, clientID, ).Scan( &cc.TenantID, &cc.ClientID, &cc.ClientSecret, &cc.TargetHost, &cc.TargetPort, &cc.MaxTTLSeconds, &allowedKeyTypesJSON, &sourceAddressRestriction, - &criticalOptionsJSON, &extensionsJSON, &cc.Enabled, + &deniedExtensionsJSON, &cc.Enabled, ) if err != nil { if err == sql.ErrNoRows { @@ -72,15 +70,9 @@ func (d *DB) GetClientConfig(ctx context.Context, tenantID, clientID string) (*C cc.SourceAddressRestriction = &sourceAddressRestriction.String } - if criticalOptionsJSON.Valid && criticalOptionsJSON.String != "" { - if err := json.Unmarshal([]byte(criticalOptionsJSON.String), &cc.CriticalOptions); err != nil { - return nil, fmt.Errorf("unmarshaling critical_options: %w", err) - } - } - - if extensionsJSON.Valid && extensionsJSON.String != "" { - if err := json.Unmarshal([]byte(extensionsJSON.String), &cc.Extensions); err != nil { - return nil, fmt.Errorf("unmarshaling extensions: %w", err) + if deniedExtensionsJSON.Valid && deniedExtensionsJSON.String != "" { + if err := json.Unmarshal([]byte(deniedExtensionsJSON.String), &cc.DeniedExtensions); err != nil { + return nil, fmt.Errorf("unmarshaling denied_extensions: %w", err) } } diff --git a/signer/internal/store/issuance_log.go b/signer/internal/store/issuance_log.go index 823656eca..f9ef4ef51 100644 --- a/signer/internal/store/issuance_log.go +++ b/signer/internal/store/issuance_log.go @@ -34,29 +34,40 @@ type IssuanceLog struct { ValidAfter time.Time ValidBefore time.Time SourceIP string + GrantedExtensions []string + ForceCommand *string UserAccessTokenHash string RequestMetadata map[string]interface{} } // InsertIssuanceLog writes a certificate issuance entry. Fails on duplicate serial_number. func (d *DB) InsertIssuanceLog(ctx context.Context, log *IssuanceLog) error { + grantedExtensions := log.GrantedExtensions + if grantedExtensions == nil { + grantedExtensions = []string{} + } + grantedExtensionsJSON, err := json.Marshal(grantedExtensions) + if err != nil { + return fmt.Errorf("marshaling granted_extensions: %w", err) + } + var metadataJSON []byte if log.RequestMetadata != nil { - var err error metadataJSON, err = json.Marshal(log.RequestMetadata) if err != nil { return fmt.Errorf("marshaling request_metadata: %w", err) } } - _, err := d.ExecContext(ctx, + _, err = d.ExecContext(ctx, `INSERT INTO certificate_issuance_logs (tenant_id, client_id, serial_number, key_id, principal, user_email, public_key_fingerprint, - ca_fingerprint, valid_after, valid_before, source_ip, user_access_token_hash, request_metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + ca_fingerprint, valid_after, valid_before, source_ip, granted_extensions, force_command, + user_access_token_hash, request_metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, log.TenantID, log.ClientID, log.SerialNumber, log.KeyID, log.Principal, log.UserEmail, log.PublicKeyFingerprint, log.CAFingerprint, log.ValidAfter, log.ValidBefore, - log.SourceIP, log.UserAccessTokenHash, metadataJSON, + log.SourceIP, grantedExtensionsJSON, log.ForceCommand, log.UserAccessTokenHash, metadataJSON, ) if err != nil { return fmt.Errorf("inserting issuance log: %w", err) diff --git a/signer/migrations/001_initial_schema.up.sql b/signer/migrations/001_initial_schema.up.sql index 1fd2f2211..5da7969cd 100644 --- a/signer/migrations/001_initial_schema.up.sql +++ b/signer/migrations/001_initial_schema.up.sql @@ -11,8 +11,7 @@ CREATE TABLE IF NOT EXISTS client_ssh_configs max_ttl_seconds INT NOT NULL DEFAULT 86400, allowed_key_types JSON NOT NULL, source_address_restriction VARCHAR(255) NULL, - critical_options JSON NULL, - extensions JSON NULL, + denied_extensions JSON NULL, enabled BOOLEAN NOT NULL DEFAULT TRUE, created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) @@ -40,6 +39,8 @@ CREATE TABLE IF NOT EXISTS certificate_issuance_logs valid_before TIMESTAMP(6) NOT NULL, issued_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), source_ip VARCHAR(45) NULL, + granted_extensions JSON NOT NULL, + force_command TEXT NULL, user_access_token_hash VARCHAR(255) NULL, request_metadata JSON NULL,
