samskalicky commented on a change in pull request #15921: [WIP] dynamic custom 
operator support
URL: https://github.com/apache/incubator-mxnet/pull/15921#discussion_r320309108
 
 

 ##########
 File path: src/c_api/c_api.cc
 ##########
 @@ -92,16 +93,384 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
 
 // NOTE: return value is added in API_END
 
-// Loads library and initializes it
+/*!
+ * \brief Loads dynamic library and initializes it
+ * \param path library path
+ */
 int MXLoadLib(const char *path) {
   API_BEGIN();
   void *lib = LibraryInitializer::Get()->lib_load(path);
   if (!lib)
     LOG(FATAL) << "Unable to load library";
 
+  // check that library and MXNet use same version of library API
+  opVersion_t opVersion = get_func<opVersion_t>(lib, 
const_cast<char*>(MXLIB_OPVERSION_STR));
+  int libVersion =  opVersion();
+  if (MX_LIBRARY_VERSION != libVersion)
+    LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet 
version ("
+               << MX_LIBRARY_VERSION << ")";
+
+  // initialize library by passing MXNet version
   initialize_t initialize = get_func<initialize_t>(lib, 
const_cast<char*>(MXLIB_INITIALIZE_STR));
   if (!initialize(static_cast<int>(MXNET_VERSION)))
     LOG(FATAL) << "Library failed to initialize";
+
+  // get C type interface functions
+  opCallFree_t callFree = get_func<opCallFree_t>(lib, 
const_cast<char*>(MXLIB_OPCALLFREE_STR));
+
+  opCallParseAttrs_t callParseAttrs =
+    get_func<opCallParseAttrs_t>(lib, 
const_cast<char*>(MXLIB_OPCALLPARSEATTRS_STR));
+
+  opCallInferShape_t callInferShape =
+    get_func<opCallInferShape_t>(lib, 
const_cast<char*>(MXLIB_OPCALLINFERSHAPE_STR));
+
+  opCallInferType_t callInferType =
+    get_func<opCallInferType_t>(lib, 
const_cast<char*>(MXLIB_OPCALLINFERTYPE_STR));
+
+  opCallFComp_t callFComp =
+    get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));
+
+  opCallMutateInputs_t callMutateInputs =
+    get_func<opCallMutateInputs_t>(lib, 
const_cast<char*>(MXLIB_OPCALLMUTATEINPUTS_STR));
+
+  // get number of operators registered in the library
+  opRegSize_t opRegSize = get_func<opRegSize_t>(lib, 
const_cast<char*>(MXLIB_OPREGSIZE_STR));
+  int numOps = opRegSize();
+  LOG(INFO) << "Found " << numOps << " operators in library";
+
+  /*
+   * The library has custom operators implementation
+   * loop and register each operator in the library to NNVM
+   */
+  opRegGet_t opRegGet = get_func<opRegGet_t>(lib, 
const_cast<char*>(MXLIB_OPREGGET_STR));
+  for (int i = 0; i < numOps; i++) {
+    const char* name;
+    // function pointers holding implementation from custom library
+    fcomp_t fcomp_fp = nullptr;
+    parseAttrs_t parse_fp = nullptr;
+    inferType_t type_fp = nullptr;
+    inferShape_t shape_fp = nullptr;
+    // optional attributes
+    mutateInputs_t mutate_fp = nullptr;
+
+    // get custom operator implemenation from the dynamic library
+    opRegGet(i, &name, &fcomp_fp, &parse_fp, &type_fp, &shape_fp, &mutate_fp);
+
+    // validate custom operator functions from the dynamic library
+    CHECK(fcomp_fp != nullptr) << "Error loading '" << name
+                            << "' custom op, FCompute function was not set.";
+    CHECK(parse_fp != nullptr) << "Error loading '" << name
+                            << "' custom op, ParseAttrs function was not set.";
+    CHECK(type_fp  != nullptr) << "Error loading '" << name
+                            << "' custom op, InferType function was not set.";
+    CHECK(shape_fp != nullptr) << "Error loading '" << name
+                            << "' custom op, InferShape function was not set.";
+
 
 Review comment:
   Lets add a CHECK for mutate_fp too

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to