This is an automated email from the ASF dual-hosted git repository. wkcn pushed a commit to branch v1.5.x in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.5.x by this push: new 804403e Fix the bug of `MXEnginePushAsyncND` and `MXEnginePushSyncND` (#15751) (#15792) 804403e is described below commit 804403e999d1567f371c5243f5565127ad7f2f93 Author: JackieWu <w...@live.cn> AuthorDate: Thu Aug 8 13:55:35 2019 +0800 Fix the bug of `MXEnginePushAsyncND` and `MXEnginePushSyncND` (#15751) (#15792) * fix push sync nd api * align code * update test for syncnd * fix bug in tests/cpp/engine/threaded_engine_test * add more testcases for MXEnginePushSyncND and MXEnginePushAsyncND * fix test * fix * fix * lint * ci * retrigger CI --- include/mxnet/c_api.h | 22 +++--- src/c_api/c_api.cc | 40 +++++------ tests/cpp/engine/threaded_engine_test.cc | 117 +++++++++++++++++++------------ 3 files changed, 105 insertions(+), 74 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a2da6db..c73b366 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2863,12 +2863,12 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, * \param wait Whether this is a WaitForVar operation. */ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - NDArrayHandle const_nds_handle, int num_const_nds, - NDArrayHandle mutable_nds_handle, int num_mutable_nds, - EngineFnPropertyHandle prop_handle DEFAULT(NULL), - int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), - bool wait DEFAULT(false)); + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle* const_nds_handle, int num_const_nds, + NDArrayHandle* mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle DEFAULT(NULL), + int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), + bool wait DEFAULT(false)); /*! * \brief Push a synchronous operation to the engine. @@ -2886,11 +2886,11 @@ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, * \param opr_name The operation name. */ MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - NDArrayHandle const_nds_handle, int num_const_nds, - NDArrayHandle mutable_nds_handle, int num_mutable_nds, - EngineFnPropertyHandle prop_handle DEFAULT(NULL), - int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle* const_nds_handle, int num_const_nds, + NDArrayHandle* mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle DEFAULT(NULL), + int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); #ifdef __cplusplus } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 35bd3ee..6ba46bd 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1535,18 +1535,18 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, } int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - NDArrayHandle const_nds_handle, int num_const_nds, - NDArrayHandle mutable_nds_handle, int num_mutable_nds, - EngineFnPropertyHandle prop_handle, int priority, - const char* opr_name, bool wait) { - API_BEGIN(); - NDArray* const_nds = static_cast<NDArray*>(const_nds_handle); - NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle); + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle* const_nds_handle, int num_const_nds, + NDArrayHandle* mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle, int priority, + const char* opr_name, bool wait) { + API_BEGIN(); + NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle); + NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle); std::vector<VarHandle> const_var_vec(num_const_nds); - for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var(); + for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var(); std::vector<VarHandle> mutable_var_vec(num_mutable_nds); - for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var(); + for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var(); return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle, const_var_vec.data(), num_const_nds, mutable_var_vec.data(), num_mutable_nds, @@ -1555,18 +1555,18 @@ int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, } int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - NDArrayHandle const_nds_handle, int num_const_nds, - NDArrayHandle mutable_nds_handle, int num_mutable_nds, - EngineFnPropertyHandle prop_handle, int priority, - const char* opr_name) { - API_BEGIN(); - NDArray* const_nds = static_cast<NDArray*>(const_nds_handle); - NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle); + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle* const_nds_handle, int num_const_nds, + NDArrayHandle* mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle, int priority, + const char* opr_name) { + API_BEGIN(); + NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle); + NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle); std::vector<VarHandle> const_var_vec(num_const_nds); - for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var(); + for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var(); std::vector<VarHandle> mutable_var_vec(num_mutable_nds); - for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var(); + for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var(); return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle, const_var_vec.data(), num_const_nds, mutable_var_vec.data(), num_mutable_nds, diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index 6b863f8..cea92a0 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -257,49 +257,80 @@ TEST(Engine, PushFunc) { TEST(Engine, PushFuncND) { auto ctx = mxnet::Context{}; - mxnet::NDArray nd(ctx); - - // Test #1 - LOG(INFO) << "===== Test #1: PushAsyncND param and deleter ====="; - int* a = new int(100); - int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0); - EXPECT_EQ(res, 0); - - // Test #2 - LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter ====="; - res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0); - EXPECT_EQ(res, 0); - - // Test #3 - LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds ====="; - res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0); - EXPECT_EQ(res, -1); - - // Test #4 - LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds ====="; - res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1); - EXPECT_EQ(res, -1); - - // Test #5 - LOG(INFO) << "===== Test #5: PushSyncND param and deleter ====="; - int* b = new int(101); - res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0); - EXPECT_EQ(res, 0); - - // Test #6 - LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter ====="; - res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1); - EXPECT_EQ(res, 0); - - // Test #7 - LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds ====="; - res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0); - EXPECT_EQ(res, -1); - - // Test #8 - LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds ====="; - res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1); - EXPECT_EQ(res, -1); + std::vector<mxnet::NDArray*> nds; + const int num_nds = 5; + for (int i = 0; i < num_nds; ++i) { + mxnet::NDArray *pnd = new mxnet::NDArray(ctx); + nds.push_back(pnd); + } + for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) { + int num_mutable_nds = num_nds - num_const_nds; + void** const_nds_handle = num_const_nds > 0 ? + reinterpret_cast<void**>(nds.data()) : nullptr; + void** mutable_nds_handle = num_mutable_nds > 0 ? + reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr; + + // Test #1 + LOG(INFO) << "===== Test #1: PushAsyncND param and deleter ====="; + int* a = new int(100); + int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, + const_nds_handle, num_const_nds, + mutable_nds_handle, num_mutable_nds); + EXPECT_EQ(res, 0); + + // Test #2 + LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter ====="; + res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, + const_nds_handle, num_const_nds, + mutable_nds_handle, num_mutable_nds); + EXPECT_EQ(res, 0); + + // Test #3 + LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds ====="; + res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, + const_nds_handle, -1, + mutable_nds_handle, num_mutable_nds); + EXPECT_EQ(res, -1); + + // Test #4 + LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds ====="; + res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, + const_nds_handle, num_const_nds, + mutable_nds_handle, -1); + EXPECT_EQ(res, -1); + + // Test #5 + LOG(INFO) << "===== Test #5: PushSyncND param and deleter ====="; + int* b = new int(101); + res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, + const_nds_handle, num_const_nds, + mutable_nds_handle, num_mutable_nds); + EXPECT_EQ(res, 0); + + // Test #6 + LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter ====="; + res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, + const_nds_handle, num_const_nds, + mutable_nds_handle, num_mutable_nds); + EXPECT_EQ(res, 0); + + // Test #7 + LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds ====="; + res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, + const_nds_handle, -1, + mutable_nds_handle, num_mutable_nds); + EXPECT_EQ(res, -1); + + // Test #8 + LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds ====="; + res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, + const_nds_handle, num_const_nds, + mutable_nds_handle, -1); + EXPECT_EQ(res, -1); + } + for (mxnet::NDArray* pnd : nds) { + delete pnd; + } } TEST(Engine, basics) {