samskalicky commented on a change in pull request #15921: [WIP] dynamic custom 
operator support
URL: https://github.com/apache/incubator-mxnet/pull/15921#discussion_r320402090
 
 

 ##########
 File path: src/c_api/c_api.cc
 ##########
 @@ -92,16 +93,384 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
 
 // NOTE: return value is added in API_END
 
-// Loads library and initializes it
+/*!
+ * \brief Loads dynamic library and initializes it
+ * \param path library path
+ */
 int MXLoadLib(const char *path) {
   API_BEGIN();
   void *lib = LibraryInitializer::Get()->lib_load(path);
   if (!lib)
     LOG(FATAL) << "Unable to load library";
 
+  // check that library and MXNet use same version of library API
+  opVersion_t opVersion = get_func<opVersion_t>(lib, 
const_cast<char*>(MXLIB_OPVERSION_STR));
+  int libVersion =  opVersion();
+  if (MX_LIBRARY_VERSION != libVersion)
+    LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet 
version ("
+               << MX_LIBRARY_VERSION << ")";
+
+  // initialize library by passing MXNet version
   initialize_t initialize = get_func<initialize_t>(lib, 
const_cast<char*>(MXLIB_INITIALIZE_STR));
   if (!initialize(static_cast<int>(MXNET_VERSION)))
     LOG(FATAL) << "Library failed to initialize";
+
+  // get C type interface functions
+  opCallFree_t callFree = get_func<opCallFree_t>(lib, 
const_cast<char*>(MXLIB_OPCALLFREE_STR));
+
+  opCallParseAttrs_t callParseAttrs =
+    get_func<opCallParseAttrs_t>(lib, 
const_cast<char*>(MXLIB_OPCALLPARSEATTRS_STR));
+
+  opCallInferShape_t callInferShape =
+    get_func<opCallInferShape_t>(lib, 
const_cast<char*>(MXLIB_OPCALLINFERSHAPE_STR));
+
+  opCallInferType_t callInferType =
+    get_func<opCallInferType_t>(lib, 
const_cast<char*>(MXLIB_OPCALLINFERTYPE_STR));
+
+  opCallFComp_t callFComp =
+    get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));
+
+  opCallMutateInputs_t callMutateInputs =
+    get_func<opCallMutateInputs_t>(lib, 
const_cast<char*>(MXLIB_OPCALLMUTATEINPUTS_STR));
+
+  // get number of operators registered in the library
+  opRegSize_t opRegSize = get_func<opRegSize_t>(lib, 
const_cast<char*>(MXLIB_OPREGSIZE_STR));
+  int numOps = opRegSize();
+  LOG(INFO) << "Found " << numOps << " operators in library";
+
+  /*
+   * The library has custom operators implementation
+   * loop and register each operator in the library to NNVM
+   */
+  opRegGet_t opRegGet = get_func<opRegGet_t>(lib, 
const_cast<char*>(MXLIB_OPREGGET_STR));
+  for (int i = 0; i < numOps; i++) {
+    const char* name;
+    // function pointers holding implementation from custom library
+    fcomp_t fcomp_fp = nullptr;
+    parseAttrs_t parse_fp = nullptr;
+    inferType_t type_fp = nullptr;
+    inferShape_t shape_fp = nullptr;
+    // optional attributes
+    mutateInputs_t mutate_fp = nullptr;
+
+    // get custom operator implemenation from the dynamic library
+    opRegGet(i, &name, &fcomp_fp, &parse_fp, &type_fp, &shape_fp, &mutate_fp);
+
+    // validate custom operator functions from the dynamic library
+    CHECK(fcomp_fp != nullptr) << "Error loading '" << name
+                            << "' custom op, FCompute function was not set.";
+    CHECK(parse_fp != nullptr) << "Error loading '" << name
+                            << "' custom op, ParseAttrs function was not set.";
+    CHECK(type_fp  != nullptr) << "Error loading '" << name
+                            << "' custom op, InferType function was not set.";
+    CHECK(shape_fp != nullptr) << "Error loading '" << name
+                            << "' custom op, InferShape function was not set.";
+
+    LOG(INFO) << "\tOp[" << i << "] " << name;
+    std::string name_str(name);
+
+    /*
+     * Below are a series of lambda functions that will be registered in the 
NNVM op registration
+     * Each one has the standard MXNet signature and converts to types 
supported by externally
+     * registered operators. 
+     */
+
+    // lambda function to call parse attributes
+    auto attr_parser = [=](const NodeAttrs* attrs) {
+      // convert attributes to vector of char
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs->dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      int num_in = -1;
+      int num_out = -1;
+      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                           &num_in, &num_out))
+      << "Error calling ParseAttrs for custom operator '" << name_str << "'";
+
+      // return type void
+    };
+
+    // lambda function to call parse attributes and return the number of inputs
+    auto num_inputs = [=](const NodeAttrs& attrs) {
+      // convert attributes to vector of char
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      int num_in = -1;
+      int num_out = -1;
+      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                           &num_in, &num_out))
+      << "Error calling ParseAttrs::num_inputs for custom operator '" << 
name_str << "'";
+
+      return num_in;
+    };
+
+    // lambda function to call parse attributes and return the number of 
outputs
+    auto num_outputs = [=](const NodeAttrs& attrs) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      int num_in = -1;
+      int num_out = -1;
+      CHECK(callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                           &num_in, &num_out))
+      << "Error calling ParseAttrs::num_outputs for custom operator '" << 
name_str << "'";
+
+      return num_out;
+    };
+
+    // lambda function to call infer shape
+    auto infer_shape = [=] (const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector *in_shape,
+                            mxnet::ShapeVector *out_shape) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      std::vector<uint32_t*> inshapes(in_shape->size());
+      std::vector<int> indims(in_shape->size());
+
+      // determine amount of memory needed to store all the input shapes
+      size_t buff_size = 0;
+      for (const auto& i : *in_shape) buff_size += i.ndim();
+
+      // copy input shapes from ShapeVector to raw memory layout
+      std::vector<uint32_t> inbuff(buff_size);
+      uint32_t *ptr = inbuff.data();
+      for (size_t i = 0; i < in_shape->size(); ++i) {
+        inshapes[i] = ptr;
+        indims[i] = (*in_shape)[i].ndim();
+        for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) {
+          *ptr = static_cast<uint32_t>((*in_shape)[i][j]);
+        }
+      }
+
+      // output shapes will be allocated by infer shape function
+      uint32_t** outshapes = nullptr;
+      int* outdims = nullptr;
+
+      CHECK(callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                           inshapes.data(), indims.data(), in_shape->size(),
+                           &outshapes, &outdims, out_shape->size()))
+      << "Error calling InferShape for custom operator '" << name_str << "'";
+
+      std::vector<uint32_t*> out_shapes(out_shape->size());
+      // determine amount of memory needed to store all the output shapes
+      buff_size = 0;
+      for (unsigned i = 0; i < out_shape->size(); i++) {
+        buff_size += outdims[i];
+      }
+
+      // copy output shapes from custom op memory to MXNet memory
+      std::vector<uint32_t> outbuff(buff_size);
+      ptr = outbuff.data();
+      for (unsigned i = 0; i < out_shape->size(); ++i) {
+        out_shapes[i] = ptr;
+        for (int j = 0; j < outdims[i]; ++j, ++ptr) {
+          *ptr = static_cast<uint32_t>(outshapes[i][j]);
+        }
+      }
+
+      // assign output shapes to ShapeVector
+      for (unsigned i = 0; i < out_shape->size(); ++i) {
+        SHAPE_ASSIGN_CHECK(*out_shape, i,
+                           mxnet::TShape(out_shapes[i], 
out_shapes[i]+outdims[i]));
+      }
+
+      // free memory used by custom op to allocate shapes/dims
+      callFree(outdims);
+      for (unsigned i = 0; i < out_shape->size(); i++) {
+        callFree(outshapes[i]);
+      }
+      callFree(outshapes);
+
+      return true;
+    };
+
+    // lambda function to call infer type
+    auto infer_type = [=] (const nnvm::NodeAttrs& attrs,
+                            std::vector<int> *in_type,
+                            std::vector<int> *out_type) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      // copy input types from in_type
+      std::vector<int> intypes(*in_type);
+
+      // output types will be populated by inferType function
+      std::vector<int> outtypes(out_type->size());
+
+      CHECK(callInferType(type_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                           intypes.data(), in_type->size(),
+                           outtypes.data(), out_type->size()))
+      << "Error calling InferType for custom operator '" << name_str << "'";
+
+      // copy and assign output types from custom op to MXNet memory
+      for (size_t i = 0; i < out_type->size(); i++) {
+        TYPE_ASSIGN_CHECK(*out_type, i, outtypes[i]);
+      }
+
+      return true;
+    };
+
+    // lambda function to convert from external fcompute to internal MXNet 
types
+    auto fcomp_lambda = [=](const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      std::vector<void*> in_data, out_data;
+      std::vector<const int64_t *> in_shapes, out_shapes;
+      std::vector<int> in_dims, out_dims;
+      std::vector<int> in_types, out_types;
+
+      // convert input tensors to constituent parts
+      for (size_t i = 0; i < inputs.size(); i++) {
+        in_data.push_back(inputs[i].data().dptr_);
+        in_shapes.push_back(inputs[i].shape().data());
+        in_dims.push_back(inputs[i].shape().ndim());
+        in_types.push_back(inputs[i].dtype());
+      }
+
+      // convert output tensors to constituent parts
+      for (size_t i = 0; i < outputs.size(); i++) {
+        out_data.push_back(outputs[i].data().dptr_);
+        out_shapes.push_back(outputs[i].shape().data());
+        out_dims.push_back(outputs[i].shape().ndim());
+        out_types.push_back(outputs[i].dtype());
+      }
+
+      // get memory resource
+      const Resource &resource = ctx.requested[0];
+      mshadow::Stream<mxnet::cpu> *cpu_stream = ctx.get_stream<mxnet::cpu>();
+
+      // create lambda that captures stream & resource objects
+      auto cpu_alloc = [&](int size) {
+        mshadow::Tensor<mxnet::cpu, 1, char> data =
+        resource.get_space_typed<mxnet::cpu, 1, char>(mshadow::Shape1(size), 
cpu_stream);
+        return data.dptr_;
+      };
+
+      typedef decltype(cpu_alloc) alloc_type;
+
+      // create lambda without captures so that we can cast it to function 
pointer
+      // this needs to be a lambda function so that we can do the decltype cast
+      auto cpu_malloc = [](void* _cpu_alloc, int size) {
+        // cast the void* argument to the type for the cpu_alloc lambda 
function
+        alloc_type* cpualloc = static_cast<alloc_type*>(_cpu_alloc);
+
+        void* ptr = (*cpualloc)(size);
+        return ptr;
+      };
+
+      // call fcompute function
+      CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                      in_shapes.data(), in_dims.data(), in_data.data(),
+                      in_types.data(), in_data.size(),
+                      out_shapes.data(), out_dims.data(), out_data.data(),
+                      out_types.data(), out_data.size(), cpu_malloc, 
&cpu_alloc))
+      << "Error calling FCompute for custom operator '" << name_str << "'";
+
+      // return type void
+    };
+
+    // lambda function to convert from external mutate_inputs to internal 
MXNet types
+    auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) {
+      // convert attributes to vector of char*
+      std::vector<const char*> attr_keys, attr_vals;
+      for (auto kv : attrs.dict) {
+        attr_keys.push_back(kv.first.c_str());
+        attr_vals.push_back(kv.second.c_str());
+      }
+
+      // C type placeholder for mutate input indices vector
+      int* mutate_indices = nullptr;
+      int indices_size = 0;
+
+      // call mutate inputs function
+      CHECK(callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                      &mutate_indices, &indices_size))
+      << "Error calling MutateInputs for custom operator '" << name_str << "'";
+
+      std::vector<uint32_t> mutate_indices_list(indices_size);
+      for (int i=0; i < indices_size; i++) {
+        mutate_indices_list[i] = static_cast<uint32_t>(mutate_indices[i]);
+      }
+
+      return mutate_indices_list;
+    };
+
+    auto infer_storage_type = [=](const nnvm::NodeAttrs& attrs,
+                                  const int dev_mask,
+                                  DispatchMode* dispatch_mode,
+                                  std::vector<int>* in_stypes,
+                                  std::vector<int>* out_stypes) {
 
 Review comment:
   until we decide to actively support sparse, lets error out if in_stypes 
contain sparse types

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to