samskalicky commented on a change in pull request #15921: [WIP] dynamic custom operator support URL: https://github.com/apache/incubator-mxnet/pull/15921#discussion_r324862929
########## File path: include/mxnet/lib_api.h ########## @@ -18,33 +18,601 @@ */ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2019 by Contributors * \file lib_api.h * \brief APIs to interact with libraries + * This API specifies function prototypes to + * register custom ops for library authors */ + #ifndef MXNET_LIB_API_H_ #define MXNET_LIB_API_H_ +#include <stdint.h> +#include <vector> +#include <map> +#include <string> +#include <iostream> + +#define MX_LIBRARY_VERSION 1 + +/*! + * \brief External Tensor data types + */ +enum MXDType { + kFloat32 = 0, + kFloat64 = 1, + kFloat16 = 2, + kUint8 = 3, + kInt32 = 4, + kInt8 = 5, + kInt64 = 6, +}; + +enum MXReturnValue { + MX_FAIL = 0, + MX_SUCCESS = 1, +}; + +/*! + * \brief External Tensor data structure + */ +struct MXTensor { + MXTensor() : data(nullptr) {} + + MXTensor(void *data, const std::vector<int64_t> &shape, MXDType dtype) + : data{data}, shape{shape}, dtype{dtype} {} + + /*! + * \brief helper function to cast data pointer + */ + template<typename data_type> + data_type* getData() { + return reinterpret_cast<data_type*>(data); + } + + void *data; // not owned + std::vector<int64_t> shape; + MXDType dtype; +}; + +/*! + * \brief resource malloc function to allocate memory inside Forward/Backward functions + */ +typedef void* (*xpu_malloc_t)(void*, int); + +/*! + * \brief Class to provide resource APIs to Forward/Backward functions + */ +class OpResource { + public: + OpResource(xpu_malloc_t xm, void* _xm) : xpu_malloc(xm), _xpu_malloc(_xm) {} + + /*! + * \brief allocate memory controlled by MXNet + */ + void* alloc(int size) { + return xpu_malloc(_xpu_malloc, size); + } + private: + xpu_malloc_t xpu_malloc; + void* _xpu_malloc; +}; + +/*! + * \brief Macro to help passing serialized subgraph through attribute dict + */ +#define SUBGRAPH_SYM_JSON "subgraph_sym_json" + +/*! + * \brief An abstract class for library author creating stateful op + * custom library should override Forward and destructor, and has an + * option to implement Backward + */ +class CustomStatefulOp { + public: + virtual MXReturnValue Forward(std::vector<MXTensor> inputs, + std::vector<MXTensor> outputs, + OpResource op_res) = 0; + virtual MXReturnValue Backward(std::vector<MXTensor> inputs, + std::vector<MXTensor> outputs, + OpResource op_res) { + std::cout << "Error! Operator does not support backward" << std::endl; + return MX_FAIL; + } + virtual ~CustomStatefulOp() = 0; Review comment: lets not force the user to implement this. it also solves the weirdness around pure virtual destructor issue. I suggest we remove this line. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services