This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit 76aa5b79b27ea18bc8195139d3f5603f0e5c2672 Author: lahiruj <[email protected]> AuthorDate: Mon Apr 6 18:46:41 2026 -0400 LDAP connection pooling --- signer/internal/validation/dispatcher.go | 27 +++++++- signer/internal/validation/ldap.go | 109 +++++++++++++++++++++---------- signer/internal/validation/ldap_test.go | 73 ++++++++++++++++++++- 3 files changed, 171 insertions(+), 38 deletions(-) diff --git a/signer/internal/validation/dispatcher.go b/signer/internal/validation/dispatcher.go index a75a725bb..55106381d 100644 --- a/signer/internal/validation/dispatcher.go +++ b/signer/internal/validation/dispatcher.go @@ -44,6 +44,7 @@ type ValidatorDispatcher struct { cacheTTL time.Duration negativeTTL time.Duration cache map[string]*cachedCreds + validators map[string]*LDAPValidator mu sync.RWMutex logger *slog.Logger } @@ -65,6 +66,7 @@ func NewValidatorDispatcher( cacheTTL: cacheTTL, negativeTTL: 30 * time.Second, cache: make(map[string]*cachedCreds), + validators: make(map[string]*LDAPValidator), logger: logger, } } @@ -124,6 +126,20 @@ func (d *ValidatorDispatcher) validateLDAP( } } + validator := d.getOrCreateValidator(tenantID, clientID, creds) + return validator.Validate(tenantID, clientID, principal, identitySubject) +} + +func (d *ValidatorDispatcher) getOrCreateValidator(tenantID, clientID string, creds *vault.ValidationCredentials) *LDAPValidator { + key := tenantID + ":" + clientID + + d.mu.RLock() + if v, ok := d.validators[key]; ok { + d.mu.RUnlock() + return v + } + d.mu.RUnlock() + verifySSL := true if creds.VerifySSL != nil { verifySSL = *creds.VerifySSL @@ -139,7 +155,16 @@ func (d *ValidatorDispatcher) validateLDAP( VerifySSL: verifySSL, }, d.ldapConnector) - return validator.Validate(tenantID, clientID, principal, identitySubject) + d.mu.Lock() + // Double-check after acquiring write lock + if v, ok := d.validators[key]; ok { + d.mu.Unlock() + return v + } + d.validators[key] = validator + d.mu.Unlock() + + return validator } func (d *ValidatorDispatcher) fetchCredentials(ctx context.Context, tenantID, clientID string) (*vault.ValidationCredentials, error) { diff --git a/signer/internal/validation/ldap.go b/signer/internal/validation/ldap.go index 79c497747..b8e82fef9 100644 --- a/signer/internal/validation/ldap.go +++ b/signer/internal/validation/ldap.go @@ -18,6 +18,7 @@ package validation import ( "crypto/tls" "fmt" + "sync" "github.com/go-ldap/ldap/v3" ) @@ -69,9 +70,13 @@ type LDAPConfig struct { } // LDAPValidator resolves an OIDC subject to a POSIX username via LDAP directory lookup. +// It maintains a persistent LDAP connection that is reused across requests and +// automatically reconnects on failure. type LDAPValidator struct { config LDAPConfig connector LDAPConnector + mu sync.Mutex + conn LDAPConnection } func NewLDAPValidator(config LDAPConfig, connector LDAPConnector) *LDAPValidator { @@ -87,51 +92,30 @@ func NewLDAPValidator(config LDAPConfig, connector LDAPConnector) *LDAPValidator } func (v *LDAPValidator) Validate(tenantID, clientID, principal, identitySubject string) (*ValidationResult, error) { - conn, err := v.connector.Connect(v.config.URL, v.config.VerifySSL) - if err != nil { - return nil, &ValidationError{ - Message: "Principal validation unavailable: LDAP connection failed", - ReasonCode: "LDAP_UNAVAILABLE", - } - } - defer conn.Close() + v.mu.Lock() + defer v.mu.Unlock() - if err := conn.Bind(v.config.BindDN, v.config.BindPassword); err != nil { - return nil, &ValidationError{ - Message: "Principal validation unavailable: LDAP bind failed", - ReasonCode: "LDAP_UNAVAILABLE", - } - } - - // Search using the OIDC subject, not the requested principal - filter := fmt.Sprintf(v.config.SearchFilter, ldap.EscapeFilter(identitySubject)) - - result, err := conn.Search(ldap.NewSearchRequest( - v.config.BaseDN, - ldap.ScopeWholeSubtree, - ldap.NeverDerefAliases, - 1, // SizeLimit - 10, // TimeLimit (seconds) - false, - filter, - []string{v.config.UsernameAttribute}, - nil, - )) + result, err := v.searchLDAP(identitySubject) if err != nil { - return nil, &ValidationError{ - Message: "Principal validation unavailable: LDAP search failed", - ReasonCode: "LDAP_UNAVAILABLE", + // Connection may be stale — close, reconnect, and retry once + v.closeConn() + result, err = v.searchLDAP(identitySubject) + if err != nil { + v.closeConn() + return nil, &ValidationError{ + Message: "Principal validation unavailable: " + err.Error(), + ReasonCode: "LDAP_UNAVAILABLE", + } } } if len(result.Entries) == 0 { return nil, &ValidationError{ - Message: fmt.Sprintf("No POSIX account found for identity subject in directory"), + Message: "No POSIX account found for identity subject in directory", ReasonCode: "LDAP_IDENTITY_NOT_FOUND", } } - // Extract the POSIX username from the directory entry resolvedUsername := result.Entries[0].GetAttributeValue(v.config.UsernameAttribute) if resolvedUsername == "" { return nil, &ValidationError{ @@ -140,7 +124,6 @@ func (v *LDAPValidator) Validate(tenantID, clientID, principal, identitySubject } } - // Verify the requested principal matches the directory-resolved username if resolvedUsername != principal { return nil, &ValidationError{ Message: fmt.Sprintf("Requested principal %q does not match directory account %q", principal, resolvedUsername), @@ -153,3 +136,59 @@ func (v *LDAPValidator) Validate(tenantID, clientID, principal, identitySubject ValidatedPrincipal: resolvedUsername, }, nil } + +// ensureConn returns the existing connection or creates a new one with bind. +func (v *LDAPValidator) ensureConn() (LDAPConnection, error) { + if v.conn != nil { + return v.conn, nil + } + + conn, err := v.connector.Connect(v.config.URL, v.config.VerifySSL) + if err != nil { + return nil, fmt.Errorf("LDAP connection failed") + } + + if err := conn.Bind(v.config.BindDN, v.config.BindPassword); err != nil { + conn.Close() + return nil, fmt.Errorf("LDAP bind failed") + } + + v.conn = conn + return conn, nil +} + +// searchLDAP executes the principal lookup on the current or newly established connection. +func (v *LDAPValidator) searchLDAP(identitySubject string) (*ldap.SearchResult, error) { + conn, err := v.ensureConn() + if err != nil { + return nil, err + } + + filter := fmt.Sprintf(v.config.SearchFilter, ldap.EscapeFilter(identitySubject)) + + return conn.Search(ldap.NewSearchRequest( + v.config.BaseDN, + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 1, // SizeLimit + 10, // TimeLimit (seconds) + false, + filter, + []string{v.config.UsernameAttribute}, + nil, + )) +} + +func (v *LDAPValidator) closeConn() { + if v.conn != nil { + v.conn.Close() + v.conn = nil + } +} + +// Close releases the persistent LDAP connection. +func (v *LDAPValidator) Close() { + v.mu.Lock() + defer v.mu.Unlock() + v.closeConn() +} diff --git a/signer/internal/validation/ldap_test.go b/signer/internal/validation/ldap_test.go index f33ac0b07..5ff86bc44 100644 --- a/signer/internal/validation/ldap_test.go +++ b/signer/internal/validation/ldap_test.go @@ -37,11 +37,13 @@ func (m *mockLDAPConn) Close() error { return nil } // mockLDAPConnector implements LDAPConnector for testing. type mockLDAPConnector struct { - conn LDAPConnection - connErr error + conn LDAPConnection + connErr error + callCount int } func (m *mockLDAPConnector) Connect(url string, verifySSL bool) (LDAPConnection, error) { + m.callCount++ return m.conn, m.connErr } @@ -229,6 +231,73 @@ func TestLDAPValidator_DefaultUsernameAttribute(t *testing.T) { } } +func TestLDAPValidator_ConnectionReuse(t *testing.T) { + conn := &mockLDAPConn{ + searchRes: &ldap.SearchResult{ + Entries: []*ldap.Entry{ldapEntry("uid=jdoe,ou=people,dc=test", "jdoe")}, + }, + } + connector := &mockLDAPConnector{conn: conn} + v := NewLDAPValidator(baseLDAPConfig(), connector) + + // Two calls should reuse the same connection + v.Validate("t1", "c1", "jdoe", "sub123") + v.Validate("t1", "c1", "jdoe", "sub123") + + if connector.callCount != 1 { + t.Errorf("expected 1 connect call (reuse), got %d", connector.callCount) + } +} + +func TestLDAPValidator_ReconnectOnSearchError(t *testing.T) { + callCount := 0 + // First connection fails on search, second succeeds + failConn := &mockLDAPConn{searchErr: errors.New("connection reset")} + goodConn := &mockLDAPConn{ + searchRes: &ldap.SearchResult{ + Entries: []*ldap.Entry{ldapEntry("uid=jdoe,ou=people,dc=test", "jdoe")}, + }, + } + connector := &mockLDAPConnector{} + // Return fail conn first, then good conn + connector.conn = failConn + origConnect := connector.Connect + _ = origConnect // suppress unused + // Override with a function that switches behavior + switchingConnector := &switchConnector{ + conns: []LDAPConnection{failConn, goodConn}, + callCount: &callCount, + } + v := NewLDAPValidator(baseLDAPConfig(), switchingConnector) + + result, err := v.Validate("t1", "c1", "jdoe", "sub123") + if err != nil { + t.Fatalf("expected success after reconnect, got: %v", err) + } + if !result.Allowed { + t.Error("expected Allowed=true") + } + // Should have connected twice: once failed, once reconnected + if callCount != 2 { + t.Errorf("expected 2 connect calls (reconnect), got %d", callCount) + } +} + +// switchConnector returns different connections on successive calls. +type switchConnector struct { + conns []LDAPConnection + callCount *int +} + +func (c *switchConnector) Connect(url string, verifySSL bool) (LDAPConnection, error) { + idx := *c.callCount + *c.callCount++ + if idx < len(c.conns) { + return c.conns[idx], nil + } + return c.conns[len(c.conns)-1], nil +} + // filterCapturingConn wraps LDAPConnection to capture the search filter. type filterCapturingConn struct { LDAPConnection
