hanahmily commented on code in PR #920: URL: https://github.com/apache/skywalking-banyandb/pull/920#discussion_r2661114046
########## banyand/metadata/dns/dns_test.go: ########## @@ -0,0 +1,1714 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package dns_test + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "sync" + "testing" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" + + commonv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/common/v1" + databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" + "github.com/apache/skywalking-banyandb/banyand/metadata/dns" + "github.com/apache/skywalking-banyandb/banyand/metadata/schema" + "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/test/flags" +) + +func TestDNS(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "DNS Discovery Suite", Label("integration")) +} + +var _ = BeforeSuite(func() { + Expect(logger.Init(logger.Logging{ + Env: "dev", + Level: flags.LogLevel, + })).To(Succeed()) +}) + +var _ = Describe("DNS Discovery Service", func() { + var ( + ctx context.Context + cancel context.CancelFunc + mockServer *mockNodeQueryServer + grpcServer *grpc.Server + listener net.Listener + ) + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + + // Setup mock gRPC server + listener, grpcServer, mockServer = setupMockGRPCServer() + + // Add test node to mock server + mockServer.node = createTestNode("node1", listener.Addr().String()) + }) + + AfterEach(func() { + if grpcServer != nil { + grpcServer.Stop() + } + if listener != nil { + _ = listener.Close() + } + cancel() + }) + + Describe("NewService", func() { + It("should create service with valid config", func() { + config := dns.Config{ + SRVAddresses: []string{"_grpc._tcp.test.local"}, + InitInterval: 1 * time.Second, + InitDuration: 10 * time.Second, + PollInterval: 5 * time.Second, + GRPCTimeout: 3 * time.Second, + TLSEnabled: false, + } + + svc, err := dns.NewService(config) + Expect(err).NotTo(HaveOccurred()) + Expect(svc).NotTo(BeNil()) + Expect(svc.Close()).To(Succeed()) + }) + + It("should fail with empty SRV addresses", func() { + config := dns.Config{ + SRVAddresses: []string{}, + } + + _, err := dns.NewService(config) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("SRV addresses cannot be empty")) + }) + + It("should fail with invalid CA cert path when TLS enabled", func() { + config := dns.Config{ + SRVAddresses: []string{"_grpc._tcp.test.local"}, + TLSEnabled: true, + CACertPaths: []string{"/non/existent/path/ca.crt"}, + } + + _, err := dns.NewService(config) + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("Node Registry Interface", func() { + var ( + svc *dns.Service + mockResolver *mockDNSResolver + grpcServer *grpc.Server + registryListener net.Listener + mockServer *mockNodeQueryServer + ) + + BeforeEach(func() { + mockResolver = newMockDNSResolver() + + // Setup mock gRPC server + registryListener, grpcServer, mockServer = setupMockGRPCServer() + + // Set up DNS to return our gRPC server + serverAddr := registryListener.Addr().String() + mockResolver.setResponse("_grpc._tcp.test.local", []*net.SRV{addrToSRV(serverAddr)}) + + // Configure mock server with test nodes + mockServer.node = createTestNode("registry-test-node", serverAddr, + databasev1.Role_ROLE_DATA, databasev1.Role_ROLE_LIAISON) + + var err error + svc, err = dns.NewServiceWithResolver(createDefaultConfig(), mockResolver) + Expect(err).NotTo(HaveOccurred()) + + // Start the discovery service + err = svc.Start(ctx) + Expect(err).NotTo(HaveOccurred()) + + // Wait for initial DNS discovery to complete + time.Sleep(300 * time.Millisecond) + }) + + AfterEach(func() { + if svc != nil { + _ = svc.Close() + svc = nil + } + if grpcServer != nil { + grpcServer.Stop() + grpcServer = nil + } + if registryListener != nil { + _ = registryListener.Close() + registryListener = nil + } + }) + + It("should return error for RegisterNode", func() { + node := &databasev1.Node{ + Metadata: &commonv1.Metadata{ + Name: "test-node", + }, + } + + err := svc.RegisterNode(ctx, node, false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not supported in DNS discovery mode")) + }) + + It("should return error for UpdateNode", func() { + node := &databasev1.Node{ + Metadata: &commonv1.Metadata{ + Name: "test-node", + }, + } + + err := svc.UpdateNode(ctx, node) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not supported in DNS discovery mode")) + }) + + It("should list nodes by role after DNS discovery", func() { + // Verify node was discovered + nodes, err := svc.ListNode(ctx, databasev1.Role_ROLE_DATA) + Expect(err).NotTo(HaveOccurred()) + Expect(nodes).To(HaveLen(1)) + Expect(nodes[0].GetMetadata().GetName()).To(Equal("registry-test-node")) + Expect(nodes[0].GetRoles()).To(ContainElement(databasev1.Role_ROLE_DATA)) + + // Verify role filtering works + liaisonNodes, err := svc.ListNode(ctx, databasev1.Role_ROLE_LIAISON) + Expect(err).NotTo(HaveOccurred()) + Expect(liaisonNodes).To(HaveLen(1)) + + metadataNodes, err := svc.ListNode(ctx, databasev1.Role_ROLE_META) + Expect(err).NotTo(HaveOccurred()) + Expect(metadataNodes).To(HaveLen(0)) + + // Verify DNS resolver was called + Expect(mockResolver.getCallCount("_grpc._tcp.test.local")).To(BeNumerically(">=", 1)) + }) + + It("should get node by name after DNS discovery", func() { + // Verify node can be retrieved by name + node, err := svc.GetNode(ctx, "registry-test-node") + Expect(err).NotTo(HaveOccurred()) + Expect(node).NotTo(BeNil()) + Expect(node.GetMetadata().GetName()).To(Equal("registry-test-node")) + Expect(node.GetGrpcAddress()).NotTo(BeEmpty()) + + // Verify non-existent node returns error + _, err = svc.GetNode(ctx, "non-existent-node") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not found")) + }) + }) + + Describe("DNS Query Operations", func() { + var ( + svc *dns.Service + mockResolver *mockDNSResolver + config dns.Config + ) + + BeforeEach(func() { + mockResolver = newMockDNSResolver() + config = createDefaultConfig() + }) + + AfterEach(func() { + if svc != nil { + Expect(svc.Close()).To(Succeed()) + } + }) + + It("should successfully query single SRV record", func() { + mockResolver.setResponse("_grpc._tcp.test.local", []*net.SRV{ + {Target: "node1.test.local", Port: 17912}, + }) + + var queryErr error + svc, queryErr = dns.NewServiceWithResolver(config, mockResolver) + Expect(queryErr).NotTo(HaveOccurred()) + + addresses, queryErr := svc.QueryAllSRVRecords(ctx) + Expect(queryErr).NotTo(HaveOccurred()) + Expect(addresses).To(HaveLen(1)) + Expect(addresses).To(ContainElement("node1.test.local:17912")) + + // Verify DNS resolver was called exactly once + Expect(mockResolver.getCallCount("_grpc._tcp.test.local")).To(Equal(1)) + }) + + It("should successfully query and deduplicate multiple SRV records", func() { + config.SRVAddresses = []string{ + "_grpc._tcp.zone1.local", + "_grpc._tcp.zone2.local", + } + + mockResolver.setResponse("_grpc._tcp.zone1.local", []*net.SRV{ + {Target: "node1.test.local", Port: 17912}, + {Target: "node2.test.local", Port: 17912}, + }) + mockResolver.setResponse("_grpc._tcp.zone2.local", []*net.SRV{ + {Target: "node1.test.local", Port: 17912}, // Duplicate + {Target: "node3.test.local", Port: 17912}, + }) + + var queryErr error + svc, queryErr = dns.NewServiceWithResolver(config, mockResolver) + Expect(queryErr).NotTo(HaveOccurred()) + + addresses, queryErr := svc.QueryAllSRVRecords(ctx) + Expect(queryErr).NotTo(HaveOccurred()) + Expect(addresses).To(HaveLen(3)) // Deduplicated + Expect(addresses).To(ContainElements( + "node1.test.local:17912", + "node2.test.local:17912", + "node3.test.local:17912", + )) + + // Verify each DNS address was queried exactly once + Expect(mockResolver.getCallCount("_grpc._tcp.zone1.local")).To(Equal(1)) + Expect(mockResolver.getCallCount("_grpc._tcp.zone2.local")).To(Equal(1)) + }) + + It("should fail when all DNS queries fail (🎯 critical scenario)", func() { + mockResolver.setError("_grpc._tcp.test.local", fmt.Errorf("DNS server unavailable")) + + var queryErr error + svc, queryErr = dns.NewServiceWithResolver(config, mockResolver) + Expect(queryErr).NotTo(HaveOccurred()) + + addresses, queryErr := svc.QueryAllSRVRecords(ctx) + Expect(queryErr).To(HaveOccurred()) + Expect(queryErr.Error()).To(ContainSubstring("DNS server unavailable")) + Expect(addresses).To(BeEmpty()) + + // Verify DNS was still called (and failed) + Expect(mockResolver.getCallCount("_grpc._tcp.test.local")).To(Equal(1)) + }) + + It("should fail when any DNS query fails (🎯 critical scenario)", func() { Review Comment: It should not fail when part of the DNS query fails. ########## banyand/metadata/dns/dns.go: ########## @@ -0,0 +1,603 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package dns implements DNS-based node discovery for distributed metadata management. +package dns + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + + databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" + "github.com/apache/skywalking-banyandb/banyand/metadata/schema" + "github.com/apache/skywalking-banyandb/banyand/observability" + "github.com/apache/skywalking-banyandb/pkg/grpchelper" + "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" + pkgtls "github.com/apache/skywalking-banyandb/pkg/tls" +) + +// Service implements DNS-based node discovery. +type Service struct { + lastQueryTime time.Time + resolver Resolver + pathToReloader map[string]*pkgtls.Reloader + srvIndexToPath map[int]string + resolvedAddrToSRVIdx map[string]int + nodeCache map[string]*databasev1.Node + closer *run.Closer + log *logger.Logger + metrics *metrics + handlers map[string]schema.EventHandler + caCertPaths []string + srvAddresses []string + lastSuccessfulDNS []string + pollInterval time.Duration + initInterval time.Duration + initDuration time.Duration + grpcTimeout time.Duration + cacheMutex sync.RWMutex + handlersMutex sync.RWMutex + lastQueryMutex sync.RWMutex + resolvedAddrMutex sync.RWMutex + tlsEnabled bool +} + +// Config holds configuration for DNS discovery service. +type Config struct { + CACertPaths []string + SRVAddresses []string + InitInterval time.Duration + InitDuration time.Duration + PollInterval time.Duration + GRPCTimeout time.Duration + TLSEnabled bool +} + +// Resolver defines the interface for DNS SRV lookups. +type Resolver interface { + LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error) +} + +// defaultResolver wraps net.DefaultResolver to implement Resolver. +type defaultResolver struct{} + +func (d *defaultResolver) LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error) { + return net.DefaultResolver.LookupSRV(ctx, "", "", name) +} + +// NewService creates a new DNS discovery service. +func NewService(cfg Config) (*Service, error) { + // validation + if len(cfg.SRVAddresses) == 0 { + return nil, errors.New("DNS SRV addresses cannot be empty") + } + + // validate CA cert paths match SRV addresses when TLS is enabled + if cfg.TLSEnabled { + if len(cfg.CACertPaths) != len(cfg.SRVAddresses) { + return nil, fmt.Errorf("number of CA cert paths (%d) must match number of SRV addresses (%d)", + len(cfg.CACertPaths), len(cfg.SRVAddresses)) + } + } + + svc := &Service{ + srvAddresses: cfg.SRVAddresses, + initInterval: cfg.InitInterval, + initDuration: cfg.InitDuration, + pollInterval: cfg.PollInterval, + grpcTimeout: cfg.GRPCTimeout, + tlsEnabled: cfg.TLSEnabled, + caCertPaths: cfg.CACertPaths, + nodeCache: make(map[string]*databasev1.Node), + handlers: make(map[string]schema.EventHandler), + lastSuccessfulDNS: []string{}, + pathToReloader: make(map[string]*pkgtls.Reloader), + srvIndexToPath: make(map[int]string), + resolvedAddrToSRVIdx: make(map[string]int), + closer: run.NewCloser(1), + log: logger.GetLogger("metadata-discovery-dns"), + resolver: &defaultResolver{}, + } + + // create shared reloaders for CA certificates + if svc.tlsEnabled { + for srvIdx, certPath := range cfg.CACertPaths { + // Store the SRV index → cert path mapping + svc.srvIndexToPath[srvIdx] = certPath + + // check if we already have a Reloader for this path + if _, exists := svc.pathToReloader[certPath]; exists { + svc.log.Debug().Str("certPath", certPath).Int("srvIndex", srvIdx). + Msg("Reusing existing CA certificate reloader") + continue + } + + // create new Reloader for this unique path + reloader, reloaderErr := pkgtls.NewClientCertReloader(certPath, svc.log) + if reloaderErr != nil { + // clean up any already-created reloaders + for _, r := range svc.pathToReloader { + r.Stop() + } + return nil, fmt.Errorf("failed to initialize CA certificate reloader for path %s (SRV index %d): %w", + certPath, srvIdx, reloaderErr) + } + + svc.pathToReloader[certPath] = reloader + svc.log.Info().Str("certPath", certPath).Int("srvIndex", srvIdx). + Str("srvAddress", cfg.SRVAddresses[srvIdx]).Msg("Initialized DNS CA certificate reloader") + } + } + + return svc, nil +} + +// newServiceWithResolver creates a service with a custom resolver (for testing). +func newServiceWithResolver(cfg Config, resolver Resolver) (*Service, error) { + svc, err := NewService(cfg) + if err != nil { + return nil, err + } + svc.resolver = resolver + return svc, nil +} + +func (s *Service) getTLSDialOptions(address string) ([]grpc.DialOption, error) { + if !s.tlsEnabled { + return []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, nil + } + + // look up which Reloader to use for this address + if len(s.pathToReloader) > 0 { + // Find which SRV address this resolved address came from + s.resolvedAddrMutex.RLock() + srvIdx, addrExists := s.resolvedAddrToSRVIdx[address] + s.resolvedAddrMutex.RUnlock() + + if !addrExists { + return nil, fmt.Errorf("no SRV mapping found for address %s", address) + } + + // look up the cert path for this SRV index + certPath, pathExists := s.srvIndexToPath[srvIdx] + if !pathExists { + return nil, fmt.Errorf("no cert path found for SRV index %d (address %s)", srvIdx, address) + } + + // get the Reloader for this cert path + reloader, reloaderExists := s.pathToReloader[certPath] + if !reloaderExists { + return nil, fmt.Errorf("no reloader found for cert path %s (address %s)", certPath, address) + } + + // get fresh TLS config from the Reloader + tlsConfig, configErr := reloader.GetClientTLSConfig("") + if configErr != nil { + return nil, fmt.Errorf("failed to get TLS config from reloader for address %s: %w", address, configErr) + } + + creds := credentials.NewTLS(tlsConfig) + return []grpc.DialOption{grpc.WithTransportCredentials(creds)}, nil + } + + // fallback to static TLS config (when no reloaders configured) + opts, err := grpchelper.SecureOptions(nil, s.tlsEnabled, false, "") + if err != nil { + return nil, fmt.Errorf("failed to load TLS config: %w", err) + } + return opts, nil +} + +// Start begins the DNS discovery background process. +func (s *Service) Start(ctx context.Context) error { + s.log.Debug().Msg("Starting DNS-based node discovery service") + + // start all Reloaders + if len(s.pathToReloader) > 0 { + startedReloaders := make([]*pkgtls.Reloader, 0, len(s.pathToReloader)) + + for certPath, reloader := range s.pathToReloader { + if startErr := reloader.Start(); startErr != nil { + // stop any already-started reloaders + for _, r := range startedReloaders { + r.Stop() + } + return fmt.Errorf("failed to start CA certificate reloader for path %s: %w", certPath, startErr) + } + startedReloaders = append(startedReloaders, reloader) + s.log.Debug().Str("certPath", certPath).Msg("Started CA certificate reloader") + } + } + + go s.discoveryLoop(ctx) + + return nil +} + +func (s *Service) discoveryLoop(ctx context.Context) { + // add the init phase finish time + initPhaseEnd := time.Now().Add(s.initDuration) + + for { + if err := s.queryDNSAndUpdateNodes(ctx); err != nil { + s.log.Err(err).Msg("failed to query DNS and update nodes") + } + + // wait for next interval + var interval time.Duration + if time.Now().Before(initPhaseEnd) { + interval = s.initInterval + } else { + interval = s.pollInterval + } + + timer := time.NewTimer(interval) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-s.closer.CloseNotify(): + timer.Stop() + return + case <-timer.C: + // continue to next iteration + } + } +} + +func (s *Service) queryDNSAndUpdateNodes(ctx context.Context) error { + // Record summary metrics + startTime := time.Now() + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.discoveryCount.Inc(1) + s.metrics.discoveryDuration.Observe(duration.Seconds()) + s.metrics.discoveryTotalDuration.Inc(duration.Seconds()) + } + }() + + addresses, queryErr := s.queryAllSRVRecords(ctx) + + if queryErr != nil { + s.log.Warn().Err(queryErr).Msg("DNS query failed, using last successful cache") + addresses = s.lastSuccessfulDNS + if len(addresses) == 0 { + if s.metrics != nil { + s.metrics.discoveryFailedCount.Inc(1) + } + return fmt.Errorf("DNS query failed and no cached addresses available: %w", queryErr) + } + } else { + s.lastSuccessfulDNS = addresses + if s.log.Debug().Enabled() { + s.log.Debug(). + Int("count", len(addresses)). + Strs("addresses", addresses). + Strs("srv_addresses", s.srvAddresses). + Msg("DNS query successful") + } + } + + // Update node cache based on DNS results + updateErr := s.updateNodeCache(ctx, addresses) + if updateErr != nil && s.metrics != nil { + s.metrics.discoveryFailedCount.Inc(1) + } + s.lastQueryMutex.Lock() + s.lastQueryTime = time.Now() + s.lastQueryMutex.Unlock() + return updateErr +} + +func (s *Service) queryAllSRVRecords(ctx context.Context) ([]string, error) { + startTime := time.Now() + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.dnsQueryCount.Inc(1) + s.metrics.dnsQueryDuration.Observe(duration.Seconds()) + s.metrics.dnsQueryTotalDuration.Inc(duration.Seconds()) + } + }() + + allAddresses := make(map[string]bool) + // track which SRV address (by index) each resolved address came from + newAddrToSRVIdx := make(map[string]int) + var queryErrors []error + + for srvIdx, srvAddr := range s.srvAddresses { + _, addrs, lookupErr := s.resolver.LookupSRV(ctx, srvAddr) + if lookupErr != nil { + queryErrors = append(queryErrors, fmt.Errorf("lookup %s failed: %w", srvAddr, lookupErr)) + continue + } + + for _, srv := range addrs { + address := fmt.Sprintf("%s:%d", srv.Target, srv.Port) + allAddresses[address] = true + + // track which SRV address this resolved to (first-wins strategy) + if _, exists := newAddrToSRVIdx[address]; !exists { + newAddrToSRVIdx[address] = srvIdx + } + } + } + + // if there have any error occurred, + // then just return the query error to ignore the result to make sure the cache correct + if len(queryErrors) > 0 { + if s.metrics != nil { + s.metrics.dnsQueryFailedCount.Inc(1) + } + return nil, errors.Join(queryErrors...) + } + + // update the resolved address to SRV index mapping + s.resolvedAddrMutex.Lock() + s.resolvedAddrToSRVIdx = newAddrToSRVIdx + s.resolvedAddrMutex.Unlock() + + // convert map to slice + result := make([]string, 0, len(allAddresses)) + for addr := range allAddresses { + result = append(result, addr) + } + + return result, nil +} + +func (s *Service) updateNodeCache(ctx context.Context, addresses []string) error { + addressSet := make(map[string]bool) + for _, addr := range addresses { + addressSet[addr] = true + } + + var addErrors []error + + for addr := range addressSet { + s.cacheMutex.RLock() + _, exists := s.nodeCache[addr] + s.cacheMutex.RUnlock() + + if !exists { + // fetch node metadata from gRPC + node, fetchErr := s.fetchNodeMetadata(ctx, addr) + if fetchErr != nil { + s.log.Warn(). + Err(fetchErr). + Str("address", addr). + Msg("Failed to fetch node metadata") + addErrors = append(addErrors, fetchErr) + continue + } + + // update cache and notify handlers + s.cacheMutex.Lock() + s.nodeCache[addr] = node + s.cacheMutex.Unlock() + + s.notifyHandlers(schema.Metadata{ + TypeMeta: schema.TypeMeta{ + Kind: schema.KindNode, + Name: node.GetMetadata().GetName(), + }, + Spec: node, + }, true) + + s.log.Debug(). + Str("address", addr). + Str("name", node.GetMetadata().GetName()). + Msg("New node discovered and added to cache") + } + } + + // collect nodes to delete first + s.cacheMutex.Lock() + nodesToDelete := make(map[string]*databasev1.Node) + for addr, node := range s.nodeCache { + if !addressSet[addr] { + nodesToDelete[addr] = node + } + } + + // delete from cache while still holding lock + for addr, node := range nodesToDelete { + delete(s.nodeCache, addr) + s.log.Debug(). + Str("address", addr). + Str("name", node.GetMetadata().GetName()). + Msg("Node removed from cache (no longer in DNS)") + } + cacheSize := len(s.nodeCache) + s.cacheMutex.Unlock() + + // Notify handlers after releasing lock + for _, node := range nodesToDelete { + s.notifyHandlers(schema.Metadata{ + TypeMeta: schema.TypeMeta{ + Kind: schema.KindNode, + Name: node.GetMetadata().GetName(), + }, + Spec: node, + }, false) + } + + // update total nodes metric + if s.metrics != nil { + s.metrics.totalNodesCount.Set(float64(cacheSize)) + } + + if len(addErrors) > 0 { + return errors.Join(addErrors...) + } + + return nil +} + +func (s *Service) fetchNodeMetadata(ctx context.Context, address string) (*databasev1.Node, error) { + // record gRPC query metrics + startTime := time.Now() + var grpcErr error + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.grpcQueryCount.Inc(1) + s.metrics.grpcQueryDuration.Observe(duration.Seconds()) + s.metrics.grpcQueryTotalDuration.Inc(duration.Seconds()) + if grpcErr != nil { + s.metrics.grpcQueryFailedCount.Inc(1) + } + } + }() + + ctxTimeout, cancel := context.WithTimeout(ctx, s.grpcTimeout) + defer cancel() + + // for TLS connections with other nodes to getting metadata + dialOpts, err := s.getTLSDialOptions(address) + if err != nil { + grpcErr = fmt.Errorf("failed to get TLS dial options: %w", err) + return nil, grpcErr + } + // nolint:contextcheck + conn, connErr := grpchelper.ConnWithAuth(address, s.grpcTimeout, "", "", dialOpts...) Review Comment: ```suggestion conn, connErr := grpchelper.Conn(address, s.grpcTimeout, dialOpts...) ``` ########## banyand/metadata/dns/dns.go: ########## @@ -0,0 +1,603 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package dns implements DNS-based node discovery for distributed metadata management. +package dns + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + + databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" + "github.com/apache/skywalking-banyandb/banyand/metadata/schema" + "github.com/apache/skywalking-banyandb/banyand/observability" + "github.com/apache/skywalking-banyandb/pkg/grpchelper" + "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" + pkgtls "github.com/apache/skywalking-banyandb/pkg/tls" +) + +// Service implements DNS-based node discovery. +type Service struct { + lastQueryTime time.Time + resolver Resolver + pathToReloader map[string]*pkgtls.Reloader + srvIndexToPath map[int]string + resolvedAddrToSRVIdx map[string]int + nodeCache map[string]*databasev1.Node + closer *run.Closer + log *logger.Logger + metrics *metrics + handlers map[string]schema.EventHandler + caCertPaths []string + srvAddresses []string + lastSuccessfulDNS []string + pollInterval time.Duration + initInterval time.Duration + initDuration time.Duration + grpcTimeout time.Duration + cacheMutex sync.RWMutex + handlersMutex sync.RWMutex + lastQueryMutex sync.RWMutex + resolvedAddrMutex sync.RWMutex + tlsEnabled bool +} + +// Config holds configuration for DNS discovery service. +type Config struct { + CACertPaths []string + SRVAddresses []string + InitInterval time.Duration + InitDuration time.Duration + PollInterval time.Duration + GRPCTimeout time.Duration + TLSEnabled bool +} + +// Resolver defines the interface for DNS SRV lookups. +type Resolver interface { + LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error) +} + +// defaultResolver wraps net.DefaultResolver to implement Resolver. +type defaultResolver struct{} + +func (d *defaultResolver) LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error) { + return net.DefaultResolver.LookupSRV(ctx, "", "", name) +} + +// NewService creates a new DNS discovery service. +func NewService(cfg Config) (*Service, error) { + // validation + if len(cfg.SRVAddresses) == 0 { + return nil, errors.New("DNS SRV addresses cannot be empty") + } + + // validate CA cert paths match SRV addresses when TLS is enabled + if cfg.TLSEnabled { + if len(cfg.CACertPaths) != len(cfg.SRVAddresses) { + return nil, fmt.Errorf("number of CA cert paths (%d) must match number of SRV addresses (%d)", + len(cfg.CACertPaths), len(cfg.SRVAddresses)) + } + } + + svc := &Service{ + srvAddresses: cfg.SRVAddresses, + initInterval: cfg.InitInterval, + initDuration: cfg.InitDuration, + pollInterval: cfg.PollInterval, + grpcTimeout: cfg.GRPCTimeout, + tlsEnabled: cfg.TLSEnabled, + caCertPaths: cfg.CACertPaths, + nodeCache: make(map[string]*databasev1.Node), + handlers: make(map[string]schema.EventHandler), + lastSuccessfulDNS: []string{}, + pathToReloader: make(map[string]*pkgtls.Reloader), + srvIndexToPath: make(map[int]string), + resolvedAddrToSRVIdx: make(map[string]int), + closer: run.NewCloser(1), + log: logger.GetLogger("metadata-discovery-dns"), + resolver: &defaultResolver{}, + } + + // create shared reloaders for CA certificates + if svc.tlsEnabled { + for srvIdx, certPath := range cfg.CACertPaths { + // Store the SRV index → cert path mapping + svc.srvIndexToPath[srvIdx] = certPath + + // check if we already have a Reloader for this path + if _, exists := svc.pathToReloader[certPath]; exists { + svc.log.Debug().Str("certPath", certPath).Int("srvIndex", srvIdx). + Msg("Reusing existing CA certificate reloader") + continue + } + + // create new Reloader for this unique path + reloader, reloaderErr := pkgtls.NewClientCertReloader(certPath, svc.log) + if reloaderErr != nil { + // clean up any already-created reloaders + for _, r := range svc.pathToReloader { + r.Stop() + } + return nil, fmt.Errorf("failed to initialize CA certificate reloader for path %s (SRV index %d): %w", + certPath, srvIdx, reloaderErr) + } + + svc.pathToReloader[certPath] = reloader + svc.log.Info().Str("certPath", certPath).Int("srvIndex", srvIdx). + Str("srvAddress", cfg.SRVAddresses[srvIdx]).Msg("Initialized DNS CA certificate reloader") + } + } + + return svc, nil +} + +// newServiceWithResolver creates a service with a custom resolver (for testing). +func newServiceWithResolver(cfg Config, resolver Resolver) (*Service, error) { + svc, err := NewService(cfg) + if err != nil { + return nil, err + } + svc.resolver = resolver + return svc, nil +} + +func (s *Service) getTLSDialOptions(address string) ([]grpc.DialOption, error) { + if !s.tlsEnabled { + return []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, nil + } + + // look up which Reloader to use for this address + if len(s.pathToReloader) > 0 { + // Find which SRV address this resolved address came from + s.resolvedAddrMutex.RLock() + srvIdx, addrExists := s.resolvedAddrToSRVIdx[address] + s.resolvedAddrMutex.RUnlock() + + if !addrExists { + return nil, fmt.Errorf("no SRV mapping found for address %s", address) + } + + // look up the cert path for this SRV index + certPath, pathExists := s.srvIndexToPath[srvIdx] + if !pathExists { + return nil, fmt.Errorf("no cert path found for SRV index %d (address %s)", srvIdx, address) + } + + // get the Reloader for this cert path + reloader, reloaderExists := s.pathToReloader[certPath] + if !reloaderExists { + return nil, fmt.Errorf("no reloader found for cert path %s (address %s)", certPath, address) + } + + // get fresh TLS config from the Reloader + tlsConfig, configErr := reloader.GetClientTLSConfig("") + if configErr != nil { + return nil, fmt.Errorf("failed to get TLS config from reloader for address %s: %w", address, configErr) + } + + creds := credentials.NewTLS(tlsConfig) + return []grpc.DialOption{grpc.WithTransportCredentials(creds)}, nil + } + + // fallback to static TLS config (when no reloaders configured) + opts, err := grpchelper.SecureOptions(nil, s.tlsEnabled, false, "") + if err != nil { + return nil, fmt.Errorf("failed to load TLS config: %w", err) + } + return opts, nil +} + +// Start begins the DNS discovery background process. +func (s *Service) Start(ctx context.Context) error { + s.log.Debug().Msg("Starting DNS-based node discovery service") + + // start all Reloaders + if len(s.pathToReloader) > 0 { + startedReloaders := make([]*pkgtls.Reloader, 0, len(s.pathToReloader)) + + for certPath, reloader := range s.pathToReloader { + if startErr := reloader.Start(); startErr != nil { + // stop any already-started reloaders + for _, r := range startedReloaders { + r.Stop() + } + return fmt.Errorf("failed to start CA certificate reloader for path %s: %w", certPath, startErr) + } + startedReloaders = append(startedReloaders, reloader) + s.log.Debug().Str("certPath", certPath).Msg("Started CA certificate reloader") + } + } + + go s.discoveryLoop(ctx) + + return nil +} + +func (s *Service) discoveryLoop(ctx context.Context) { + // add the init phase finish time + initPhaseEnd := time.Now().Add(s.initDuration) + + for { + if err := s.queryDNSAndUpdateNodes(ctx); err != nil { + s.log.Err(err).Msg("failed to query DNS and update nodes") + } + + // wait for next interval + var interval time.Duration + if time.Now().Before(initPhaseEnd) { + interval = s.initInterval + } else { + interval = s.pollInterval + } + + timer := time.NewTimer(interval) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-s.closer.CloseNotify(): + timer.Stop() + return + case <-timer.C: + // continue to next iteration + } + } +} + +func (s *Service) queryDNSAndUpdateNodes(ctx context.Context) error { + // Record summary metrics + startTime := time.Now() + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.discoveryCount.Inc(1) + s.metrics.discoveryDuration.Observe(duration.Seconds()) + s.metrics.discoveryTotalDuration.Inc(duration.Seconds()) + } + }() + + addresses, queryErr := s.queryAllSRVRecords(ctx) + + if queryErr != nil { + s.log.Warn().Err(queryErr).Msg("DNS query failed, using last successful cache") + addresses = s.lastSuccessfulDNS + if len(addresses) == 0 { + if s.metrics != nil { + s.metrics.discoveryFailedCount.Inc(1) + } + return fmt.Errorf("DNS query failed and no cached addresses available: %w", queryErr) + } + } else { + s.lastSuccessfulDNS = addresses + if s.log.Debug().Enabled() { + s.log.Debug(). + Int("count", len(addresses)). + Strs("addresses", addresses). + Strs("srv_addresses", s.srvAddresses). + Msg("DNS query successful") + } + } + + // Update node cache based on DNS results + updateErr := s.updateNodeCache(ctx, addresses) + if updateErr != nil && s.metrics != nil { + s.metrics.discoveryFailedCount.Inc(1) + } + s.lastQueryMutex.Lock() + s.lastQueryTime = time.Now() + s.lastQueryMutex.Unlock() + return updateErr +} + +func (s *Service) queryAllSRVRecords(ctx context.Context) ([]string, error) { + startTime := time.Now() + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.dnsQueryCount.Inc(1) + s.metrics.dnsQueryDuration.Observe(duration.Seconds()) + s.metrics.dnsQueryTotalDuration.Inc(duration.Seconds()) + } + }() + + allAddresses := make(map[string]bool) + // track which SRV address (by index) each resolved address came from + newAddrToSRVIdx := make(map[string]int) + var queryErrors []error + + for srvIdx, srvAddr := range s.srvAddresses { + _, addrs, lookupErr := s.resolver.LookupSRV(ctx, srvAddr) + if lookupErr != nil { + queryErrors = append(queryErrors, fmt.Errorf("lookup %s failed: %w", srvAddr, lookupErr)) + continue + } + + for _, srv := range addrs { + address := fmt.Sprintf("%s:%d", srv.Target, srv.Port) + allAddresses[address] = true + + // track which SRV address this resolved to (first-wins strategy) + if _, exists := newAddrToSRVIdx[address]; !exists { + newAddrToSRVIdx[address] = srvIdx + } + } + } + + // if there have any error occurred, + // then just return the query error to ignore the result to make sure the cache correct + if len(queryErrors) > 0 { + if s.metrics != nil { + s.metrics.dnsQueryFailedCount.Inc(1) + } + return nil, errors.Join(queryErrors...) + } + + // update the resolved address to SRV index mapping + s.resolvedAddrMutex.Lock() + s.resolvedAddrToSRVIdx = newAddrToSRVIdx + s.resolvedAddrMutex.Unlock() + + // convert map to slice + result := make([]string, 0, len(allAddresses)) + for addr := range allAddresses { + result = append(result, addr) + } + + return result, nil +} + +func (s *Service) updateNodeCache(ctx context.Context, addresses []string) error { + addressSet := make(map[string]bool) + for _, addr := range addresses { + addressSet[addr] = true + } + + var addErrors []error + + for addr := range addressSet { + s.cacheMutex.RLock() + _, exists := s.nodeCache[addr] + s.cacheMutex.RUnlock() + + if !exists { + // fetch node metadata from gRPC + node, fetchErr := s.fetchNodeMetadata(ctx, addr) + if fetchErr != nil { + s.log.Warn(). + Err(fetchErr). + Str("address", addr). + Msg("Failed to fetch node metadata") + addErrors = append(addErrors, fetchErr) + continue + } + + // update cache and notify handlers + s.cacheMutex.Lock() + s.nodeCache[addr] = node + s.cacheMutex.Unlock() Review Comment: I prefer the check-lock-check pattern ########## banyand/metadata/client.go: ########## @@ -96,13 +115,54 @@ func (s *clientService) FlagSet() *run.FlagSet { fs.StringVar(&s.etcdTLSKeyFile, flagEtcdTLSKeyFile, "", "Private key for the etcd client certificate.") fs.DurationVar(&s.registryTimeout, "node-registry-timeout", 2*time.Minute, "The timeout for the node registry") fs.DurationVar(&s.etcdFullSyncInterval, "etcd-full-sync-interval", 30*time.Minute, "The interval for full sync etcd") + + // node discovery configuration + fs.StringVar(&s.nodeDiscoveryMode, "node-discovery-mode", NodeDiscoveryModeEtcd, + "Node discovery mode: 'etcd' for etcd-based discovery, 'dns' for DNS-based discovery") + fs.StringSliceVar(&s.dnsSRVAddresses, "node-discovery-dns-srv-addresses", []string{}, + "DNS SRV addresses for node discovery (e.g., _grpc._tcp.banyandb.svc.cluster.local)") + fs.DurationVar(&s.dnsFetchInitInterval, "node-discovery-dns-fetch-init-interval", 5*time.Second, + "DNS query interval during initialization phase") + fs.DurationVar(&s.dnsFetchInitDuration, "node-discovery-dns-fetch-init-duration", 5*time.Minute, + "Duration of the initialization phase for DNS discovery") + fs.DurationVar(&s.dnsFetchInterval, "node-discovery-dns-fetch-interval", 15*time.Second, + "DNS query interval after initialization phase") + fs.DurationVar(&s.grpcTimeout, "node-discovery-grpc-timeout", 5*time.Second, + "Timeout for gRPC calls to fetch node metadata") + fs.BoolVar(&s.dnsTLSEnabled, "node-discovery-dns-tls", false, + "Enable TLS for DNS discovery gRPC connections") + fs.StringSliceVar(&s.dnsCACertPaths, "node-discovery-dns-ca-certs", []string{}, + "Comma-separated list of CA certificate files to verify DNS discovered nodes (one per SRV address, in same order)") + return fs } func (s *clientService) Validate() error { - if s.endpoints == nil { - return errors.New("endpoints is empty") + if s.nodeDiscoveryMode != NodeDiscoveryModeEtcd && s.nodeDiscoveryMode != NodeDiscoveryModeDNS { + return fmt.Errorf("invalid node-discovery-mode: %s, must be '%s' or '%s'", s.nodeDiscoveryMode, NodeDiscoveryModeEtcd, NodeDiscoveryModeDNS) } + + // Validate etcd endpoints (required for both modes for schema storage) Review Comment: fix it ########## banyand/metadata/dns/dns.go: ########## @@ -0,0 +1,603 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package dns implements DNS-based node discovery for distributed metadata management. +package dns + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + + databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" + "github.com/apache/skywalking-banyandb/banyand/metadata/schema" + "github.com/apache/skywalking-banyandb/banyand/observability" + "github.com/apache/skywalking-banyandb/pkg/grpchelper" + "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" + pkgtls "github.com/apache/skywalking-banyandb/pkg/tls" +) + +// Service implements DNS-based node discovery. +type Service struct { + lastQueryTime time.Time + resolver Resolver + pathToReloader map[string]*pkgtls.Reloader + srvIndexToPath map[int]string + resolvedAddrToSRVIdx map[string]int + nodeCache map[string]*databasev1.Node + closer *run.Closer + log *logger.Logger + metrics *metrics + handlers map[string]schema.EventHandler + caCertPaths []string + srvAddresses []string + lastSuccessfulDNS []string + pollInterval time.Duration + initInterval time.Duration + initDuration time.Duration + grpcTimeout time.Duration + cacheMutex sync.RWMutex + handlersMutex sync.RWMutex + lastQueryMutex sync.RWMutex + resolvedAddrMutex sync.RWMutex + tlsEnabled bool +} + +// Config holds configuration for DNS discovery service. +type Config struct { + CACertPaths []string + SRVAddresses []string + InitInterval time.Duration + InitDuration time.Duration + PollInterval time.Duration + GRPCTimeout time.Duration + TLSEnabled bool +} + +// Resolver defines the interface for DNS SRV lookups. +type Resolver interface { + LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error) +} + +// defaultResolver wraps net.DefaultResolver to implement Resolver. +type defaultResolver struct{} + +func (d *defaultResolver) LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error) { + return net.DefaultResolver.LookupSRV(ctx, "", "", name) +} + +// NewService creates a new DNS discovery service. +func NewService(cfg Config) (*Service, error) { + // validation + if len(cfg.SRVAddresses) == 0 { + return nil, errors.New("DNS SRV addresses cannot be empty") + } + + // validate CA cert paths match SRV addresses when TLS is enabled + if cfg.TLSEnabled { + if len(cfg.CACertPaths) != len(cfg.SRVAddresses) { + return nil, fmt.Errorf("number of CA cert paths (%d) must match number of SRV addresses (%d)", + len(cfg.CACertPaths), len(cfg.SRVAddresses)) + } + } + + svc := &Service{ + srvAddresses: cfg.SRVAddresses, + initInterval: cfg.InitInterval, + initDuration: cfg.InitDuration, + pollInterval: cfg.PollInterval, + grpcTimeout: cfg.GRPCTimeout, + tlsEnabled: cfg.TLSEnabled, + caCertPaths: cfg.CACertPaths, + nodeCache: make(map[string]*databasev1.Node), + handlers: make(map[string]schema.EventHandler), + lastSuccessfulDNS: []string{}, + pathToReloader: make(map[string]*pkgtls.Reloader), + srvIndexToPath: make(map[int]string), + resolvedAddrToSRVIdx: make(map[string]int), + closer: run.NewCloser(1), + log: logger.GetLogger("metadata-discovery-dns"), + resolver: &defaultResolver{}, + } + + // create shared reloaders for CA certificates + if svc.tlsEnabled { + for srvIdx, certPath := range cfg.CACertPaths { + // Store the SRV index → cert path mapping + svc.srvIndexToPath[srvIdx] = certPath + + // check if we already have a Reloader for this path + if _, exists := svc.pathToReloader[certPath]; exists { + svc.log.Debug().Str("certPath", certPath).Int("srvIndex", srvIdx). + Msg("Reusing existing CA certificate reloader") + continue + } + + // create new Reloader for this unique path + reloader, reloaderErr := pkgtls.NewClientCertReloader(certPath, svc.log) + if reloaderErr != nil { + // clean up any already-created reloaders + for _, r := range svc.pathToReloader { + r.Stop() + } + return nil, fmt.Errorf("failed to initialize CA certificate reloader for path %s (SRV index %d): %w", + certPath, srvIdx, reloaderErr) + } + + svc.pathToReloader[certPath] = reloader + svc.log.Info().Str("certPath", certPath).Int("srvIndex", srvIdx). + Str("srvAddress", cfg.SRVAddresses[srvIdx]).Msg("Initialized DNS CA certificate reloader") + } + } + + return svc, nil +} + +// newServiceWithResolver creates a service with a custom resolver (for testing). +func newServiceWithResolver(cfg Config, resolver Resolver) (*Service, error) { + svc, err := NewService(cfg) + if err != nil { + return nil, err + } + svc.resolver = resolver + return svc, nil +} + +func (s *Service) getTLSDialOptions(address string) ([]grpc.DialOption, error) { + if !s.tlsEnabled { + return []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, nil + } + + // look up which Reloader to use for this address + if len(s.pathToReloader) > 0 { + // Find which SRV address this resolved address came from + s.resolvedAddrMutex.RLock() + srvIdx, addrExists := s.resolvedAddrToSRVIdx[address] + s.resolvedAddrMutex.RUnlock() + + if !addrExists { + return nil, fmt.Errorf("no SRV mapping found for address %s", address) + } + + // look up the cert path for this SRV index + certPath, pathExists := s.srvIndexToPath[srvIdx] + if !pathExists { + return nil, fmt.Errorf("no cert path found for SRV index %d (address %s)", srvIdx, address) + } + + // get the Reloader for this cert path + reloader, reloaderExists := s.pathToReloader[certPath] + if !reloaderExists { + return nil, fmt.Errorf("no reloader found for cert path %s (address %s)", certPath, address) + } + + // get fresh TLS config from the Reloader + tlsConfig, configErr := reloader.GetClientTLSConfig("") + if configErr != nil { + return nil, fmt.Errorf("failed to get TLS config from reloader for address %s: %w", address, configErr) + } + + creds := credentials.NewTLS(tlsConfig) + return []grpc.DialOption{grpc.WithTransportCredentials(creds)}, nil + } + + // fallback to static TLS config (when no reloaders configured) + opts, err := grpchelper.SecureOptions(nil, s.tlsEnabled, false, "") + if err != nil { + return nil, fmt.Errorf("failed to load TLS config: %w", err) + } + return opts, nil +} + +// Start begins the DNS discovery background process. +func (s *Service) Start(ctx context.Context) error { + s.log.Debug().Msg("Starting DNS-based node discovery service") + + // start all Reloaders + if len(s.pathToReloader) > 0 { + startedReloaders := make([]*pkgtls.Reloader, 0, len(s.pathToReloader)) + + for certPath, reloader := range s.pathToReloader { + if startErr := reloader.Start(); startErr != nil { + // stop any already-started reloaders + for _, r := range startedReloaders { + r.Stop() + } + return fmt.Errorf("failed to start CA certificate reloader for path %s: %w", certPath, startErr) + } + startedReloaders = append(startedReloaders, reloader) + s.log.Debug().Str("certPath", certPath).Msg("Started CA certificate reloader") + } + } + + go s.discoveryLoop(ctx) + + return nil +} + +func (s *Service) discoveryLoop(ctx context.Context) { + // add the init phase finish time + initPhaseEnd := time.Now().Add(s.initDuration) + + for { + if err := s.queryDNSAndUpdateNodes(ctx); err != nil { + s.log.Err(err).Msg("failed to query DNS and update nodes") + } + + // wait for next interval + var interval time.Duration + if time.Now().Before(initPhaseEnd) { + interval = s.initInterval + } else { + interval = s.pollInterval + } + + timer := time.NewTimer(interval) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-s.closer.CloseNotify(): + timer.Stop() + return + case <-timer.C: + // continue to next iteration + } + } +} + +func (s *Service) queryDNSAndUpdateNodes(ctx context.Context) error { + // Record summary metrics + startTime := time.Now() + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.discoveryCount.Inc(1) + s.metrics.discoveryDuration.Observe(duration.Seconds()) + s.metrics.discoveryTotalDuration.Inc(duration.Seconds()) + } + }() + + addresses, queryErr := s.queryAllSRVRecords(ctx) + + if queryErr != nil { + s.log.Warn().Err(queryErr).Msg("DNS query failed, using last successful cache") + addresses = s.lastSuccessfulDNS + if len(addresses) == 0 { + if s.metrics != nil { + s.metrics.discoveryFailedCount.Inc(1) + } + return fmt.Errorf("DNS query failed and no cached addresses available: %w", queryErr) + } + } else { + s.lastSuccessfulDNS = addresses + if s.log.Debug().Enabled() { + s.log.Debug(). + Int("count", len(addresses)). + Strs("addresses", addresses). + Strs("srv_addresses", s.srvAddresses). + Msg("DNS query successful") + } + } + + // Update node cache based on DNS results + updateErr := s.updateNodeCache(ctx, addresses) + if updateErr != nil && s.metrics != nil { + s.metrics.discoveryFailedCount.Inc(1) + } + s.lastQueryMutex.Lock() + s.lastQueryTime = time.Now() + s.lastQueryMutex.Unlock() + return updateErr +} + +func (s *Service) queryAllSRVRecords(ctx context.Context) ([]string, error) { + startTime := time.Now() + defer func() { + if s.metrics != nil { + duration := time.Since(startTime) + s.metrics.dnsQueryCount.Inc(1) + s.metrics.dnsQueryDuration.Observe(duration.Seconds()) + s.metrics.dnsQueryTotalDuration.Inc(duration.Seconds()) + } + }() + + allAddresses := make(map[string]bool) + // track which SRV address (by index) each resolved address came from + newAddrToSRVIdx := make(map[string]int) + var queryErrors []error + + for srvIdx, srvAddr := range s.srvAddresses { + _, addrs, lookupErr := s.resolver.LookupSRV(ctx, srvAddr) + if lookupErr != nil { + queryErrors = append(queryErrors, fmt.Errorf("lookup %s failed: %w", srvAddr, lookupErr)) + continue + } + + for _, srv := range addrs { + address := fmt.Sprintf("%s:%d", srv.Target, srv.Port) + allAddresses[address] = true + + // track which SRV address this resolved to (first-wins strategy) + if _, exists := newAddrToSRVIdx[address]; !exists { + newAddrToSRVIdx[address] = srvIdx + } + } + } + + // if there have any error occurred, + // then just return the query error to ignore the result to make sure the cache correct + if len(queryErrors) > 0 { Review Comment: You should return resolved addresses, even if some cannot be resolved. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
