This is an automated email from the ASF dual-hosted git repository.

marsevilspirit pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git


The following commit(s) were added to refs/heads/develop by this push:
     new 47f73f746 fix: resolve memory leaks in goroutine management, file 
handles, and extension caches (#3023)
47f73f746 is described below

commit 47f73f74609cb6c8595cc3887c7e45beaaec76d4
Author: Eric Wang <[email protected]>
AuthorDate: Thu Sep 18 12:30:51 2025 +0800

    fix: resolve memory leaks in goroutine management, file handles, and 
extension caches (#3023)
---
 common/extension/filter.go               |  20 ++++
 common/extension/loadbalance.go          |  14 +++
 common/extension/memory_leak_test.go     | 157 +++++++++++++++++++++++++++++
 common/extension/protocol.go             |  15 +++
 filter/accesslog/filter.go               | 163 ++++++++++++++++++++++++++++---
 filter/accesslog/filter_test.go          |  12 +--
 filter/accesslog/memory_leak_test.go     | 137 ++++++++++++++++++++++++++
 registry/nacos/service_discovery.go      |  17 ++++
 registry/nacos/service_discovery_test.go |   2 +-
 9 files changed, 519 insertions(+), 18 deletions(-)

diff --git a/common/extension/filter.go b/common/extension/filter.go
index 84e922136..826646a08 100644
--- a/common/extension/filter.go
+++ b/common/extension/filter.go
@@ -58,3 +58,23 @@ func GetRejectedExecutionHandler(name string) 
(filter.RejectedExecutionHandler,
        }
        return creator(), nil
 }
+
+// UnregisterFilter removes the filter extension with @name
+// This helps prevent memory leaks in dynamic extension scenarios
+func UnregisterFilter(name string) {
+       delete(filters, name)
+}
+
+// UnregisterRejectedExecutionHandler removes the RejectedExecutionHandler 
with @name
+func UnregisterRejectedExecutionHandler(name string) {
+       delete(rejectedExecutionHandler, name)
+}
+
+// GetAllFilterNames returns all registered filter names
+func GetAllFilterNames() []string {
+       names := make([]string, 0, len(filters))
+       for name := range filters {
+               names = append(names, name)
+       }
+       return names
+}
diff --git a/common/extension/loadbalance.go b/common/extension/loadbalance.go
index 3308b405d..d04236f37 100644
--- a/common/extension/loadbalance.go
+++ b/common/extension/loadbalance.go
@@ -37,3 +37,17 @@ func GetLoadbalance(name string) loadbalance.LoadBalance {
 
        return loadbalances[name]()
 }
+
+// UnregisterLoadbalance removes the loadbalance extension with @name
+func UnregisterLoadbalance(name string) {
+       delete(loadbalances, name)
+}
+
+// GetAllLoadbalanceNames returns all registered loadbalance names
+func GetAllLoadbalanceNames() []string {
+       names := make([]string, 0, len(loadbalances))
+       for name := range loadbalances {
+               names = append(names, name)
+       }
+       return names
+}
diff --git a/common/extension/memory_leak_test.go 
b/common/extension/memory_leak_test.go
new file mode 100644
index 000000000..83160bf5b
--- /dev/null
+++ b/common/extension/memory_leak_test.go
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package extension
+
+import (
+       "context"
+       "testing"
+)
+
+import (
+       "github.com/stretchr/testify/assert"
+)
+
+import (
+       "dubbo.apache.org/dubbo-go/v3/cluster/loadbalance"
+       "dubbo.apache.org/dubbo-go/v3/filter"
+       "dubbo.apache.org/dubbo-go/v3/protocol/base"
+       "dubbo.apache.org/dubbo-go/v3/protocol/result"
+)
+
+// Mock implementations for testing
+type MockProtocol struct {
+       base.BaseProtocol
+}
+
+type MockFilter struct{}
+
+func (m *MockFilter) Invoke(ctx context.Context, invoker base.Invoker, 
invocation base.Invocation) result.Result {
+       return &result.RPCResult{}
+}
+
+func (m *MockFilter) OnResponse(ctx context.Context, res result.Result, 
invoker base.Invoker, invocation base.Invocation) result.Result {
+       return res
+}
+
+type MockLoadBalance struct{}
+
+func (m *MockLoadBalance) Select(invokers []base.Invoker, invocation 
base.Invocation) base.Invoker {
+       return nil
+}
+
+// TestProtocolMemoryManagement tests protocol registration and unregistration
+func TestProtocolMemoryManagement(t *testing.T) {
+       testName := "test-protocol-memory"
+
+       // Get initial count
+       initialCount := len(GetAllProtocolNames())
+
+       // Register a protocol
+       SetProtocol(testName, func() base.Protocol {
+               return &MockProtocol{}
+       })
+
+       // Verify registration
+       afterRegisterCount := len(GetAllProtocolNames())
+       assert.Equal(t, initialCount+1, afterRegisterCount, "Protocol should be 
registered")
+
+       // Verify protocol can be retrieved
+       assert.NotPanics(t, func() {
+               GetProtocol(testName)
+       }, "Should be able to get registered protocol")
+
+       // Unregister the protocol
+       UnregisterProtocol(testName)
+
+       // Verify unregistration
+       afterUnregisterCount := len(GetAllProtocolNames())
+       assert.Equal(t, initialCount, afterUnregisterCount, "Protocol should be 
unregistered")
+
+       // Verify protocol cannot be retrieved
+       assert.Panics(t, func() {
+               GetProtocol(testName)
+       }, "Should panic when trying to get unregistered protocol")
+}
+
+// TestFilterMemoryManagement tests filter registration and unregistration
+func TestFilterMemoryManagement(t *testing.T) {
+       testName := "test-filter-memory"
+
+       // Get initial count
+       initialCount := len(GetAllFilterNames())
+
+       // Register a filter
+       SetFilter(testName, func() filter.Filter {
+               return &MockFilter{}
+       })
+
+       // Verify registration
+       afterRegisterCount := len(GetAllFilterNames())
+       assert.Equal(t, initialCount+1, afterRegisterCount, "Filter should be 
registered")
+
+       // Verify filter can be retrieved
+       f, exists := GetFilter(testName)
+       assert.True(t, exists, "Should be able to get registered filter")
+       assert.NotNil(t, f, "Retrieved filter should not be nil")
+
+       // Unregister the filter
+       UnregisterFilter(testName)
+
+       // Verify unregistration
+       afterUnregisterCount := len(GetAllFilterNames())
+       assert.Equal(t, initialCount, afterUnregisterCount, "Filter should be 
unregistered")
+
+       // Verify filter cannot be retrieved
+       f, exists = GetFilter(testName)
+       assert.False(t, exists, "Should not be able to get unregistered filter")
+       assert.Nil(t, f, "Retrieved filter should be nil")
+}
+
+// TestLoadbalanceMemoryManagement tests loadbalance registration and 
unregistration
+func TestLoadbalanceMemoryManagement(t *testing.T) {
+       testName := "test-loadbalance-memory"
+
+       // Get initial count
+       initialCount := len(GetAllLoadbalanceNames())
+
+       // Register a loadbalance
+       SetLoadbalance(testName, func() loadbalance.LoadBalance {
+               return &MockLoadBalance{}
+       })
+
+       // Verify registration
+       afterRegisterCount := len(GetAllLoadbalanceNames())
+       assert.Equal(t, initialCount+1, afterRegisterCount, "LoadBalance should 
be registered")
+
+       // Verify loadbalance can be retrieved
+       assert.NotPanics(t, func() {
+               GetLoadbalance(testName)
+       }, "Should be able to get registered loadbalance")
+
+       // Unregister the loadbalance
+       UnregisterLoadbalance(testName)
+
+       // Verify unregistration
+       afterUnregisterCount := len(GetAllLoadbalanceNames())
+       assert.Equal(t, initialCount, afterUnregisterCount, "LoadBalance should 
be unregistered")
+
+       // Verify loadbalance cannot be retrieved
+       assert.Panics(t, func() {
+               GetLoadbalance(testName)
+       }, "Should panic when trying to get unregistered loadbalance")
+}
diff --git a/common/extension/protocol.go b/common/extension/protocol.go
index b599d4545..6a4f33c91 100644
--- a/common/extension/protocol.go
+++ b/common/extension/protocol.go
@@ -35,3 +35,18 @@ func GetProtocol(name string) base.Protocol {
        }
        return protocols[name]()
 }
+
+// UnregisterProtocol removes the protocol extension with @name
+// This helps prevent memory leaks in dynamic extension scenarios
+func UnregisterProtocol(name string) {
+       delete(protocols, name)
+}
+
+// GetAllProtocolNames returns all registered protocol names
+func GetAllProtocolNames() []string {
+       names := make([]string, 0, len(protocols))
+       for name := range protocols {
+               names = append(names, name)
+       }
+       return names
+}
diff --git a/filter/accesslog/filter.go b/filter/accesslog/filter.go
index c5865b146..cba9d3beb 100644
--- a/filter/accesslog/filter.go
+++ b/filter/accesslog/filter.go
@@ -84,18 +84,25 @@ func init() {
  * AccessLogFilter is designed to be singleton
  */
 type Filter struct {
-       logChan chan Data
+       logChan      chan Data
+       fileLock     sync.RWMutex // protects fileCache
+       fileCache    map[string]*os.File
+       ctx          context.Context
+       cancel       context.CancelFunc
+       shutdownOnce sync.Once
 }
 
 func newFilter() filter.Filter {
        if accessLogFilter == nil {
                once.Do(func() {
-                       accessLogFilter = &Filter{logChan: make(chan Data, 
LogMaxBuffer)}
-                       go func() {
-                               for accessLogData := range 
accessLogFilter.logChan {
-                                       
accessLogFilter.writeLogToFile(accessLogData)
-                               }
-                       }()
+                       ctx, cancel := context.WithCancel(context.Background())
+                       accessLogFilter = &Filter{
+                               logChan:   make(chan Data, LogMaxBuffer),
+                               fileCache: make(map[string]*os.File),
+                               ctx:       ctx,
+                               cancel:    cancel,
+                       }
+                       go accessLogFilter.processLogs()
                })
        }
        return accessLogFilter
@@ -182,6 +189,63 @@ func (f *Filter) OnResponse(_ context.Context, result 
result.Result, _ base.Invo
        return result
 }
 
+// processLogs runs in a background goroutine to process log data
+func (f *Filter) processLogs() {
+       defer func() {
+               if r := recover(); r != nil {
+                       logger.Errorf("AccessLog processLogs panic: %v", r)
+               }
+               f.drainLogs()
+       }()
+
+       for {
+               select {
+               case accessLogData, ok := <-f.logChan:
+                       if !ok {
+                               return
+                       }
+                       f.writeLogToFileWithTimeout(accessLogData, 
5*time.Second)
+               case <-f.ctx.Done():
+                       return
+               }
+       }
+}
+
+// drainLogs drains remaining log data with timeout protection
+func (f *Filter) drainLogs() {
+       timeout := time.After(5 * time.Second)
+       for {
+               select {
+               case accessLogData, ok := <-f.logChan:
+                       if !ok {
+                               return
+                       }
+                       f.writeLogToFileWithTimeout(accessLogData, 
1*time.Second)
+               case <-timeout:
+                       logger.Warnf("AccessLog drain timeout, some logs may be 
lost")
+                       return
+               default:
+                       return
+               }
+       }
+}
+
+// writeLogToFileWithTimeout writes log with timeout protection
+func (f *Filter) writeLogToFileWithTimeout(data Data, timeout time.Duration) {
+       done := make(chan struct{})
+       go func() {
+               defer close(done)
+               f.writeLogToFile(data)
+       }()
+
+       select {
+       case <-done:
+               logger.Debugf("AccessLog successfully written for: %s", 
data.accessLog)
+       case <-time.After(timeout):
+               logger.Warnf("AccessLog writeLogToFile timeout for: %s", 
data.accessLog)
+       }
+}
+
 // writeLogToFile actually write the logs into file
 func (f *Filter) writeLogToFile(data Data) {
        accessLog := data.accessLog
@@ -190,7 +254,7 @@ func (f *Filter) writeLogToFile(data Data) {
                return
        }
 
-       logFile, err := f.openLogFile(accessLog)
+       logFile, err := f.getOrOpenLogFile(accessLog)
        if err != nil {
                logger.Warnf("Can not open the access log file: %s, %v", 
accessLog, err)
                return
@@ -204,12 +268,56 @@ func (f *Filter) writeLogToFile(data Data) {
        }
 }
 
+// needLogRotation checks if the log file needs rotation based on date
+func needLogRotation(logFile *os.File) bool {
+       now := time.Now().Format(FileDateFormat)
+       if fileInfo, err := logFile.Stat(); err == nil {
+               last := fileInfo.ModTime().Format(FileDateFormat)
+               return now != last
+       }
+       return true // If we can't stat the file, assume rotation is needed
+}
+
+// getOrOpenLogFile gets or opens the log file with proper caching and handle 
management
+func (f *Filter) getOrOpenLogFile(accessLog string) (*os.File, error) {
+       f.fileLock.RLock()
+       if logFile, exists := f.fileCache[accessLog]; exists {
+               // Check if we need to rotate the log
+               if !needLogRotation(logFile) {
+                       f.fileLock.RUnlock()
+                       return logFile, nil
+               }
+       }
+       f.fileLock.RUnlock()
+
+       // Need to open new file or rotate existing one
+       f.fileLock.Lock()
+       defer f.fileLock.Unlock()
+
+       // Double-check after acquiring write lock
+       if logFile, exists := f.fileCache[accessLog]; exists {
+               if !needLogRotation(logFile) {
+                       return logFile, nil
+               }
+               // Close the old file before rotation
+               if err := logFile.Close(); err != nil {
+                       logger.Warnf("Failed to close old log file %s: %v", 
accessLog, err)
+               }
+               delete(f.fileCache, accessLog)
+       }
+
+       logFile, err := f.openLogFile(accessLog)
+       if err != nil {
+               return nil, err
+       }
+
+       f.fileCache[accessLog] = logFile
+       return logFile, nil
+}
+
 // openLogFile will open the log file with append mode.
 // If the file is not found, it will create the file.
 // Actually, the accessLog is the filename
-// You may find out that, once we want to write access log into log file,
-// we open the file again and again.
-// It needs to be optimized.
 func (f *Filter) openLogFile(accessLog string) (*os.File, error) {
        logFile, err := os.OpenFile(accessLog, 
os.O_CREATE|os.O_APPEND|os.O_RDWR, LogFileMode)
        if err != nil {
@@ -287,3 +395,36 @@ func (d *Data) toLogMessage() string {
        }
        return builder.String()
 }
+
+// Shutdown gracefully shuts down the access log filter
+// This should be called during application shutdown to prevent goroutine leaks
+func Shutdown() {
+       if accessLogFilter != nil {
+               accessLogFilter.shutdown()
+       }
+}
+
+// shutdown gracefully shuts down this filter instance
+func (f *Filter) shutdown() {
+       f.shutdownOnce.Do(func() {
+               // Cancel the context to signal goroutine to stop
+               if f.cancel != nil {
+                       f.cancel()
+               }
+
+               // Close the channel to stop accepting new logs
+               if f.logChan != nil {
+                       close(f.logChan)
+               }
+
+               // Close all cached file handles
+               f.fileLock.Lock()
+               defer f.fileLock.Unlock()
+               for path, file := range f.fileCache {
+                       if err := file.Close(); err != nil {
+                               logger.Warnf("Error closing access log file %s: 
%v", path, err)
+                       }
+                       delete(f.fileCache, path)
+               }
+       })
+}
diff --git a/filter/accesslog/filter_test.go b/filter/accesslog/filter_test.go
index 946f15a57..9c65a85ea 100644
--- a/filter/accesslog/filter_test.go
+++ b/filter/accesslog/filter_test.go
@@ -50,8 +50,8 @@ func TestFilter_Invoke_Not_Config(t *testing.T) {
        attach := make(map[string]any, 10)
        inv := invocation.NewRPCInvocation("MethodName", []any{"OK", "Hello"}, 
attach)
 
-       accessLogFilter := &Filter{}
-       invokeResult := accessLogFilter.Invoke(context.Background(), invoker, 
inv)
+       filter := &Filter{}
+       invokeResult := filter.Invoke(context.Background(), invoker, inv)
        assert.Nil(t, invokeResult.Error())
 }
 
@@ -71,14 +71,14 @@ func TestFilterInvokeDefaultConfig(t *testing.T) {
        attach[constant.GroupKey] = "MyGroup"
        inv := invocation.NewRPCInvocation("MethodName", []any{"OK", "Hello"}, 
attach)
 
-       accessLogFilter := &Filter{}
-       invokeResult := accessLogFilter.Invoke(context.Background(), invoker, 
inv)
+       filter := &Filter{}
+       invokeResult := filter.Invoke(context.Background(), invoker, inv)
        assert.Nil(t, invokeResult.Error())
 }
 
 func TestFilterOnResponse(t *testing.T) {
        rpcResult := &result.RPCResult{}
-       accessLogFilter := &Filter{}
-       response := accessLogFilter.OnResponse(context.TODO(), rpcResult, nil, 
nil)
+       filter := &Filter{}
+       response := filter.OnResponse(context.TODO(), rpcResult, nil, nil)
        assert.Equal(t, rpcResult, response)
 }
diff --git a/filter/accesslog/memory_leak_test.go 
b/filter/accesslog/memory_leak_test.go
new file mode 100644
index 000000000..f021c27f5
--- /dev/null
+++ b/filter/accesslog/memory_leak_test.go
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package accesslog
+
+import (
+       "context"
+       "os"
+       "runtime"
+       "sync"
+       "testing"
+       "time"
+)
+
+import (
+       "github.com/stretchr/testify/assert"
+)
+
+import (
+       "dubbo.apache.org/dubbo-go/v3/common"
+       "dubbo.apache.org/dubbo-go/v3/common/constant"
+       "dubbo.apache.org/dubbo-go/v3/protocol/base"
+       invocation_impl "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
+       "dubbo.apache.org/dubbo-go/v3/protocol/result"
+)
+
+// resetGlobalState resets the global state for testing
+func resetGlobalState() {
+       once.Do(func() {}) // Trigger once
+       accessLogFilter = nil
+       once = sync.Once{}
+}
+
+// TestAccessLogFilterGoroutineShutdown tests that the goroutine is properly 
shut down
+func TestAccessLogFilterGoroutineShutdown(t *testing.T) {
+       resetGlobalState()
+
+       // Count goroutines before
+       initialGoroutines := runtime.NumGoroutine()
+
+       // Create filter (this should start the goroutine)
+       filter := newFilter()
+       assert.NotNil(t, filter)
+
+       // Give the goroutine time to start
+       time.Sleep(100 * time.Millisecond)
+       postCreateGoroutines := runtime.NumGoroutine()
+
+       // Should have at least one more goroutine
+       assert.Greater(t, postCreateGoroutines, initialGoroutines)
+
+       // Shutdown the filter
+       Shutdown()
+
+       // Give goroutine time to exit
+       time.Sleep(100 * time.Millisecond)
+       runtime.GC() // Force garbage collection
+
+       postShutdownGoroutines := runtime.NumGoroutine()
+
+       // Goroutine count should be back to original or less
+       assert.LessOrEqual(t, postShutdownGoroutines, initialGoroutines+1,
+               "Goroutines should be cleaned up after shutdown")
+}
+
+// TestAccessLogFilterFileHandleManagement tests proper file handle management
+func TestAccessLogFilterFileHandleManagement(t *testing.T) {
+       resetGlobalState()
+
+       tempFile := "/tmp/test_access_log.log"
+       defer os.Remove(tempFile)
+
+       // Create filter
+       filter := newFilter().(*Filter)
+
+       // Create test URL and invocation
+       url := common.NewURLWithOptions(
+               common.WithParamsValue(constant.AccessLogFilterKey, tempFile),
+       )
+
+       invoker := &MockInvoker{url: url}
+       invocation := &invocation_impl.RPCInvocation{}
+
+       // Invoke multiple times to test file handle caching
+       for i := 0; i < 5; i++ {
+               filter.Invoke(context.Background(), invoker, invocation)
+       }
+
+       // Wait for logs to be processed
+       time.Sleep(100 * time.Millisecond)
+
+       // Check that file is in cache
+       filter.fileLock.RLock()
+       cachedFile, exists := filter.fileCache[tempFile]
+       filter.fileLock.RUnlock()
+
+       assert.True(t, exists, "File should be cached")
+       assert.NotNil(t, cachedFile, "Cached file should not be nil")
+
+       // Shutdown and verify files are closed
+       Shutdown()
+
+       // Check that cache is cleared
+       filter.fileLock.RLock()
+       cacheSize := len(filter.fileCache)
+       filter.fileLock.RUnlock()
+
+       assert.Equal(t, 0, cacheSize, "File cache should be empty after 
shutdown")
+}
+
+// MockInvoker for testing
+type MockInvoker struct {
+       base.BaseInvoker
+       url *common.URL
+}
+
+func (m *MockInvoker) GetURL() *common.URL {
+       return m.url
+}
+
+func (m *MockInvoker) Invoke(ctx context.Context, invocation base.Invocation) 
result.Result {
+       return &result.RPCResult{}
+}
diff --git a/registry/nacos/service_discovery.go 
b/registry/nacos/service_discovery.go
index 0fc3bdaa5..6bf062292 100644
--- a/registry/nacos/service_discovery.go
+++ b/registry/nacos/service_discovery.go
@@ -86,6 +86,23 @@ func (n *nacosServiceDiscovery) Destroy() error {
                        logger.Errorf("Unregister nacos instance:%+v, err:%+v", 
inst, err)
                }
        }
+
+       // Clean up listeners to prevent potential leaks
+       n.listenerLock.Lock()
+       defer n.listenerLock.Unlock()
+       // Unsubscribe from all services to stop callbacks
+       for serviceName := range n.instanceListenerMap {
+               err := n.namingClient.Client().Unsubscribe(&vo.SubscribeParam{
+                       ServiceName: serviceName,
+                       GroupName:   n.group,
+               })
+               if err != nil {
+                       logger.Warnf("Failed to unsubscribe from service %s: 
%v", serviceName, err)
+               }
+       }
+       // Clear the listener map
+       n.instanceListenerMap = make(map[string]*gxset.HashSet)
+
        n.namingClient.Close()
        return nil
 }
diff --git a/registry/nacos/service_discovery_test.go 
b/registry/nacos/service_discovery_test.go
index 10924c46e..20a58636b 100644
--- a/registry/nacos/service_discovery_test.go
+++ b/registry/nacos/service_discovery_test.go
@@ -324,7 +324,7 @@ func (c mockClient) Subscribe(param *vo.SubscribeParam) 
error {
 }
 
 func (c mockClient) Unsubscribe(param *vo.SubscribeParam) error {
-       panic("implement me")
+       return nil
 }
 
 func (c mockClient) GetAllServicesInfo(param vo.GetAllServiceInfoParam) 
(model.ServiceList, error) {

Reply via email to