zheng-da commented on a change in pull request #13392: add csr sample op
URL: https://github.com/apache/incubator-mxnet/pull/13392#discussion_r236952140
 
 

 ##########
 File path: src/operator/contrib/dgl_graph.cc
 ##########
 @@ -35,6 +35,833 @@
 namespace mxnet {
 namespace op {
 
+typedef int64_t dgl_id_t;
+
+////////////////////////////// Graph Sampling ///////////////////////////////
+
+unsigned int seed = 123;
+
+/*
+ * This is used for BFS traversal
+ */
+struct ver_node {
+  dgl_id_t vertex_id;
+  int level;
+};
+
+/*
+ * ArrayHeap is used to sample elements from vector
+ */
+class ArrayHeap {
+ public:
+  // ctor & dctor
+  explicit ArrayHeap(const std::vector<float>& prob) {
+    this->vec_size = prob.size();
+    this->bit_len = ceil(log2(vec_size));
+    this->limit = 1 << bit_len;
+    // allocate twice the size
+    this->heap.resize(limit << 1, 0);
+    // allocate the leaves
+    for (int i = limit; i < vec_size+limit; ++i) {
+      heap[i] = prob[i-limit];
+    }
+    // iterate up the tree (this is O(m))
+    for (int i = bit_len-1; i >= 0; --i) {
+      for (int j = (1 << i); j < (1 << (i + 1)); ++j) {
+        heap[j] = heap[j << 1] + heap[(j << 1) + 1];
+      }
+    }
+  }
+  ~ArrayHeap() {}
+
+  /*
+   * Remove term from index (this costs O(log m) steps)
+   */
+  void Delete(size_t index) {
+    size_t i = index + limit;
+    float w = heap[i];
+    for (int j = bit_len; j >= 0; --j) {
+      heap[i] -= w;
+      i = i >> 1;
+    }
+  }
+
+  /*
+   * Add value w to index (this costs O(log m) steps)
+   */
+  void Add(size_t index, float w) {
+    size_t i = index + limit;
+    for (int j = bit_len; j >= 0; --j) {
+      heap[i] += w;
+      i = i >> 1;
+    }
+  }
+
+  /*
+   * Sample from arrayHeap
+   */
+  size_t Sample() {
+    float xi = heap[1] * (rand_r(&seed)%100/101.0);
+    int i = 1;
+    while (i < limit) {
+      i = i << 1;
+      if (xi >= heap[i]) {
+        xi -= heap[i];
+        i += 1;
+      }
+    }
+    return i - limit;
+  }
+
+  /*
+   * Sample a vector by given the size n
+   */
+  void SampleWithoutReplacement(size_t n, std::vector<size_t>* samples) {
+    // sample n elements
+    for (size_t i = 0; i < n; ++i) {
+      samples->at(i) = this->Sample();
+      this->Delete(samples->at(i));
+    }
+  }
+
+ private:
+  int vec_size;  // sample size
+  int bit_len;   // bit size
+  int limit;
+  std::vector<float> heap;
+};
+
+struct NeighborSampleParam : public dmlc::Parameter<NeighborSampleParam> {
+  int num_args;
+  dgl_id_t num_hops;
+  dgl_id_t num_neighbor;
+  dgl_id_t max_num_vertices;
+  DMLC_DECLARE_PARAMETER(NeighborSampleParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(2)
+    .describe("Number of input NDArray.");
+    DMLC_DECLARE_FIELD(num_hops)
+      .set_default(1)
+      .describe("Number of hops.");
+    DMLC_DECLARE_FIELD(num_neighbor)
+      .set_default(2)
+      .describe("Number of neighbor.");
+    DMLC_DECLARE_FIELD(max_num_vertices)
+      .set_default(100)
+      .describe("Max number of vertices.");
+  }
+};
+
+DMLC_REGISTER_PARAMETER(NeighborSampleParam);
+
+/*
+ * Check uniform Storage Type
+ */
+static bool CSRNeighborUniformSampleStorageType(const nnvm::NodeAttrs& attrs,
+                                                const int dev_mask,
+                                                DispatchMode* dispatch_mode,
+                                                std::vector<int> *in_attrs,
+                                                std::vector<int> *out_attrs) {
+  const NeighborSampleParam& params = 
nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  size_t num_subgraphs = params.num_args - 1;
+  CHECK_EQ(out_attrs->size(), 3 * num_subgraphs);
+
+  // input[0] is csr_graph
+  CHECK_EQ(in_attrs->at(0), mxnet::kCSRStorage);
+  // the rest input ndarray is seed_vector
+  for (size_t i = 0; i < num_subgraphs; i++)
+    CHECK_EQ(in_attrs->at(1 + i), mxnet::kDefaultStorage);
+
+  bool success = true;
+  // sample_id
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i], mxnet::kDefaultStorage)) {
+      success = false;
+    }
+  }
+  // sub_graph
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i + num_subgraphs], mxnet::kCSRStorage)) {
+      success = false;
+    }
+  }
+  // sub_layer
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i + 2*num_subgraphs], 
mxnet::kDefaultStorage)) {
+      success = false;
+    }
+  }
+
+  *dispatch_mode = DispatchMode::kFComputeEx;
+
+  return success;
+}
+
+/*
+ * Check non-uniform Storage Type
+ */
+static bool CSRNeighborNonUniformSampleStorageType(const nnvm::NodeAttrs& 
attrs,
+                                                   const int dev_mask,
+                                                   DispatchMode* dispatch_mode,
+                                                   std::vector<int> *in_attrs,
+                                                   std::vector<int> 
*out_attrs) {
+  const NeighborSampleParam& params =
+    nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  size_t num_subgraphs = params.num_args - 2;
+  CHECK_EQ(out_attrs->size(), 4 * num_subgraphs);
+
+  // input[0] is csr_graph
+  CHECK_EQ(in_attrs->at(0), mxnet::kCSRStorage);
+  // input[1] is probability
+  CHECK_EQ(in_attrs->at(1), mxnet::kDefaultStorage);
+
+  // the rest input ndarray is seed_vector
+  for (size_t i = 0; i < num_subgraphs; i++)
+    CHECK_EQ(in_attrs->at(2 + i), mxnet::kDefaultStorage);
+
+  bool success = true;
+  // sample_id
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i], mxnet::kDefaultStorage)) {
+      success = false;
+    }
+  }
+  // sub_graph
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i + num_subgraphs], mxnet::kCSRStorage)) {
+      success = false;
+    }
+  }
+  // sub_probability
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i + 2*num_subgraphs], 
mxnet::kDefaultStorage)) {
+      success = false;
+    }
+  }
+  // sub_layer
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    if (!type_assign(&(*out_attrs)[i + 3*num_subgraphs], 
mxnet::kDefaultStorage)) {
+      success = false;
+    }
+  }
+
+  *dispatch_mode = DispatchMode::kFComputeEx;
+
+  return success;
+}
+
+/*
+ * Check uniform Shape
+ */
+static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs,
+                                          std::vector<TShape> *in_attrs,
+                                          std::vector<TShape> *out_attrs) {
+  const NeighborSampleParam& params =
+    nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  size_t num_subgraphs = params.num_args - 1;
+  CHECK_EQ(out_attrs->size(), 3 * num_subgraphs);
+  // input[0] is csr graph
+  CHECK_EQ(in_attrs->at(0).ndim(), 2U);
+  CHECK_EQ(in_attrs->at(0)[0], in_attrs->at(0)[1]);
+
+  // the rest input ndarray is seed vector
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    CHECK_EQ(in_attrs->at(1 + i).ndim(), 1U);
+  }
+
+  // Output
+  bool success = true;
+  TShape out_shape(1);
+  // We use the last element to store the actual
+  // number of vertices in the subgraph.
+  out_shape[0] = params.max_num_vertices + 1;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i, out_shape);
+    success = success &&
+              out_attrs->at(i).ndim() != 0U &&
+              out_attrs->at(i).Size() != 0U;
+  }
+  // sub_csr
+  TShape out_csr_shape(2);
+  out_csr_shape[0] = params.max_num_vertices;
+  out_csr_shape[1] = in_attrs->at(0)[1];
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, out_csr_shape);
+    success = success &&
+              out_attrs->at(i + num_subgraphs).ndim() != 0U &&
+              out_attrs->at(i + num_subgraphs).Size() != 0U;
+  }
+  // sub_layer
+  TShape out_layer_shape(1);
+  out_layer_shape[0] = params.max_num_vertices;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_layer_shape);
+    success = success &&
+              out_attrs->at(i + 2*num_subgraphs).ndim() != 0U &&
+              out_attrs->at(i + 2*num_subgraphs).Size() != 0U;
+  }
+
+  return success;
+}
+
+/*
+ * Check non-uniform Shape
+ */
+static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs,
+                                             std::vector<TShape> *in_attrs,
+                                             std::vector<TShape> *out_attrs) {
+  const NeighborSampleParam& params =
+    nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  size_t num_subgraphs = params.num_args - 2;
+  CHECK_EQ(out_attrs->size(), 4 * num_subgraphs);
+  // input[0] is csr graph
+  CHECK_EQ(in_attrs->at(0).ndim(), 2U);
+  CHECK_EQ(in_attrs->at(0)[0], in_attrs->at(0)[1]);
+
+  // input[1] is probability
+  CHECK_EQ(in_attrs->at(1).ndim(), 1U);
+
+  // the rest ndarray is seed vector
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    CHECK_EQ(in_attrs->at(2 + i).ndim(), 1U);
+  }
+
+  // Output
+  bool success = true;
+  TShape out_shape(1);
+  // We use the last element to store the actual
+  // number of vertices in the subgraph.
+  out_shape[0] = params.max_num_vertices + 1;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i, out_shape);
+    success = success &&
+              out_attrs->at(i).ndim() != 0U &&
+              out_attrs->at(i).Size() != 0U;
+  }
+  // sub_csr
+  TShape out_csr_shape(2);
+  out_csr_shape[0] = params.max_num_vertices;
+  out_csr_shape[1] = in_attrs->at(0)[1];
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, out_csr_shape);
+    success = success &&
+              out_attrs->at(i + num_subgraphs).ndim() != 0U &&
+              out_attrs->at(i + num_subgraphs).Size() != 0U;
+  }
+  // sub_probability
+  TShape out_prob_shape(1);
+  out_prob_shape[0] = params.max_num_vertices;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_prob_shape);
+    success = success &&
+              out_attrs->at(i + 2*num_subgraphs).ndim() != 0U &&
+              out_attrs->at(i + 2*num_subgraphs).Size() != 0U;
+  }
+  // sub_layer
+  TShape out_layer_shape(1);
+  out_layer_shape[0] = params.max_num_vertices;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, i + 3*num_subgraphs, out_prob_shape);
+    success = success &&
+              out_attrs->at(i + 3*num_subgraphs).ndim() != 0U &&
+              out_attrs->at(i + 3*num_subgraphs).Size() != 0U;
+  }
+
+  return success;
+}
+
+/*
+ * Check uniform Type
+ */
+static bool CSRNeighborUniformSampleType(const nnvm::NodeAttrs& attrs,
+                                         std::vector<int> *in_attrs,
+                                         std::vector<int> *out_attrs) {
+  const NeighborSampleParam& params =
+    nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  size_t num_subgraphs = params.num_args - 1;
+  CHECK_EQ(out_attrs->size(), 3 * num_subgraphs);
+
+  bool success = true;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    TYPE_ASSIGN_CHECK(*out_attrs, i, in_attrs->at(1));
+    TYPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, in_attrs->at(0));
+    TYPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, in_attrs->at(1));
+    success = success &&
+               out_attrs->at(i) != -1 &&
+               out_attrs->at(i + num_subgraphs) != -1 &&
+               out_attrs->at(i + 2*num_subgraphs) != -1;
+  }
+
+  return success;
+}
+
+/*
+ * Check non-uniform Type
+ */
+static bool CSRNeighborNonUniformSampleType(const nnvm::NodeAttrs& attrs,
+                                            std::vector<int> *in_attrs,
+                                            std::vector<int> *out_attrs) {
+  const NeighborSampleParam& params =
+    nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  size_t num_subgraphs = params.num_args - 2;
+  CHECK_EQ(out_attrs->size(), 4 * num_subgraphs);
+
+  bool success = true;
+  for (size_t i = 0; i < num_subgraphs; i++) {
+    TYPE_ASSIGN_CHECK(*out_attrs, i, in_attrs->at(2));
+    TYPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, in_attrs->at(0));
+    TYPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, in_attrs->at(1));
+    TYPE_ASSIGN_CHECK(*out_attrs, i + 3*num_subgraphs, in_attrs->at(2));
+    success = success &&
+               out_attrs->at(i) != -1 &&
+               out_attrs->at(i + num_subgraphs) != -1 &&
+               out_attrs->at(i + 2*num_subgraphs) != -1 &&
+               out_attrs->at(i + 3*num_subgraphs) != -1;
+  }
+
+  return success;
+}
+
+/*
+ * Get src vertex and edge id for a destination vertex
+ */
+static void GetSrcList(const dgl_id_t* val_list,
+                       const dgl_id_t* col_list,
+                       const dgl_id_t* indptr,
+                       const dgl_id_t dst_id,
+                       std::vector<dgl_id_t>* src_list,
+                       std::vector<dgl_id_t>* edge_list) {
+  for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
+    src_list->push_back(col_list[i]);
+    edge_list->push_back(val_list[i]);
+  }
+}
+
+static void RandomSample(size_t set_size,
+                         size_t num,
+                         std::vector<size_t>* out) {
+  std::unordered_set<size_t> sampled_idxs;
+  while (sampled_idxs.size() < num) {
+    sampled_idxs.insert(rand_r(&seed) % set_size);
+  }
+  out->clear();
+  for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {
+    out->push_back(*it);
+  }
+}
+
+static void NegateSet(const std::vector<size_t> &idxs,
+                      size_t set_size,
+                      std::vector<size_t>* out) {
+  // idxs must have been sorted.
+  auto it = idxs.begin();
+  size_t i = 0;
+  CHECK_GT(set_size, idxs.back());
+  for (; i < set_size && it != idxs.end(); i++) {
+    if (*it == i) {
+      it++;
+      continue;
+    }
+    out->push_back(i);
+  }
+  for (; i < set_size; i++) {
+    out->push_back(i);
+  }
+}
+
+/*
+ * Uniform sample
+ */
+static void GetUniformSample(const std::vector<dgl_id_t>& ver_list,
+                             const std::vector<dgl_id_t>& edge_list,
+                             const size_t max_num_neighbor,
+                             std::vector<dgl_id_t>* out_ver,
+                             std::vector<dgl_id_t>* out_edge) {
+  CHECK_EQ(ver_list.size(), edge_list.size());
+  // Copy ver_list to output
+  if (ver_list.size() <= max_num_neighbor) {
+    for (size_t i = 0; i < ver_list.size(); ++i) {
+      out_ver->push_back(ver_list[i]);
+      out_edge->push_back(edge_list[i]);
+    }
+    return;
+  }
+  // If we just sample a small number of elements from a large neighbor list.
+  std::vector<size_t> sorted_idxs;
+  if (ver_list.size() > max_num_neighbor * 2) {
+    sorted_idxs.reserve(max_num_neighbor);
+    RandomSample(ver_list.size(), max_num_neighbor, &sorted_idxs);
+    std::sort(sorted_idxs.begin(), sorted_idxs.end());
+  } else {
+    std::vector<size_t> negate;
+    negate.reserve(ver_list.size() - max_num_neighbor);
+    RandomSample(ver_list.size(), ver_list.size() - max_num_neighbor,
+                 &negate);
+    std::sort(negate.begin(), negate.end());
+    NegateSet(negate, ver_list.size(), &sorted_idxs);
+  }
+  // verify the result.
+  CHECK_EQ(sorted_idxs.size(), max_num_neighbor);
+  for (size_t i = 1; i < sorted_idxs.size(); i++) {
+    CHECK_GT(sorted_idxs[i], sorted_idxs[i - 1]);
+  }
+  for (auto idx : sorted_idxs) {
+    out_ver->push_back(ver_list[idx]);
+    out_edge->push_back(edge_list[idx]);
+  }
+}
+
+/*
+ * Non-uniform sample via ArrayHeap
+ */
+static void GetNonUniformSample(const float* probability,
+                                const std::vector<dgl_id_t>& ver_list,
+                                const std::vector<dgl_id_t>& edge_list,
+                                const size_t max_num_neighbor,
+                                std::vector<dgl_id_t>* out_ver,
+                                std::vector<dgl_id_t>* out_edge) {
+  CHECK_EQ(ver_list.size(), edge_list.size());
+  // Copy ver_list to output
+  if (ver_list.size() <= max_num_neighbor) {
+    for (size_t i = 0; i < ver_list.size(); ++i) {
+      out_ver->push_back(ver_list[i]);
+      out_edge->push_back(edge_list[i]);
+    }
+    return;
+  }
+  // Make sample
+  std::vector<size_t> sp_index(max_num_neighbor);
+  std::vector<float> sp_prob(ver_list.size());
+  for (size_t i = 0; i < ver_list.size(); ++i) {
+    sp_prob[i] = probability[ver_list[i]];
+  }
+  ArrayHeap arrayHeap(sp_prob);
+  arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index);
+  out_ver->resize(max_num_neighbor);
+  out_edge->resize(max_num_neighbor);
+  for (size_t i = 0; i < max_num_neighbor; ++i) {
+    size_t idx = sp_index[i];
+    out_ver->at(i) = ver_list[idx];
+    out_edge->at(i) = edge_list[idx];
+  }
+}
+
+/*
+ * Used for subgraph sampling
+ */
+struct neigh_list {
+  std::vector<dgl_id_t> neighs;
+  std::vector<dgl_id_t> edges;
+  neigh_list(const std::vector<dgl_id_t> &_neighs,
+             const std::vector<dgl_id_t> &_edges)
+    : neighs(_neighs), edges(_edges) {}
+};
+
+/*
+ * Sample sub-graph from csr graph
+ */
+static void SampleSubgraph(const NDArray &csr,
+                           const NDArray &seed_arr,
+                           const NDArray &sampled_ids,
+                           const NDArray &sub_csr,
+                           float* sub_prob,
+                           const NDArray &sub_layer,
+                           const float* probability,
+                           dgl_id_t num_hops,
+                           dgl_id_t num_neighbor,
+                           dgl_id_t max_num_vertices) {
+  size_t num_seeds = seed_arr.shape().Size();
+  CHECK_GE(max_num_vertices, num_seeds);
+
+  const dgl_id_t* val_list = csr.data().dptr<dgl_id_t>();
+  const dgl_id_t* col_list = csr.aux_data(csr::kIdx).dptr<dgl_id_t>();
+  const dgl_id_t* indptr = csr.aux_data(csr::kIndPtr).dptr<dgl_id_t>();
+  const dgl_id_t* seed = seed_arr.data().dptr<dgl_id_t>();
+  dgl_id_t* out = sampled_ids.data().dptr<dgl_id_t>();
+  dgl_id_t* out_layer = sub_layer.data().dptr<dgl_id_t>();
+  // BFS traverse the graph and sample vertices
+  dgl_id_t sub_vertices_count = 0;
+  std::unordered_set<dgl_id_t> sub_ver_mp;
+  std::unordered_map<dgl_id_t, int> layer;
+  std::queue<ver_node> node_queue;
+  // add seed vertices
+  for (size_t i = 0; i < num_seeds; ++i) {
+    ver_node node;
+    node.vertex_id = seed[i];
+    node.level = 0;
+    node_queue.push(node);
+    sub_ver_mp.insert(node.vertex_id);
+    layer[node.vertex_id] = node.level;
+    sub_vertices_count++;
+  }
+  std::vector<dgl_id_t> tmp_src_list;
+  std::vector<dgl_id_t> tmp_edge_list;
+  std::vector<dgl_id_t> tmp_sampled_src_list;
+  std::vector<dgl_id_t> tmp_sampled_edge_list;
+  std::unordered_map<dgl_id_t, neigh_list> neigh_mp;
+  size_t num_edges = 0;
+  while (!node_queue.empty()) {
+    ver_node& cur_node = node_queue.front();
+    if (cur_node.level < num_hops) {
+      dgl_id_t dst_id = cur_node.vertex_id;
+      tmp_src_list.clear();
+      tmp_edge_list.clear();
+      tmp_sampled_src_list.clear();
+      tmp_sampled_edge_list.clear();
+      GetSrcList(val_list,
+                 col_list,
+                 indptr,
+                 dst_id,
+                 &tmp_src_list,
+                 &tmp_edge_list);
+      if (probability == nullptr) {  // uniform-sample
+        // Here we can also use GetUniformSampleShuffle() API
+        GetUniformSample(tmp_src_list,
+                       tmp_edge_list,
+                       num_neighbor,
+                       &tmp_sampled_src_list,
+                       &tmp_sampled_edge_list);
+      } else {  // non-uniform-sample
+        GetNonUniformSample(probability,
+                       tmp_src_list,
+                       tmp_edge_list,
+                       num_neighbor,
+                       &tmp_sampled_src_list,
+                       &tmp_sampled_edge_list);
+      }
+      neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
+        neigh_list(tmp_sampled_src_list, tmp_sampled_edge_list)));
+      num_edges += tmp_sampled_src_list.size();
+      bool exit = false;
+      for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
+        auto ret = sub_ver_mp.insert(tmp_sampled_src_list[i]);
+        if (ret.second) {
+          sub_vertices_count++;
+          ver_node new_node;
+          new_node.vertex_id = tmp_sampled_src_list[i];
+          new_node.level = cur_node.level + 1;
+          if (layer.find(tmp_sampled_src_list[i]) == layer.end()) {
+            layer[tmp_sampled_src_list[i]] = new_node.level;
+          }
+          if (new_node.level < num_hops) {
+            node_queue.push(new_node);
+          }
+          if (sub_vertices_count == max_num_vertices) {
+            exit = true;
+            break;
+          }
+        }
+      }
+      if (exit) {
+        break;
+      }
+    }
+    node_queue.pop();
+  }
+  // Copy sub_ver_mp to output[0]
+  size_t idx = 0;
+  for (auto& data : sub_ver_mp) {
+    *(out+idx) = data;
+    idx++;
+  }
+  size_t num_vertices = sub_ver_mp.size();
+  std::sort(out, out + num_vertices);
+  // The rest data will be set to -1
+  for (dgl_id_t i = idx; i < max_num_vertices; ++i) {
+    *(out+i) = -1;
+  }
+  // The last element stores the actual
+  // number of vertices in the subgraph.
+  out[max_num_vertices] = sub_ver_mp.size();
+  // Copy sub_probability
+  if (sub_prob != nullptr) {
+    for (dgl_id_t i = 0; i < max_num_vertices; ++i) {
+      dgl_id_t idx = out[i];
+      if (idx != -1) {
+        sub_prob[i] = probability[idx];
+      } else {
+        sub_prob[i] = -1;
+      }
+    }
+  }
+  // Copy layer
+  for (dgl_id_t i = 0; i < max_num_vertices; ++i) {
+    dgl_id_t idx = out[i];
+    if (idx != -1) {
+      out_layer[i] = layer[idx];
+    } else {
+      out_layer[i] = -1;
+    }
+  }
+  // Construct sub_csr_graph
+  TShape shape_1(1);
+  TShape shape_2(1);
+  shape_1[0] = num_edges;
+  shape_2[0] = max_num_vertices+1;
+  sub_csr.CheckAndAllocData(shape_1);
+  sub_csr.CheckAndAllocAuxData(csr::kIdx, shape_1);
+  sub_csr.CheckAndAllocAuxData(csr::kIndPtr, shape_2);
+  dgl_id_t* val_list_out = sub_csr.data().dptr<dgl_id_t>();
+  dgl_id_t* col_list_out = sub_csr.aux_data(1).dptr<dgl_id_t>();
+  dgl_id_t* indptr_out = sub_csr.aux_data(0).dptr<dgl_id_t>();
+
+  size_t collected_nedges = 0;
+  for (size_t i = 0, index = 1; i < num_vertices; i++) {
+    dgl_id_t dst_id = *(out + i);
+    auto it = neigh_mp.find(dst_id);
+    if (it != neigh_mp.end()) {
+      const auto &edges = it->second.edges;
+      const auto &neighs = it->second.neighs;
+      CHECK_EQ(edges.size(), neighs.size());
+      std::copy(edges.begin(), edges.end(), val_list_out + collected_nedges);
+      std::copy(neighs.begin(), neighs.end(), col_list_out + collected_nedges);
+      collected_nedges += edges.size();
+      indptr_out[index] = indptr_out[index-1] + edges.size();
+    } else {
+      indptr_out[index] = indptr_out[index-1];
+    }
+    index++;
+  }
+  for (dgl_id_t i = num_vertices+1; i <= max_num_vertices; ++i) {
+    indptr_out[i] = indptr_out[i-1];
+  }
+  CHECK_EQ(collected_nedges, num_edges);
+}
+
+/*
+ * Operator: contrib_csr_neighbor_uniform_sample
+ */
+static void CSRNeighborUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                          const OpContext& ctx,
+                                          const std::vector<NDArray>& inputs,
+                                          const std::vector<OpReqType>& req,
+                                          const std::vector<NDArray>& outputs) 
{
+  const NeighborSampleParam& params =
+    nnvm::get<NeighborSampleParam>(attrs.parsed);
+
+  int num_subgraphs = inputs.size() - 1;
+  CHECK_EQ(outputs.size(), 3 * num_subgraphs);
+
+#pragma omp parallel for
+  for (int i = 0; i < num_subgraphs; i++) {
+    SampleSubgraph(inputs[0],                     // graph_csr
+                   inputs[i + 1],                 // seed vector
+                   outputs[i],                    // sample_id
+                   outputs[i + 1*num_subgraphs],  // sub_csr
+                   nullptr,                       // sample_id_probability
+                   outputs[i + 2*num_subgraphs],  // sample_id_layer
+                   nullptr,                       // probability
+                   params.num_hops,
+                   params.num_neighbor,
+                   params.max_num_vertices);
+  }
+}
+
+NNVM_REGISTER_OP(_contrib_csr_neighbor_uniform_sample)
 
 Review comment:
   could you rename it with `_contrib_dgl_neighbor_uniform_sample`? it's quite 
dgl specific.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to