This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit 0457b25288a55f21fc7cd90893b8eb94604cdcdd Author: lahiruj <[email protected]> AuthorDate: Mon Apr 6 14:22:56 2026 -0400 add SSH extension validation and resolution logic with tests --- signer/internal/cert/extensions.go | 121 ++++++++++++++++++++ signer/internal/cert/extensions_test.go | 188 ++++++++++++++++++++++++++++++++ 2 files changed, 309 insertions(+) diff --git a/signer/internal/cert/extensions.go b/signer/internal/cert/extensions.go new file mode 100644 index 000000000..9bcdddfd6 --- /dev/null +++ b/signer/internal/cert/extensions.go @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cert + +import "fmt" + +// SSHExtension represents a valid OpenSSH certificate extension name. +type SSHExtension string + +const ( + ExtPermitPTY SSHExtension = "permit-pty" + ExtPermitPortForwarding SSHExtension = "permit-port-forwarding" + ExtPermitUserRC SSHExtension = "permit-user-rc" + ExtPermitAgentForwarding SSHExtension = "permit-agent-forwarding" + ExtPermitX11Forwarding SSHExtension = "permit-X11-forwarding" + ExtNoTouchRequired SSHExtension = "no-touch-required" +) + +var validExtensions = map[SSHExtension]bool{ + ExtPermitPTY: true, + ExtPermitPortForwarding: true, + ExtPermitUserRC: true, + ExtPermitAgentForwarding: true, + ExtPermitX11Forwarding: true, + ExtNoTouchRequired: true, +} + +// AllStandardExtensions returns the 5 standard non-FIDO extensions. +// no-touch-required is excluded as it is FIDO-specific. +func AllStandardExtensions() []SSHExtension { + return []SSHExtension{ + ExtPermitPTY, + ExtPermitPortForwarding, + ExtPermitUserRC, + ExtPermitAgentForwarding, + ExtPermitX11Forwarding, + } +} + +func (e SSHExtension) Validate() error { + if !validExtensions[e] { + return fmt.Errorf("unknown SSH extension: %q", string(e)) + } + return nil +} + +// ValidateExtensionList validates a list of extension name strings and returns +// them as typed SSHExtension values. Returns an error naming the first invalid entry. +func ValidateExtensionList(names []string) ([]SSHExtension, error) { + result := make([]SSHExtension, 0, len(names)) + for _, name := range names { + ext := SSHExtension(name) + if err := ext.Validate(); err != nil { + return nil, err + } + result = append(result, ext) + } + return result, nil +} + +// ResolveExtensions computes the granted extension set by starting with all +// standard extensions and removing denied and excluded entries. +func ResolveExtensions(denied []string, excluded []string) ([]SSHExtension, error) { + // Validate inputs + if _, err := ValidateExtensionList(denied); err != nil { + return nil, fmt.Errorf("invalid denied extension: %w", err) + } + if _, err := ValidateExtensionList(excluded); err != nil { + return nil, fmt.Errorf("invalid excluded extension: %w", err) + } + + // Build removal set + remove := make(map[SSHExtension]bool, len(denied)+len(excluded)) + for _, d := range denied { + remove[SSHExtension(d)] = true + } + for _, e := range excluded { + remove[SSHExtension(e)] = true + } + + // Filter + var granted []SSHExtension + for _, ext := range AllStandardExtensions() { + if !remove[ext] { + granted = append(granted, ext) + } + } + return granted, nil +} + +// ExtensionsToMap converts a list of SSHExtension values to the map[string]string +// format required by golang.org/x/crypto/ssh.Certificate.Permissions.Extensions. +func ExtensionsToMap(exts []SSHExtension) map[string]string { + m := make(map[string]string, len(exts)) + for _, e := range exts { + m[string(e)] = "" + } + return m +} + +// ExtensionNames converts a list of SSHExtension values to plain strings. +func ExtensionNames(exts []SSHExtension) []string { + names := make([]string, len(exts)) + for i, e := range exts { + names[i] = string(e) + } + return names +} diff --git a/signer/internal/cert/extensions_test.go b/signer/internal/cert/extensions_test.go new file mode 100644 index 000000000..cc8cf9bb2 --- /dev/null +++ b/signer/internal/cert/extensions_test.go @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cert + +import ( + "testing" +) + +func TestSSHExtension_Validate_Valid(t *testing.T) { + valid := []SSHExtension{ + ExtPermitPTY, ExtPermitPortForwarding, ExtPermitUserRC, + ExtPermitAgentForwarding, ExtPermitX11Forwarding, ExtNoTouchRequired, + } + for _, ext := range valid { + if err := ext.Validate(); err != nil { + t.Errorf("expected %q to be valid, got: %v", ext, err) + } + } +} + +func TestSSHExtension_Validate_Invalid(t *testing.T) { + invalid := []SSHExtension{"unknown-ext", "permit-PTY", "PERMIT-PTY", ""} + for _, ext := range invalid { + if err := ext.Validate(); err == nil { + t.Errorf("expected %q to be invalid", ext) + } + } +} + +func TestValidateExtensionList_Valid(t *testing.T) { + names := []string{"permit-pty", "permit-port-forwarding"} + result, err := ValidateExtensionList(names) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 2 { + t.Errorf("expected 2 extensions, got %d", len(result)) + } +} + +func TestValidateExtensionList_InvalidEntry(t *testing.T) { + names := []string{"permit-pty", "bogus-extension"} + _, err := ValidateExtensionList(names) + if err == nil { + t.Fatal("expected error for invalid extension") + } +} + +func TestValidateExtensionList_Empty(t *testing.T) { + result, err := ValidateExtensionList(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 0 { + t.Errorf("expected 0 extensions, got %d", len(result)) + } +} + +func TestAllStandardExtensions_Count(t *testing.T) { + exts := AllStandardExtensions() + if len(exts) != 5 { + t.Errorf("expected 5 standard extensions, got %d", len(exts)) + } +} + +func TestAllStandardExtensions_ExcludesNoTouchRequired(t *testing.T) { + for _, ext := range AllStandardExtensions() { + if ext == ExtNoTouchRequired { + t.Error("AllStandardExtensions should not include no-touch-required") + } + } +} + +func TestResolveExtensions_NoDenials(t *testing.T) { + granted, err := ResolveExtensions(nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(granted) != 5 { + t.Errorf("expected 5 extensions, got %d", len(granted)) + } +} + +func TestResolveExtensions_DenyOne(t *testing.T) { + granted, err := ResolveExtensions([]string{"permit-port-forwarding"}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(granted) != 4 { + t.Errorf("expected 4 extensions, got %d", len(granted)) + } + for _, ext := range granted { + if ext == ExtPermitPortForwarding { + t.Error("permit-port-forwarding should be denied") + } + } +} + +func TestResolveExtensions_ExcludeOne(t *testing.T) { + granted, err := ResolveExtensions(nil, []string{"permit-user-rc"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(granted) != 4 { + t.Errorf("expected 4 extensions, got %d", len(granted)) + } + for _, ext := range granted { + if ext == ExtPermitUserRC { + t.Error("permit-user-rc should be excluded") + } + } +} + +func TestResolveExtensions_DenyAndExcludeOverlap(t *testing.T) { + granted, err := ResolveExtensions( + []string{"permit-port-forwarding"}, + []string{"permit-port-forwarding", "permit-user-rc"}, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(granted) != 3 { + t.Errorf("expected 3 extensions, got %d", len(granted)) + } +} + +func TestResolveExtensions_DenyAll(t *testing.T) { + all := ExtensionNames(AllStandardExtensions()) + granted, err := ResolveExtensions(all, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(granted) != 0 { + t.Errorf("expected 0 extensions, got %d", len(granted)) + } +} + +func TestResolveExtensions_InvalidDenied(t *testing.T) { + _, err := ResolveExtensions([]string{"bogus"}, nil) + if err == nil { + t.Fatal("expected error for invalid denied extension") + } +} + +func TestResolveExtensions_InvalidExcluded(t *testing.T) { + _, err := ResolveExtensions(nil, []string{"bogus"}) + if err == nil { + t.Fatal("expected error for invalid excluded extension") + } +} + +func TestExtensionsToMap(t *testing.T) { + exts := []SSHExtension{ExtPermitPTY, ExtPermitPortForwarding} + m := ExtensionsToMap(exts) + if len(m) != 2 { + t.Errorf("expected 2 entries, got %d", len(m)) + } + if _, ok := m["permit-pty"]; !ok { + t.Error("expected permit-pty in map") + } + if v := m["permit-pty"]; v != "" { + t.Errorf("expected empty value for permit-pty, got %q", v) + } +} + +func TestExtensionNames(t *testing.T) { + exts := []SSHExtension{ExtPermitPTY, ExtPermitUserRC} + names := ExtensionNames(exts) + if len(names) != 2 { + t.Errorf("expected 2 names, got %d", len(names)) + } + if names[0] != "permit-pty" || names[1] != "permit-user-rc" { + t.Errorf("unexpected names: %v", names) + } +}
