This is an automated email from the ASF dual-hosted git repository. zhasheng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 24ea5b0 Factorization machine example & sparse example folder re-org (#8767) 24ea5b0 is described below commit 24ea5b02d2a402e02c7aaab0138a853eadedcef5 Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Fri Dec 8 15:32:36 2017 -0800 Factorization machine example & sparse example folder re-org (#8767) * draft for fm * add example of checkpointing sparse model * remove kvtore code * add back kvstore and dummy iter * add fm folder * add missing files * add metric * update optimizer * bug fix for metrics * update default num epochs * re-org example folder * better aws cmd * add hyperparams to args parser * move dummy iter to common file. renmae get_data.py * remove dataset. * use fluent method * move dummy iter to test_utils * add error msg * doc update of dummy itr --- example/sparse/factorization_machine/README.md | 16 +++ example/sparse/factorization_machine/metric.py | 88 +++++++++++++ example/sparse/factorization_machine/model.py | 54 ++++++++ example/sparse/factorization_machine/train.py | 142 +++++++++++++++++++++ example/sparse/linear_classification/README.md | 17 +++ example/sparse/linear_classification/data.py | 33 +++++ .../{ => linear_classification}/linear_model.py | 0 .../train.py} | 12 +- .../weighted_softmax_ce.py | 0 example/sparse/matrix_factorization/README.md | 8 ++ .../{get_data.py => matrix_factorization/data.py} | 33 +---- .../model.py} | 0 .../train.py} | 5 +- example/sparse/readme.md | 21 --- python/mxnet/test_utils.py | 31 +++++ 15 files changed, 398 insertions(+), 62 deletions(-) diff --git a/example/sparse/factorization_machine/README.md b/example/sparse/factorization_machine/README.md new file mode 100644 index 0000000..7609f31 --- /dev/null +++ b/example/sparse/factorization_machine/README.md @@ -0,0 +1,16 @@ +Factorization Machine +=========== +This example trains a factorization machine model using the criteo dataset. + +## Download the Dataset + +The criteo dataset is available at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#criteo +The data was used in a competition on click-through rate prediction jointly hosted by Criteo and Kaggle in 2014, +with 1,000,000 features. There are 45,840,617 training examples and 6,042,135 testing examples. +It takes more than 30 GB to download and extract the dataset. + +## Train the Model + +- python train.py --train-data /path/to/criteo.kaggle2014.test.svm --test-data /path/to/criteo.kaggle2014.test.svm + +[Rendle, Steffen. "Factorization machines." In Data Mining (ICDM), 2010 IEEE 10th International Conference on, pp. 995-1000. IEEE, 2010. ](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) diff --git a/example/sparse/factorization_machine/metric.py b/example/sparse/factorization_machine/metric.py new file mode 100644 index 0000000..07a7e01 --- /dev/null +++ b/example/sparse/factorization_machine/metric.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np + +@mx.metric.register +@mx.metric.alias('log_loss') +class LogLossMetric(mx.metric.EvalMetric): + """Computes the negative log-likelihood loss. + + The negative log-likelihoodd loss over a batch of sample size :math:`N` is given by + + .. math:: + -\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}), + + where :math:`K` is the number of classes, :math:`y_{nk}` is the prediceted probability for + :math:`k`-th class for :math:`n`-th sample. :math:`t_{nk}=1` if and only if sample + :math:`n` belongs to class :math:`k`. + + Parameters + ---------- + eps : float + Negative log-likelihood loss is undefined for predicted value is 0, + so predicted values are added with the small constant. + name : str + Name of this metric instance for display. + output_names : list of str, or None + Name of predictions that should be used when updating with update_dict. + By default include all predictions. + label_names : list of str, or None + Name of labels that should be used when updating with update_dict. + By default include all labels. + + Examples + -------- + >>> predicts = [mx.nd.array([[0.3], [0], [0.4]])] + >>> labels = [mx.nd.array([0, 1, 1])] + >>> log_loss= mx.metric.NegativeLogLikelihood() + >>> log_loss.update(labels, predicts) + >>> print log_loss.get() + ('log-loss', 0.57159948348999023) + """ + def __init__(self, eps=1e-12, name='log-loss', + output_names=None, label_names=None): + super(LogLossMetric, self).__init__( + name, eps=eps, + output_names=output_names, label_names=label_names) + self.eps = eps + + def update(self, labels, preds): + """Updates the internal evaluation result. + + Parameters + ---------- + labels : list of `NDArray` + The labels of the data. + + preds : list of `NDArray` + Predicted values. + """ + mx.metric.check_label_shapes(labels, preds) + + for label, pred in zip(labels, preds): + label = label.asnumpy() + pred = pred.asnumpy() + pred = np.column_stack((1 - pred, pred)) + + label = label.ravel() + num_examples = pred.shape[0] + assert label.shape[0] == num_examples, (label.shape[0], num_examples) + prob = pred[np.arange(num_examples, dtype=np.int64), np.int64(label)] + self.sum_metric += (-np.log(prob + self.eps)).sum() + self.num_inst += num_examples diff --git a/example/sparse/factorization_machine/model.py b/example/sparse/factorization_machine/model.py new file mode 100644 index 0000000..f0af2e6 --- /dev/null +++ b/example/sparse/factorization_machine/model.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx + +def factorization_machine_model(factor_size, num_features, + lr_mult_config, wd_mult_config, init_config): + """ builds factorization machine network with proper formulation: + y = w_0 \sum(x_i w_i) + 0.5(\sum\sum<v_i,v_j>x_ix_j - \sum<v_iv_i>x_i^2) + """ + x = mx.symbol.Variable("data", stype='csr') + # factor, linear and bias terms + v = mx.symbol.Variable("v", shape=(num_features, factor_size), stype='row_sparse', + init=init_config['v'], lr_mult=lr_mult_config['v'], + wd_mult=wd_mult_config['v']) + w = mx.symbol.var('w', shape=(num_features, 1), stype='row_sparse', + init=init_config['w'], lr_mult=lr_mult_config['w'], + wd_mult=wd_mult_config['w']) + w0 = mx.symbol.var('w0', shape=(1,), init=init_config['w0'], + lr_mult=lr_mult_config['w0'], wd_mult=wd_mult_config['w0']) + w1 = mx.symbol.broadcast_add(mx.symbol.dot(x, w), w0) + + # squared terms for subtracting self interactions + v_s = mx.symbol._internal._square_sum(data=v, axis=1, keepdims=True) + x_s = x.square() + bd_sum = mx.sym.dot(x_s, v_s) + + # interactions + w2 = mx.symbol.dot(x, v) + w2_squared = 0.5 * mx.symbol.square(data=w2) + + # putting everything together + w_all = mx.symbol.Concat(w1, w2_squared, dim=1) + sum1 = w_all.sum(axis=1, keepdims=True) + sum2 = -0.5 * bd_sum + model = sum1 + sum2 + + y = mx.symbol.Variable("softmax_label") + model = mx.symbol.LogisticRegressionOutput(data=model, label=y) + return model diff --git a/example/sparse/factorization_machine/train.py b/example/sparse/factorization_machine/train.py new file mode 100644 index 0000000..741cf95 --- /dev/null +++ b/example/sparse/factorization_machine/train.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from metric import * +from mxnet.test_utils import * +from model import * +import argparse, os + +parser = argparse.ArgumentParser(description="Run factorization machine with criteo dataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--data-train', type=str, default=None, + help='training dataset in LibSVM format.') +parser.add_argument('--data-test', type=str, default=None, + help='test dataset in LibSVM format.') +parser.add_argument('--num-epoch', type=int, default=1, + help='number of epochs to train') +parser.add_argument('--batch-size', type=int, default=1000, + help='number of examples per batch') +parser.add_argument('--input-size', type=int, default=1000000, + help='number of features in the input') +parser.add_argument('--factor-size', type=int, default=16, + help='number of latent variables') +parser.add_argument('--factor-lr', type=float, default=0.0001, + help='learning rate for factor terms') +parser.add_argument('--linear-lr', type=float, default=0.001, + help='learning rate for linear terms') +parser.add_argument('--bias-lr', type=float, default=0.1, + help='learning rate for bias terms') +parser.add_argument('--factor-wd', type=float, default=0.00001, + help='weight decay rate for factor terms') +parser.add_argument('--linear-wd', type=float, default=0.001, + help='weight decay rate for linear terms') +parser.add_argument('--bias-wd', type=float, default=0.01, + help='weight decay rate for bias terms') +parser.add_argument('--factor-sigma', type=float, default=0.001, + help='standard deviation for initialization of factor terms') +parser.add_argument('--linear-sigma', type=float, default=0.01, + help='standard deviation for initialization of linear terms') +parser.add_argument('--bias-sigma', type=float, default=0.01, + help='standard deviation for initialization of bias terms') +parser.add_argument('--log-interval', type=int, default=100, + help='number of batches between logging messages') +parser.add_argument('--kvstore', type=str, default='local', + help='what kvstore to use', choices=["dist_async", "local"]) + +if __name__ == '__main__': + import logging + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.INFO, format=head) + + # arg parser + args = parser.parse_args() + logging.info(args) + num_epoch = args.num_epoch + batch_size = args.batch_size + kvstore = args.kvstore + factor_size = args.factor_size + num_features = args.input_size + log_interval = args.log_interval + assert(args.data_train is not None and args.data_test is not None), \ + "dataset for training or test is missing" + + # create kvstore + kv = mx.kvstore.create(kvstore) + # data iterator + train_data = mx.io.LibSVMIter(data_libsvm=args.data_train, data_shape=(num_features,), + batch_size=batch_size) + eval_data = mx.io.LibSVMIter(data_libsvm=args.data_test, data_shape=(num_features,), + batch_size=batch_size) + # model + lr_config = {'v': args.factor_lr, 'w': args.linear_lr, 'w0': args.bias_lr} + wd_config = {'v': args.factor_wd, 'w': args.linear_wd, 'w0': args.bias_wd} + init_config = {'v': mx.initializer.Normal(args.factor_sigma), + 'w': mx.initializer.Normal(args.linear_sigma), + 'w0': mx.initializer.Normal(args.bias_sigma)} + model = factorization_machine_model(factor_size, num_features, lr_config, wd_config, init_config) + + # module + mod = mx.mod.Module(symbol=model) + mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) + mod.init_params() + optimizer_params=(('learning_rate', 1), ('wd', 1), ('beta1', 0.9), + ('beta2', 0.999), ('epsilon', 1e-8)) + mod.init_optimizer(optimizer='adam', kvstore=kv, optimizer_params=optimizer_params) + + # metrics + metric = mx.metric.create(['log_loss']) + speedometer = mx.callback.Speedometer(batch_size, log_interval) + + # get the sparse weight parameter + w_index = mod._exec_group.param_names.index('w') + w_param = mod._exec_group.param_arrays[w_index] + v_index = mod._exec_group.param_names.index('v') + v_param = mod._exec_group.param_arrays[v_index] + + logging.info('Training started ...') + train_iter = iter(train_data) + eval_iter = iter(eval_data) + for epoch in range(num_epoch): + nbatch = 0 + metric.reset() + for batch in train_iter: + nbatch += 1 + # manually pull sparse weights from kvstore so that _square_sum + # only computes the rows necessary + row_ids = batch.data[0].indices + kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index) + kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index) + mod.forward_backward(batch) + # update all parameters (including the weight parameter) + mod.update() + # update training metric + mod.update_metric(metric, batch.label) + speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch, + eval_metric=metric, locals=locals()) + speedometer(speedometer_param) + + # pull all updated rows before validation + kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index) + kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index) + # evaluate metric on validation dataset + score = mod.score(eval_iter, ['log_loss']) + logging.info("epoch %d, eval log loss = %s" % (epoch, score[0][1])) + # reset the iterator for next pass of data + train_iter.reset() + eval_iter.reset() + logging.info('Training completed.') diff --git a/example/sparse/linear_classification/README.md b/example/sparse/linear_classification/README.md new file mode 100644 index 0000000..7e2a7ad --- /dev/null +++ b/example/sparse/linear_classification/README.md @@ -0,0 +1,17 @@ +Linear Classification Using Sparse Matrix Multiplication +=========== +This examples trains a linear model using the sparse feature in MXNet. This is for demonstration purpose only. + +The example utilizes the sparse data loader ([mx.io.LibSVMIter](https://mxnet.incubator.apache.org/versions/master/api/python/io.html#mxnet.io.LibSVMIter)), +the sparse dot operator and [sparse gradient updaters](https://mxnet.incubator.apache.org/versions/master/api/python/ndarray/sparse.html#updater) +to train a linear model on the +[Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset. + +The example also shows how to perform distributed training with the sparse feature. + +- `python train.py` + +Notes on Distributed Training: + +- For distributed training, please use the `../../tools/launch.py` script to launch a cluster. +- For example, to run two workers and two servers with one machine, run `../../../tools/launch.py -n 2 --launcher=local python train.py --kvstore=dist_async` diff --git a/example/sparse/linear_classification/data.py b/example/sparse/linear_classification/data.py new file mode 100644 index 0000000..0298473 --- /dev/null +++ b/example/sparse/linear_classification/data.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os, gzip +import sys +import mxnet as mx + +def get_avazu_data(data_dir, data_name, url): + if not os.path.isdir(data_dir): + os.mkdir(data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + print("Dataset " + data_name + " not present. Downloading now ...") + import urllib + zippath = os.path.join(data_dir, data_name + ".bz2") + urllib.urlretrieve(url + data_name + ".bz2", zippath) + os.system("bzip2 -d %r" % data_name + ".bz2") + print("Dataset " + data_name + " is now present.") + os.chdir("..") diff --git a/example/sparse/linear_model.py b/example/sparse/linear_classification/linear_model.py similarity index 100% rename from example/sparse/linear_model.py rename to example/sparse/linear_classification/linear_model.py diff --git a/example/sparse/linear_classification.py b/example/sparse/linear_classification/train.py similarity index 95% rename from example/sparse/linear_classification.py rename to example/sparse/linear_classification/train.py index 1d63c55..eb7871b 100644 --- a/example/sparse/linear_classification.py +++ b/example/sparse/linear_classification/train.py @@ -17,7 +17,7 @@ import mxnet as mx from mxnet.test_utils import * -from get_data import get_libsvm_data +from data import get_avazu_data from linear_model import * import argparse import os @@ -67,8 +67,8 @@ if __name__ == '__main__': data_dir = os.path.join(os.getcwd(), 'data') train_data = os.path.join(data_dir, AVAZU['train']) val_data = os.path.join(data_dir, AVAZU['test']) - get_libsvm_data(data_dir, AVAZU['train'], AVAZU['url']) - get_libsvm_data(data_dir, AVAZU['test'], AVAZU['url']) + get_avazu_data(data_dir, AVAZU['train'], AVAZU['url']) + get_avazu_data(data_dir, AVAZU['test'], AVAZU['url']) # data iterator train_data = mx.io.LibSVMIter(data_libsvm=train_data, data_shape=(num_features,), @@ -100,11 +100,10 @@ if __name__ == '__main__': speedometer = mx.callback.Speedometer(batch_size, 100) logging.info('Training started ...') - data_iter = iter(train_data) for epoch in range(num_epoch): nbatch = 0 metric.reset() - for batch in data_iter: + for batch in train_data: nbatch += 1 # for distributed training, we need to manually pull sparse weights from kvstore if kv: @@ -129,5 +128,6 @@ if __name__ == '__main__': save_optimizer_states = 'dist' not in kv.type if kv else True mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=save_optimizer_states) # reset the iterator for next pass of data - data_iter.reset() + train_data.reset() + eval_data.reset() logging.info('Training completed.') diff --git a/example/sparse/weighted_softmax_ce.py b/example/sparse/linear_classification/weighted_softmax_ce.py similarity index 100% rename from example/sparse/weighted_softmax_ce.py rename to example/sparse/linear_classification/weighted_softmax_ce.py diff --git a/example/sparse/matrix_factorization/README.md b/example/sparse/matrix_factorization/README.md new file mode 100644 index 0000000..3ada5e8 --- /dev/null +++ b/example/sparse/matrix_factorization/README.md @@ -0,0 +1,8 @@ +Matrix Factorization w/ Sparse Embedding +=========== +The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples. +The operator is available on both CPU and GPU. This is for demonstration purpose only. + +- `python train.py` +- To compare the training speed with (dense) Embedding, run `python train.py --use-dense` +- To run the example on the GPU, run `python train.py --use-gpu` diff --git a/example/sparse/get_data.py b/example/sparse/matrix_factorization/data.py similarity index 68% rename from example/sparse/get_data.py rename to example/sparse/matrix_factorization/data.py index 19c635f..fae2c23 100644 --- a/example/sparse/get_data.py +++ b/example/sparse/matrix_factorization/data.py @@ -18,38 +18,7 @@ import os, gzip import sys import mxnet as mx - -class DummyIter(mx.io.DataIter): - "A dummy iterator that always return the same batch, used for speed testing" - def __init__(self, real_iter): - super(DummyIter, self).__init__() - self.real_iter = real_iter - self.provide_data = real_iter.provide_data - self.provide_label = real_iter.provide_label - self.batch_size = real_iter.batch_size - - for batch in real_iter: - self.the_batch = batch - break - - def __iter__(self): - return self - - def next(self): - return self.the_batch - -def get_libsvm_data(data_dir, data_name, url): - if not os.path.isdir(data_dir): - os.mkdir(data_dir) - os.chdir(data_dir) - if (not os.path.exists(data_name)): - print("Dataset " + data_name + " not present. Downloading now ...") - import urllib - zippath = os.path.join(data_dir, data_name + ".bz2") - urllib.urlretrieve(url + data_name + ".bz2", zippath) - os.system("bzip2 -d %r" % data_name + ".bz2") - print("Dataset " + data_name + " is now present.") - os.chdir("..") +from mxnet.test_utils import DummyIter def get_movielens_data(prefix): if not os.path.exists("%s.zip" % prefix): diff --git a/example/sparse/matrix_fact_model.py b/example/sparse/matrix_factorization/model.py similarity index 100% rename from example/sparse/matrix_fact_model.py rename to example/sparse/matrix_factorization/model.py diff --git a/example/sparse/matrix_factorization.py b/example/sparse/matrix_factorization/train.py similarity index 97% rename from example/sparse/matrix_factorization.py rename to example/sparse/matrix_factorization/train.py index 3387706..14c6ca1 100644 --- a/example/sparse/matrix_factorization.py +++ b/example/sparse/matrix_factorization/train.py @@ -20,9 +20,8 @@ import logging import time import mxnet as mx import numpy as np -from get_data import get_movielens_iter, get_movielens_data -from matrix_fact_model import matrix_fact_net - +from data import get_movielens_iter, get_movielens_data +from model import matrix_fact_net logging.basicConfig(level=logging.DEBUG) diff --git a/example/sparse/readme.md b/example/sparse/readme.md deleted file mode 100644 index e443bfa..0000000 --- a/example/sparse/readme.md +++ /dev/null @@ -1,21 +0,0 @@ -Example -=========== -This folder contains examples using the sparse feature in MXNet. They are for demonstration purpose only. - -## Linear Classification Using Sparse Matrix Multiplication - -The example demonstrates the basic usage of the sparse feature in MXNet to speedup computation. It utilizes the sparse data loader, sparse operators and a sparse gradient updater to train a linear model on the [Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset. - -- `python linear_classification.py` - -Notes on Distributed Training: - -- For distributed training, please use the `../../tools/launch.py` script to launch a cluster. -- For example, to run two workers and two servers with one machine, run `../../tools/launch.py -n 2 --launcher=local python linear_classification.py --kvstore=dist_async` - -## Matrix Factorization Using Sparse Embedding - -The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples. - -- `python matrix_factorization.py` -- To compare the train speed with (dense) Embedding, run `python matrix_factorization.py --use-dense` diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 3e66736..0dfeec5 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1538,3 +1538,34 @@ def discard_stderr(): finally: os.dup2(old_stderr, stderr_fileno) bit_bucket.close() + +class DummyIter(mx.io.DataIter): + """A dummy iterator that always returns the same batch of data + (the first data batch of the real data iter). This is usually used for speed testing. + + Parameters + ---------- + real_iter: mx.io.DataIter + The real data iterator where the first batch of data comes from + """ + def __init__(self, real_iter): + super(DummyIter, self).__init__() + self.real_iter = real_iter + self.provide_data = real_iter.provide_data + self.provide_label = real_iter.provide_label + self.batch_size = real_iter.batch_size + self.the_batch = next(real_iter) + + def __iter__(self): + return self + + def next(self): + """Get a data batch from iterator. The first data batch of real iter is always returned. + StopIteration will never be raised. + + Returns + ------- + DataBatch + The data of next batch. + """ + return self.the_batch -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].