This is an automated email from the ASF dual-hosted git repository.
marsevilspirit pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/develop by this push:
new 47f73f746 fix: resolve memory leaks in goroutine management, file
handles, and extension caches (#3023)
47f73f746 is described below
commit 47f73f74609cb6c8595cc3887c7e45beaaec76d4
Author: Eric Wang <[email protected]>
AuthorDate: Thu Sep 18 12:30:51 2025 +0800
fix: resolve memory leaks in goroutine management, file handles, and
extension caches (#3023)
---
common/extension/filter.go | 20 ++++
common/extension/loadbalance.go | 14 +++
common/extension/memory_leak_test.go | 157 +++++++++++++++++++++++++++++
common/extension/protocol.go | 15 +++
filter/accesslog/filter.go | 163 ++++++++++++++++++++++++++++---
filter/accesslog/filter_test.go | 12 +--
filter/accesslog/memory_leak_test.go | 137 ++++++++++++++++++++++++++
registry/nacos/service_discovery.go | 17 ++++
registry/nacos/service_discovery_test.go | 2 +-
9 files changed, 519 insertions(+), 18 deletions(-)
diff --git a/common/extension/filter.go b/common/extension/filter.go
index 84e922136..826646a08 100644
--- a/common/extension/filter.go
+++ b/common/extension/filter.go
@@ -58,3 +58,23 @@ func GetRejectedExecutionHandler(name string)
(filter.RejectedExecutionHandler,
}
return creator(), nil
}
+
+// UnregisterFilter removes the filter extension with @name
+// This helps prevent memory leaks in dynamic extension scenarios
+func UnregisterFilter(name string) {
+ delete(filters, name)
+}
+
+// UnregisterRejectedExecutionHandler removes the RejectedExecutionHandler
with @name
+func UnregisterRejectedExecutionHandler(name string) {
+ delete(rejectedExecutionHandler, name)
+}
+
+// GetAllFilterNames returns all registered filter names
+func GetAllFilterNames() []string {
+ names := make([]string, 0, len(filters))
+ for name := range filters {
+ names = append(names, name)
+ }
+ return names
+}
diff --git a/common/extension/loadbalance.go b/common/extension/loadbalance.go
index 3308b405d..d04236f37 100644
--- a/common/extension/loadbalance.go
+++ b/common/extension/loadbalance.go
@@ -37,3 +37,17 @@ func GetLoadbalance(name string) loadbalance.LoadBalance {
return loadbalances[name]()
}
+
+// UnregisterLoadbalance removes the loadbalance extension with @name
+func UnregisterLoadbalance(name string) {
+ delete(loadbalances, name)
+}
+
+// GetAllLoadbalanceNames returns all registered loadbalance names
+func GetAllLoadbalanceNames() []string {
+ names := make([]string, 0, len(loadbalances))
+ for name := range loadbalances {
+ names = append(names, name)
+ }
+ return names
+}
diff --git a/common/extension/memory_leak_test.go
b/common/extension/memory_leak_test.go
new file mode 100644
index 000000000..83160bf5b
--- /dev/null
+++ b/common/extension/memory_leak_test.go
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package extension
+
+import (
+ "context"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/cluster/loadbalance"
+ "dubbo.apache.org/dubbo-go/v3/filter"
+ "dubbo.apache.org/dubbo-go/v3/protocol/base"
+ "dubbo.apache.org/dubbo-go/v3/protocol/result"
+)
+
+// Mock implementations for testing
+type MockProtocol struct {
+ base.BaseProtocol
+}
+
+type MockFilter struct{}
+
+func (m *MockFilter) Invoke(ctx context.Context, invoker base.Invoker,
invocation base.Invocation) result.Result {
+ return &result.RPCResult{}
+}
+
+func (m *MockFilter) OnResponse(ctx context.Context, res result.Result,
invoker base.Invoker, invocation base.Invocation) result.Result {
+ return res
+}
+
+type MockLoadBalance struct{}
+
+func (m *MockLoadBalance) Select(invokers []base.Invoker, invocation
base.Invocation) base.Invoker {
+ return nil
+}
+
+// TestProtocolMemoryManagement tests protocol registration and unregistration
+func TestProtocolMemoryManagement(t *testing.T) {
+ testName := "test-protocol-memory"
+
+ // Get initial count
+ initialCount := len(GetAllProtocolNames())
+
+ // Register a protocol
+ SetProtocol(testName, func() base.Protocol {
+ return &MockProtocol{}
+ })
+
+ // Verify registration
+ afterRegisterCount := len(GetAllProtocolNames())
+ assert.Equal(t, initialCount+1, afterRegisterCount, "Protocol should be
registered")
+
+ // Verify protocol can be retrieved
+ assert.NotPanics(t, func() {
+ GetProtocol(testName)
+ }, "Should be able to get registered protocol")
+
+ // Unregister the protocol
+ UnregisterProtocol(testName)
+
+ // Verify unregistration
+ afterUnregisterCount := len(GetAllProtocolNames())
+ assert.Equal(t, initialCount, afterUnregisterCount, "Protocol should be
unregistered")
+
+ // Verify protocol cannot be retrieved
+ assert.Panics(t, func() {
+ GetProtocol(testName)
+ }, "Should panic when trying to get unregistered protocol")
+}
+
+// TestFilterMemoryManagement tests filter registration and unregistration
+func TestFilterMemoryManagement(t *testing.T) {
+ testName := "test-filter-memory"
+
+ // Get initial count
+ initialCount := len(GetAllFilterNames())
+
+ // Register a filter
+ SetFilter(testName, func() filter.Filter {
+ return &MockFilter{}
+ })
+
+ // Verify registration
+ afterRegisterCount := len(GetAllFilterNames())
+ assert.Equal(t, initialCount+1, afterRegisterCount, "Filter should be
registered")
+
+ // Verify filter can be retrieved
+ f, exists := GetFilter(testName)
+ assert.True(t, exists, "Should be able to get registered filter")
+ assert.NotNil(t, f, "Retrieved filter should not be nil")
+
+ // Unregister the filter
+ UnregisterFilter(testName)
+
+ // Verify unregistration
+ afterUnregisterCount := len(GetAllFilterNames())
+ assert.Equal(t, initialCount, afterUnregisterCount, "Filter should be
unregistered")
+
+ // Verify filter cannot be retrieved
+ f, exists = GetFilter(testName)
+ assert.False(t, exists, "Should not be able to get unregistered filter")
+ assert.Nil(t, f, "Retrieved filter should be nil")
+}
+
+// TestLoadbalanceMemoryManagement tests loadbalance registration and
unregistration
+func TestLoadbalanceMemoryManagement(t *testing.T) {
+ testName := "test-loadbalance-memory"
+
+ // Get initial count
+ initialCount := len(GetAllLoadbalanceNames())
+
+ // Register a loadbalance
+ SetLoadbalance(testName, func() loadbalance.LoadBalance {
+ return &MockLoadBalance{}
+ })
+
+ // Verify registration
+ afterRegisterCount := len(GetAllLoadbalanceNames())
+ assert.Equal(t, initialCount+1, afterRegisterCount, "LoadBalance should
be registered")
+
+ // Verify loadbalance can be retrieved
+ assert.NotPanics(t, func() {
+ GetLoadbalance(testName)
+ }, "Should be able to get registered loadbalance")
+
+ // Unregister the loadbalance
+ UnregisterLoadbalance(testName)
+
+ // Verify unregistration
+ afterUnregisterCount := len(GetAllLoadbalanceNames())
+ assert.Equal(t, initialCount, afterUnregisterCount, "LoadBalance should
be unregistered")
+
+ // Verify loadbalance cannot be retrieved
+ assert.Panics(t, func() {
+ GetLoadbalance(testName)
+ }, "Should panic when trying to get unregistered loadbalance")
+}
diff --git a/common/extension/protocol.go b/common/extension/protocol.go
index b599d4545..6a4f33c91 100644
--- a/common/extension/protocol.go
+++ b/common/extension/protocol.go
@@ -35,3 +35,18 @@ func GetProtocol(name string) base.Protocol {
}
return protocols[name]()
}
+
+// UnregisterProtocol removes the protocol extension with @name
+// This helps prevent memory leaks in dynamic extension scenarios
+func UnregisterProtocol(name string) {
+ delete(protocols, name)
+}
+
+// GetAllProtocolNames returns all registered protocol names
+func GetAllProtocolNames() []string {
+ names := make([]string, 0, len(protocols))
+ for name := range protocols {
+ names = append(names, name)
+ }
+ return names
+}
diff --git a/filter/accesslog/filter.go b/filter/accesslog/filter.go
index c5865b146..cba9d3beb 100644
--- a/filter/accesslog/filter.go
+++ b/filter/accesslog/filter.go
@@ -84,18 +84,25 @@ func init() {
* AccessLogFilter is designed to be singleton
*/
type Filter struct {
- logChan chan Data
+ logChan chan Data
+ fileLock sync.RWMutex // protects fileCache
+ fileCache map[string]*os.File
+ ctx context.Context
+ cancel context.CancelFunc
+ shutdownOnce sync.Once
}
func newFilter() filter.Filter {
if accessLogFilter == nil {
once.Do(func() {
- accessLogFilter = &Filter{logChan: make(chan Data,
LogMaxBuffer)}
- go func() {
- for accessLogData := range
accessLogFilter.logChan {
-
accessLogFilter.writeLogToFile(accessLogData)
- }
- }()
+ ctx, cancel := context.WithCancel(context.Background())
+ accessLogFilter = &Filter{
+ logChan: make(chan Data, LogMaxBuffer),
+ fileCache: make(map[string]*os.File),
+ ctx: ctx,
+ cancel: cancel,
+ }
+ go accessLogFilter.processLogs()
})
}
return accessLogFilter
@@ -182,6 +189,63 @@ func (f *Filter) OnResponse(_ context.Context, result
result.Result, _ base.Invo
return result
}
+// processLogs runs in a background goroutine to process log data
+func (f *Filter) processLogs() {
+ defer func() {
+ if r := recover(); r != nil {
+ logger.Errorf("AccessLog processLogs panic: %v", r)
+ }
+ f.drainLogs()
+ }()
+
+ for {
+ select {
+ case accessLogData, ok := <-f.logChan:
+ if !ok {
+ return
+ }
+ f.writeLogToFileWithTimeout(accessLogData,
5*time.Second)
+ case <-f.ctx.Done():
+ return
+ }
+ }
+}
+
+// drainLogs drains remaining log data with timeout protection
+func (f *Filter) drainLogs() {
+ timeout := time.After(5 * time.Second)
+ for {
+ select {
+ case accessLogData, ok := <-f.logChan:
+ if !ok {
+ return
+ }
+ f.writeLogToFileWithTimeout(accessLogData,
1*time.Second)
+ case <-timeout:
+ logger.Warnf("AccessLog drain timeout, some logs may be
lost")
+ return
+ default:
+ return
+ }
+ }
+}
+
+// writeLogToFileWithTimeout writes log with timeout protection
+func (f *Filter) writeLogToFileWithTimeout(data Data, timeout time.Duration) {
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ f.writeLogToFile(data)
+ }()
+
+ select {
+ case <-done:
+ logger.Debugf("AccessLog successfully written for: %s",
data.accessLog)
+ case <-time.After(timeout):
+ logger.Warnf("AccessLog writeLogToFile timeout for: %s",
data.accessLog)
+ }
+}
+
// writeLogToFile actually write the logs into file
func (f *Filter) writeLogToFile(data Data) {
accessLog := data.accessLog
@@ -190,7 +254,7 @@ func (f *Filter) writeLogToFile(data Data) {
return
}
- logFile, err := f.openLogFile(accessLog)
+ logFile, err := f.getOrOpenLogFile(accessLog)
if err != nil {
logger.Warnf("Can not open the access log file: %s, %v",
accessLog, err)
return
@@ -204,12 +268,56 @@ func (f *Filter) writeLogToFile(data Data) {
}
}
+// needLogRotation checks if the log file needs rotation based on date
+func needLogRotation(logFile *os.File) bool {
+ now := time.Now().Format(FileDateFormat)
+ if fileInfo, err := logFile.Stat(); err == nil {
+ last := fileInfo.ModTime().Format(FileDateFormat)
+ return now != last
+ }
+ return true // If we can't stat the file, assume rotation is needed
+}
+
+// getOrOpenLogFile gets or opens the log file with proper caching and handle
management
+func (f *Filter) getOrOpenLogFile(accessLog string) (*os.File, error) {
+ f.fileLock.RLock()
+ if logFile, exists := f.fileCache[accessLog]; exists {
+ // Check if we need to rotate the log
+ if !needLogRotation(logFile) {
+ f.fileLock.RUnlock()
+ return logFile, nil
+ }
+ }
+ f.fileLock.RUnlock()
+
+ // Need to open new file or rotate existing one
+ f.fileLock.Lock()
+ defer f.fileLock.Unlock()
+
+ // Double-check after acquiring write lock
+ if logFile, exists := f.fileCache[accessLog]; exists {
+ if !needLogRotation(logFile) {
+ return logFile, nil
+ }
+ // Close the old file before rotation
+ if err := logFile.Close(); err != nil {
+ logger.Warnf("Failed to close old log file %s: %v",
accessLog, err)
+ }
+ delete(f.fileCache, accessLog)
+ }
+
+ logFile, err := f.openLogFile(accessLog)
+ if err != nil {
+ return nil, err
+ }
+
+ f.fileCache[accessLog] = logFile
+ return logFile, nil
+}
+
// openLogFile will open the log file with append mode.
// If the file is not found, it will create the file.
// Actually, the accessLog is the filename
-// You may find out that, once we want to write access log into log file,
-// we open the file again and again.
-// It needs to be optimized.
func (f *Filter) openLogFile(accessLog string) (*os.File, error) {
logFile, err := os.OpenFile(accessLog,
os.O_CREATE|os.O_APPEND|os.O_RDWR, LogFileMode)
if err != nil {
@@ -287,3 +395,36 @@ func (d *Data) toLogMessage() string {
}
return builder.String()
}
+
+// Shutdown gracefully shuts down the access log filter
+// This should be called during application shutdown to prevent goroutine leaks
+func Shutdown() {
+ if accessLogFilter != nil {
+ accessLogFilter.shutdown()
+ }
+}
+
+// shutdown gracefully shuts down this filter instance
+func (f *Filter) shutdown() {
+ f.shutdownOnce.Do(func() {
+ // Cancel the context to signal goroutine to stop
+ if f.cancel != nil {
+ f.cancel()
+ }
+
+ // Close the channel to stop accepting new logs
+ if f.logChan != nil {
+ close(f.logChan)
+ }
+
+ // Close all cached file handles
+ f.fileLock.Lock()
+ defer f.fileLock.Unlock()
+ for path, file := range f.fileCache {
+ if err := file.Close(); err != nil {
+ logger.Warnf("Error closing access log file %s:
%v", path, err)
+ }
+ delete(f.fileCache, path)
+ }
+ })
+}
diff --git a/filter/accesslog/filter_test.go b/filter/accesslog/filter_test.go
index 946f15a57..9c65a85ea 100644
--- a/filter/accesslog/filter_test.go
+++ b/filter/accesslog/filter_test.go
@@ -50,8 +50,8 @@ func TestFilter_Invoke_Not_Config(t *testing.T) {
attach := make(map[string]any, 10)
inv := invocation.NewRPCInvocation("MethodName", []any{"OK", "Hello"},
attach)
- accessLogFilter := &Filter{}
- invokeResult := accessLogFilter.Invoke(context.Background(), invoker,
inv)
+ filter := &Filter{}
+ invokeResult := filter.Invoke(context.Background(), invoker, inv)
assert.Nil(t, invokeResult.Error())
}
@@ -71,14 +71,14 @@ func TestFilterInvokeDefaultConfig(t *testing.T) {
attach[constant.GroupKey] = "MyGroup"
inv := invocation.NewRPCInvocation("MethodName", []any{"OK", "Hello"},
attach)
- accessLogFilter := &Filter{}
- invokeResult := accessLogFilter.Invoke(context.Background(), invoker,
inv)
+ filter := &Filter{}
+ invokeResult := filter.Invoke(context.Background(), invoker, inv)
assert.Nil(t, invokeResult.Error())
}
func TestFilterOnResponse(t *testing.T) {
rpcResult := &result.RPCResult{}
- accessLogFilter := &Filter{}
- response := accessLogFilter.OnResponse(context.TODO(), rpcResult, nil,
nil)
+ filter := &Filter{}
+ response := filter.OnResponse(context.TODO(), rpcResult, nil, nil)
assert.Equal(t, rpcResult, response)
}
diff --git a/filter/accesslog/memory_leak_test.go
b/filter/accesslog/memory_leak_test.go
new file mode 100644
index 000000000..f021c27f5
--- /dev/null
+++ b/filter/accesslog/memory_leak_test.go
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package accesslog
+
+import (
+ "context"
+ "os"
+ "runtime"
+ "sync"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/common"
+ "dubbo.apache.org/dubbo-go/v3/common/constant"
+ "dubbo.apache.org/dubbo-go/v3/protocol/base"
+ invocation_impl "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
+ "dubbo.apache.org/dubbo-go/v3/protocol/result"
+)
+
+// resetGlobalState resets the global state for testing
+func resetGlobalState() {
+ once.Do(func() {}) // Trigger once
+ accessLogFilter = nil
+ once = sync.Once{}
+}
+
+// TestAccessLogFilterGoroutineShutdown tests that the goroutine is properly
shut down
+func TestAccessLogFilterGoroutineShutdown(t *testing.T) {
+ resetGlobalState()
+
+ // Count goroutines before
+ initialGoroutines := runtime.NumGoroutine()
+
+ // Create filter (this should start the goroutine)
+ filter := newFilter()
+ assert.NotNil(t, filter)
+
+ // Give the goroutine time to start
+ time.Sleep(100 * time.Millisecond)
+ postCreateGoroutines := runtime.NumGoroutine()
+
+ // Should have at least one more goroutine
+ assert.Greater(t, postCreateGoroutines, initialGoroutines)
+
+ // Shutdown the filter
+ Shutdown()
+
+ // Give goroutine time to exit
+ time.Sleep(100 * time.Millisecond)
+ runtime.GC() // Force garbage collection
+
+ postShutdownGoroutines := runtime.NumGoroutine()
+
+ // Goroutine count should be back to original or less
+ assert.LessOrEqual(t, postShutdownGoroutines, initialGoroutines+1,
+ "Goroutines should be cleaned up after shutdown")
+}
+
+// TestAccessLogFilterFileHandleManagement tests proper file handle management
+func TestAccessLogFilterFileHandleManagement(t *testing.T) {
+ resetGlobalState()
+
+ tempFile := "/tmp/test_access_log.log"
+ defer os.Remove(tempFile)
+
+ // Create filter
+ filter := newFilter().(*Filter)
+
+ // Create test URL and invocation
+ url := common.NewURLWithOptions(
+ common.WithParamsValue(constant.AccessLogFilterKey, tempFile),
+ )
+
+ invoker := &MockInvoker{url: url}
+ invocation := &invocation_impl.RPCInvocation{}
+
+ // Invoke multiple times to test file handle caching
+ for i := 0; i < 5; i++ {
+ filter.Invoke(context.Background(), invoker, invocation)
+ }
+
+ // Wait for logs to be processed
+ time.Sleep(100 * time.Millisecond)
+
+ // Check that file is in cache
+ filter.fileLock.RLock()
+ cachedFile, exists := filter.fileCache[tempFile]
+ filter.fileLock.RUnlock()
+
+ assert.True(t, exists, "File should be cached")
+ assert.NotNil(t, cachedFile, "Cached file should not be nil")
+
+ // Shutdown and verify files are closed
+ Shutdown()
+
+ // Check that cache is cleared
+ filter.fileLock.RLock()
+ cacheSize := len(filter.fileCache)
+ filter.fileLock.RUnlock()
+
+ assert.Equal(t, 0, cacheSize, "File cache should be empty after
shutdown")
+}
+
+// MockInvoker for testing
+type MockInvoker struct {
+ base.BaseInvoker
+ url *common.URL
+}
+
+func (m *MockInvoker) GetURL() *common.URL {
+ return m.url
+}
+
+func (m *MockInvoker) Invoke(ctx context.Context, invocation base.Invocation)
result.Result {
+ return &result.RPCResult{}
+}
diff --git a/registry/nacos/service_discovery.go
b/registry/nacos/service_discovery.go
index 0fc3bdaa5..6bf062292 100644
--- a/registry/nacos/service_discovery.go
+++ b/registry/nacos/service_discovery.go
@@ -86,6 +86,23 @@ func (n *nacosServiceDiscovery) Destroy() error {
logger.Errorf("Unregister nacos instance:%+v, err:%+v",
inst, err)
}
}
+
+ // Clean up listeners to prevent potential leaks
+ n.listenerLock.Lock()
+ defer n.listenerLock.Unlock()
+ // Unsubscribe from all services to stop callbacks
+ for serviceName := range n.instanceListenerMap {
+ err := n.namingClient.Client().Unsubscribe(&vo.SubscribeParam{
+ ServiceName: serviceName,
+ GroupName: n.group,
+ })
+ if err != nil {
+ logger.Warnf("Failed to unsubscribe from service %s:
%v", serviceName, err)
+ }
+ }
+ // Clear the listener map
+ n.instanceListenerMap = make(map[string]*gxset.HashSet)
+
n.namingClient.Close()
return nil
}
diff --git a/registry/nacos/service_discovery_test.go
b/registry/nacos/service_discovery_test.go
index 10924c46e..20a58636b 100644
--- a/registry/nacos/service_discovery_test.go
+++ b/registry/nacos/service_discovery_test.go
@@ -324,7 +324,7 @@ func (c mockClient) Subscribe(param *vo.SubscribeParam)
error {
}
func (c mockClient) Unsubscribe(param *vo.SubscribeParam) error {
- panic("implement me")
+ return nil
}
func (c mockClient) GetAllServicesInfo(param vo.GetAllServiceInfoParam)
(model.ServiceList, error) {