piiswrong closed pull request #10511: add naming tutorial URL: https://github.com/apache/incubator-mxnet/pull/10511
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/docs/tutorials/gluon/datasets.md b/docs/tutorials/gluon/datasets.md index 248ea02f5c1..0c9b5375d20 100644 --- a/docs/tutorials/gluon/datasets.md +++ b/docs/tutorials/gluon/datasets.md @@ -33,7 +33,7 @@ print(sample) ( [ 0.4375872 0.29753461 0.89177299] - <NDArray 3 @cpu(0)>, + <NDArray 3 @cpu(0)>, [ 0.83261985] <NDArray 1 @cpu(0)>) @@ -60,7 +60,7 @@ for X_batch, y_batch in data_loader: X_batch has shape (5, 3), and y_batch has shape (5, 1) -We can see 2 mini-batches of data (and labels), each with 5 samples, which makes sense given we started with a dataset of 10 samples. When comparing the shape of the batches to the samples returned by the [`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset), we've gained an extra dimension at the start which is sometimes called the batch axis. +We can see 2 mini-batches of data (and labels), each with 5 samples, which makes sense given we started with a dataset of 10 samples. When comparing the shape of the batches to the samples returned by the [`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset), we've gained an extra dimension at the start which is sometimes called the batch axis. Our `data_loader` loop will stop when every sample of `dataset` has been returned as part of a batch. Sometimes the dataset length isn't divisible by the mini-batch size, leaving a final batch with a smaller number of samples. [`DataLoader`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataloader#mxnet.gluon.data.DataLoader)'s default behavior is to return this smaller mini-batch, but this can be changed by setting the `last_batch` parameter to `discard` (which ignores the last batch) or `rollover` (which starts the next epoch with the remaining samples). @@ -137,7 +137,7 @@ def construct_net(): ctx = mx.cpu() net = construct_net() net.hybridize() -net.collect_params().initialize(mx.init.Xavier()) +net.initialize(mx.init.Xavier()) # define loss and trainer. criterion = gluon.loss.SoftmaxCrossEntropyLoss() trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1}) @@ -159,7 +159,7 @@ for epoch in range(epochs): cumulative_train_loss += loss.sum() training_samples += data.shape[0] train_loss = cumulative_train_loss.asscalar()/training_samples - + # validation loop cumulative_valid_loss = mx.nd.array([0]) valid_samples = 0 @@ -171,7 +171,7 @@ for epoch in range(epochs): cumulative_valid_loss += loss.sum() valid_samples += data.shape[0] valid_loss = cumulative_valid_loss.asscalar()/valid_samples - + print("Epoch {}, training loss: {:.2f}, validation loss: {:.2f}".format(epoch, train_loss, valid_loss)) ``` @@ -184,7 +184,7 @@ for epoch in range(epochs): # Using own data with included `Dataset`s -Gluon has a number of different [`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset) classes for working with your own image data straight out-of-the-box. You can get started quickly using the [`mxnet.gluon.data.vision.datasets.ImageFolderDataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=imagefolderdataset#mxnet.gluon.data.vision.datasets.ImageFolderDataset) which loads images directly from a user-defined folder, and infers the label (i.e. class) from the folders. +Gluon has a number of different [`Dataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=dataset#mxnet.gluon.data.Dataset) classes for working with your own image data straight out-of-the-box. You can get started quickly using the [`mxnet.gluon.data.vision.datasets.ImageFolderDataset`](https://mxnet.incubator.apache.org/api/python/gluon/data.html?highlight=imagefolderdataset#mxnet.gluon.data.vision.datasets.ImageFolderDataset) which loads images directly from a user-defined folder, and infers the label (i.e. class) from the folders. We will run through an example for image classification, but a similar process applies for other vision tasks. If you already have your own collection of images to work with you should partition your data into training and test sets, and place all objects of the same class into seperate folders. Similar to: @@ -307,4 +307,4 @@ data_iter_loader = DataIterLoader(data_iter) for X_batch, y_batch in data_iter_loader: assert X_batch.shape == (5, 3) assert y_batch.shape == (5, 1) -``` \ No newline at end of file +``` diff --git a/docs/tutorials/gluon/gluon.md b/docs/tutorials/gluon/gluon.md index a1688ea121d..518e99905c0 100644 --- a/docs/tutorials/gluon/gluon.md +++ b/docs/tutorials/gluon/gluon.md @@ -70,7 +70,7 @@ A network must be created and initialized before it can be used: net = Net() # Initialize on CPU. Replace with `mx.gpu(0)`, or `[mx.gpu(0), mx.gpu(1)]`, # etc to use one or more GPUs. -net.collect_params().initialize(mx.init.Xavier(), ctx=mx.cpu()) +net.initialize(mx.init.Xavier(), ctx=mx.cpu()) ``` Note that because we didn't specify input size to layers in Net's constructor, diff --git a/docs/tutorials/gluon/hybrid.md b/docs/tutorials/gluon/hybrid.md index 859ad934c7e..3554a15fa3b 100644 --- a/docs/tutorials/gluon/hybrid.md +++ b/docs/tutorials/gluon/hybrid.md @@ -77,7 +77,7 @@ is called, its `hybrid_forward` will be run: ```python net = Net() -net.collect_params().initialize() +net.initialize() x = mx.nd.random_normal(shape=(16, 1, 28, 28)) net(x) x = mx.nd.random_normal(shape=(16, 1, 28, 28)) @@ -117,7 +117,7 @@ x = mx.sym.var('data') y = net(x) print(y) y.save('model.json') -net.collect_params().save('model.params') +net.save_params('model.params') ``` If your network outputs more than one value, you can use `mx.sym.Group` to diff --git a/docs/tutorials/gluon/mnist.md b/docs/tutorials/gluon/mnist.md index 86c493b38fe..3a2a2cbe01b 100644 --- a/docs/tutorials/gluon/mnist.md +++ b/docs/tutorials/gluon/mnist.md @@ -102,7 +102,7 @@ initialized parameters. ```python gpus = mx.test_utils.list_gpus() ctx = [mx.gpu()] if gpus else [mx.cpu(0), mx.cpu(1)] -net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) +net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.02}) ``` @@ -252,10 +252,9 @@ Training and prediction can be done in the similar way as we did for MLP. We will initialize the network parameters as follows: ```python - # set the context on GPU is available otherwise CPU ctx = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()] -net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) +net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03}) ``` diff --git a/docs/tutorials/gluon/naming.md b/docs/tutorials/gluon/naming.md new file mode 100644 index 00000000000..37b63fa08a9 --- /dev/null +++ b/docs/tutorials/gluon/naming.md @@ -0,0 +1,255 @@ + +# Naming of Gluon Parameter and Blocks + +In gluon, each Parameter or Block has a name (and prefix). Parameter names are specified by users and Block names can be either specified by users or automatically created. + +In this tutorial we talk about the best practices on naming. First, let's import MXNet and Gluon: + + +```python +from __future__ import print_function +import mxnet as mx +from mxnet import gluon +``` + +## Naming Blocks + +When creating a block, you can assign a prefix to it: + + +```python +mydense = gluon.nn.Dense(100, prefix='mydense_') +print(mydense.prefix) +``` + + mydense_ + + +When no prefix is given, Gluon will automatically generate one: + + +```python +dense0 = gluon.nn.Dense(100) +print(dense0.prefix) +``` + + dense0_ + + +When you create more Blocks of the same kind, they will be named with incrementing suffixes to avoid collision: + + +```python +dense1 = gluon.nn.Dense(100) +print(dense1.prefix) +``` + + dense1_ + + +## Naming Parameters + +Parameters within a Block will be named by prepending the prefix of the Block to the name of the Parameter: + + +```python +print(dense0.collect_params()) +``` + + dense0_ ( + Parameter dense0_weight (shape=(100, 0), dtype=<type 'numpy.float32'>) + Parameter dense0_bias (shape=(100,), dtype=<type 'numpy.float32'>) + ) + + +## Name scopes + +To manage the names of nested Blocks, each Block has a `name_scope` attached to it. All Blocks created within a name scope will have its parent Block's prefix prepended to its name. + +Let's demonstrate this by first defining a simple neural net: + + +```python +class Model(gluon.Block): + def __init__(self, **kwargs): + super(Model, self).__init__(**kwargs) + with self.name_scope(): + self.dense0 = gluon.nn.Dense(20) + self.dense1 = gluon.nn.Dense(20) + self.mydense = gluon.nn.Dense(20, prefix='mydense_') + + def forward(self, x): + x = mx.nd.relu(self.dense0(x)) + x = mx.nd.relu(self.dense1(x)) + return mx.nd.relu(self.mydense(x)) +``` + +Now let's instantiate our neural net. + +- Note that `model0.dense0` is named as `model0_dense0_` instead of `dense0_`. + +- Also note that although we specified `mydense_` as prefix for `model.mydense`, its parent's prefix is automatically prepended to generate the prefix `model0_mydense_`. + + +```python +model0 = Model() +model0.initialize() +model0(mx.nd.zeros((1, 20))) +print(model0.prefix) +print(model0.dense0.prefix) +print(model0.dense1.prefix) +print(model0.mydense.prefix) +``` + + model0_ + model0_dense0_ + model0_dense1_ + model0_mydense_ + + +If we instantiate `Model` again, it will be given a different name like shown before for `Dense`. + +- Note that `model1.dense0` is still named as `dense0_` instead of `dense2_`, following dense layers in previously created `model0`. This is because each instance of model's name scope is independent of each other. + + +```python +model1 = Model() +print(model1.prefix) +print(model1.dense0.prefix) +print(model1.dense1.prefix) +print(model1.mydense.prefix) +``` + + model1_ + model1_dense0_ + model1_dense1_ + model1_mydense_ + + +**It is recommended that you manually specify a prefix for the top level Block, i.e. `model = Model(prefix='mymodel_')`, to avoid potential confusions in naming.** + +The same principle also applies to container blocks like Sequential. `name_scope` can be used inside `__init__` as well as out side of `__init__`: + + +```python +net = gluon.nn.Sequential() +with net.name_scope(): + net.add(gluon.nn.Dense(20)) + net.add(gluon.nn.Dense(20)) +print(net.prefix) +print(net[0].prefix) +print(net[1].prefix) +``` + + sequential0_ + sequential0_dense0_ + sequential0_dense1_ + + +`gluon.model_zoo` also behaves similarly: + + +```python +net = gluon.nn.Sequential() +with net.name_scope(): + net.add(gluon.model_zoo.vision.alexnet(pretrained=True)) + net.add(gluon.model_zoo.vision.alexnet(pretrained=True)) +print(net.prefix, net[0].prefix, net[1].prefix) +``` + + sequential1_ sequential1_alexnet0_ sequential1_alexnet1_ + + +## Saving and loading + +Because model0 and model1 have different prefixes, their parameters also have different names: + + +```python +print(model0.collect_params(), '\n') +print(model1.collect_params()) +``` + + model0_ ( + Parameter model0_dense0_weight (shape=(20L, 20L), dtype=<type 'numpy.float32'>) + Parameter model0_dense0_bias (shape=(20L,), dtype=<type 'numpy.float32'>) + Parameter model0_dense1_weight (shape=(20L, 20L), dtype=<type 'numpy.float32'>) + Parameter model0_dense1_bias (shape=(20L,), dtype=<type 'numpy.float32'>) + Parameter model0_mydense_weight (shape=(20L, 20L), dtype=<type 'numpy.float32'>) + Parameter model0_mydense_bias (shape=(20L,), dtype=<type 'numpy.float32'>) + ) + + model1_ ( + Parameter model1_dense0_weight (shape=(20, 0), dtype=<type 'numpy.float32'>) + Parameter model1_dense0_bias (shape=(20,), dtype=<type 'numpy.float32'>) + Parameter model1_dense1_weight (shape=(20, 0), dtype=<type 'numpy.float32'>) + Parameter model1_dense1_bias (shape=(20,), dtype=<type 'numpy.float32'>) + Parameter model1_mydense_weight (shape=(20, 0), dtype=<type 'numpy.float32'>) + Parameter model1_mydense_bias (shape=(20,), dtype=<type 'numpy.float32'>) + ) + + +As a result, if you try to save parameters from model0 and load it with model1, you'll get an error due to unmatching names: + + +```python +model0.collect_params().save('model.params') +try: + model1.collect_params().load('model.params', mx.cpu()) +except Exception as e: + print(e) +``` + + Parameter 'model1_dense0_weight' is missing in file 'model.params', which contains parameters: 'model0_mydense_weight', 'model0_dense1_bias', 'model0_dense1_weight', 'model0_dense0_weight', 'model0_dense0_bias', 'model0_mydense_bias'. Please make sure source and target networks have the same prefix. + + +To solve this problem, we use `save_params`/`load_params` instead of `collect_params` and `save`/`load`. `save_params` uses model structure, instead of parameter name, to match parameters. + + +```python +model0.save_params('model.params') +model1.load_params('model.params') +print(mx.nd.load('model.params').keys()) +``` + + ['dense0.bias', 'mydense.bias', 'dense1.bias', 'dense1.weight', 'dense0.weight', 'mydense.weight'] + + +## Replacing Blocks from networks and fine-tuning + +Sometimes you may want to load a pretrained model, and replace certain Blocks in it for fine-tuning. + +For example, the alexnet in model zoo has 1000 output dimensions, but maybe you only have 100 classes in your application. + +To see how to do this, we first load a pretrained AlexNet. + +- In Gluon model zoo, all image classification models follow the format where the feature extraction layers are named `features` while the output layer is named `output`. +- Note that the output layer is a dense block with 1000 dimension outputs. + + +```python +alexnet = gluon.model_zoo.vision.alexnet(pretrained=True) +print(alexnet.output) +print(alexnet.output.prefix) +``` + + Dense(4096 -> 1000, linear) + alexnet0_dense2_ + + +To change the output to 100 dimension, we replace it with a new block. + + +```python +with alexnet.name_scope(): + alexnet.output = gluon.nn.Dense(100) +alexnet.output.initialize() +print(alexnet.output) +print(alexnet.output.prefix) +``` + + Dense(None -> 100, linear) + alexnet0_dense3_ + + +<!-- INSERT SOURCE DOWNLOAD BUTTONS --> diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 00a15046ab9..04b7893c619 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -100,6 +100,8 @@ The Gluon and Module tutorials are in Python, but you can also find a variety of - [Designing a custom layer with gluon](http://gluon.mxnet.io/chapter03_deep-neural-networks/custom-layer.html) +- [Block and Parameter naming](/tutorials/gluon/naming.html) + - [Fast, portable neural networks with Gluon HybridBlocks](http://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html) - [Training on multiple GPUs with gluon](http://gluon.mxnet.io/chapter07_distributed-learning/multiple-gpus-gluon.html) diff --git a/docs/tutorials/onnx/fine_tuning_gluon.md b/docs/tutorials/onnx/fine_tuning_gluon.md index 4116ff631eb..c3015428ad7 100644 --- a/docs/tutorials/onnx/fine_tuning_gluon.md +++ b/docs/tutorials/onnx/fine_tuning_gluon.md @@ -7,7 +7,7 @@ Fine-tuning is a common practice in Transfer Learning. One can take advantage of [Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types. In this tutorial we will: - + - learn how to pick a specific layer from a pre-trained .onnx model file - learn how to load this model in Gluon and fine-tune it on a different dataset @@ -63,7 +63,7 @@ We download a pre-trained model, in our case the [vgg16](https://arxiv.org/abs/1 ```python -base_url = "https://s3.amazonaws.com/download.onnx/models/" +base_url = "https://s3.amazonaws.com/download.onnx/models/" current_model = "vgg16" model_folder = "model" archive_file = "{}.tar.gz".format(current_model) @@ -135,7 +135,7 @@ We transform the dataset images using the following operations: def transform(image, label): resized = mx.image.resize_short(image, EDGE) cropped, crop_info = mx.image.center_crop(resized, SIZE) - transposed = nd.transpose(cropped, (2,0,1)) + transposed = nd.transpose(cropped, (2,0,1)) return transposed, label ``` @@ -162,7 +162,7 @@ We use num_workers=Number of CPU cores, which means the dataloading and pre-proc ```python dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, last_batch='discard', shuffle=True, num_workers=NUM_WORKERS) -dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, last_batch='discard', +dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, last_batch='discard', shuffle=True, num_workers=NUM_WORKERS) print("Train dataset: {} images, Test dataset: {} images".format(len(dataset_train), len(dataset_test))) ``` @@ -274,7 +274,7 @@ We create the new dense layer with the right new number of classes (101) and ini ```python dense_layer = gluon.nn.Dense(NUM_CLASSES) -dense_layer.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) +dense_layer.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) ``` We add the SymbolBlock and the new dense layer to a HybridSequential network @@ -309,8 +309,8 @@ The trainer will retrain and fine-tune the entire network. If we use `dense_laye ```python trainer = gluon.Trainer(net.collect_params(), 'sgd', - {'learning_rate': LEARNING_RATE, - 'wd':WDECAY, + {'learning_rate': LEARNING_RATE, + 'wd':WDECAY, 'momentum':MOMENTUM}) ``` @@ -353,20 +353,20 @@ for epoch in range(20): for i, (data, label) in enumerate(dataloader_train): data = data.astype(np.float32).as_in_context(ctx) label = label.as_in_context(ctx) - + if i%20==0 and i >0: print('Batch [{0}] loss: {1:.4f}'.format(i, loss.mean().asscalar())) - + with autograd.record(): output = net(data) loss = softmax_cross_entropy(output, label) loss.backward() trainer.step(data.shape[0]) - + nd.waitall() # wait at the end of the epoch new_val_accuracy = evaluate_accuracy_gluon(dataloader_test, net) print("Epoch [{0}] Test Accuracy {1:.4f} ".format(epoch, new_val_accuracy)) - + # We perform early-stopping regularization, to prevent the model from overfitting if val_accuracy > new_val_accuracy: print('Validation accuracy is decreasing, stopping training') @@ -385,7 +385,7 @@ Let's see if our network fine-tuned on Caltech101 is up for the task: ```python # Number of predictions to show -TOP_P = 3 +TOP_P = 3 ``` diff --git a/example/gluon/embedding_learning/train.py b/example/gluon/embedding_learning/train.py index 269caff414c..46f76b55614 100644 --- a/example/gluon/embedding_learning/train.py +++ b/example/gluon/embedding_learning/train.py @@ -246,7 +246,7 @@ def train(epochs, ctx): if val_accs[0] > best_val: best_val = val_accs[0] logging.info('Saving %s.' % opt.save_model_prefix) - net.collect_params().save('%s.params' % opt.save_model_prefix) + net.save_params('%s.params' % opt.save_model_prefix) return best_val diff --git a/example/gluon/kaggle_k_fold_cross_validation.py b/example/gluon/kaggle_k_fold_cross_validation.py index 7911e4d1a01..420e6fc53c8 100644 --- a/example/gluon/kaggle_k_fold_cross_validation.py +++ b/example/gluon/kaggle_k_fold_cross_validation.py @@ -88,7 +88,7 @@ def train(net, X_train, y_train, epochs, verbose_epoch, learning_rate, trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': learning_rate, 'wd': weight_decay}) - net.collect_params().initialize(force_reinit=True) + net.initialize(force_reinit=True) for epoch in range(epochs): for data, label in data_iter_train: with autograd.record(): diff --git a/example/gluon/learning_rate_manipulation.py b/example/gluon/learning_rate_manipulation.py index 1523102b795..be1ffc29024 100644 --- a/example/gluon/learning_rate_manipulation.py +++ b/example/gluon/learning_rate_manipulation.py @@ -32,13 +32,13 @@ net = gluon.nn.Sequential() # The output dimension is 1. net.add(gluon.nn.Dense(1)) -net.collect_params().initialize() +net.initialize() loss = gluon.loss.L2Loss() # Initialize the learning rate as 0.1. trainer = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params={'learning_rate': 0.1}) -net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), +net.initialize(mx.init.Xavier(magnitude=2.24), force_reinit=True) train_data = mx.io.NDArrayIter(X, Y, batch_size=10, shuffle=True) @@ -60,4 +60,4 @@ for para_name, para_value in net.collect_params().items(): # Print all the parameter values after training. - print(para_name, para_value.data().asnumpy()[0]) \ No newline at end of file + print(para_name, para_value.data().asnumpy()[0]) diff --git a/example/gluon/lstm_crf.py b/example/gluon/lstm_crf.py index 857bfca5618..561b4c24bb6 100644 --- a/example/gluon/lstm_crf.py +++ b/example/gluon/lstm_crf.py @@ -197,7 +197,7 @@ def forward(self, sentence): # dont confuse this with _forward_alg above. tag2idx = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4} model = BiLSTM_CRF(len(word2idx), tag2idx, EMBEDDING_DIM, HIDDEN_DIM) -model.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu()) +model.initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu()) optimizer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 1e-4}) # Check predictions before training diff --git a/example/gluon/style_transfer/main.py b/example/gluon/style_transfer/main.py index fa21a3695de..7fcc927f9cb 100644 --- a/example/gluon/style_transfer/main.py +++ b/example/gluon/style_transfer/main.py @@ -54,7 +54,7 @@ def train(args): style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx) if args.resume is not None: print('Resuming, initializing using weight from {}.'.format(args.resume)) - style_model.collect_params().load(args.resume, ctx=ctx) + style_model.load_params(args.resume, ctx=ctx) print('style_model:',style_model) # optimizer and loss trainer = gluon.Trainer(style_model.collect_params(), 'adam', @@ -96,7 +96,7 @@ def train(args): total_loss = content_loss + style_loss total_loss.backward() - + trainer.step(args.batch_size) mx.nd.waitall() @@ -112,20 +112,20 @@ def train(args): ) print(mesg) - + if (batch_id + 1) % (4 * args.log_interval) == 0: # save model save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".params" save_model_path = os.path.join(args.save_model_dir, save_model_filename) - style_model.collect_params().save(save_model_path) + style_model.save_params(save_model_path) print("\nCheckpoint, trained model saved at", save_model_path) # save model save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".params" save_model_path = os.path.join(args.save_model_dir, save_model_filename) - style_model.collect_params().save(save_model_path) + style_model.save_params(save_model_path) print("\nDone, trained model saved at", save_model_path) @@ -140,7 +140,7 @@ def evaluate(args): style_image = utils.preprocess_batch(style_image) # model style_model = net.Net(ngf=args.ngf) - style_model.collect_params().load(args.model, ctx=ctx) + style_model.load_params(args.model, ctx=ctx) # forward style_model.setTarget(style_image) output = style_model(content_image) @@ -195,7 +195,7 @@ def optimize(args): trainer.step(1) if (e + 1) % args.log_interval == 0: print('loss:{:.2f}'.format(total_loss.asnumpy()[0])) - + # save the image output = utils.add_imagenet_mean_batch(output.data()) utils.tensor_save_bgrimage(output[0], args.output_image, args.cuda) @@ -209,7 +209,7 @@ def main(): raise ValueError("ERROR: specify the experiment type") if args.subcommand == "train": - # Training the model + # Training the model train(args) elif args.subcommand == 'eval': diff --git a/example/gluon/super_resolution.py b/example/gluon/super_resolution.py index 7963590c6db..38c3bec8949 100644 --- a/example/gluon/super_resolution.py +++ b/example/gluon/super_resolution.py @@ -144,7 +144,7 @@ def train(epoch, ctx): ctx = [ctx] net.initialize(mx.init.Orthogonal(), ctx=ctx) # re-initialize conv4's weight to be Orthogonal - net.conv4.collect_params().initialize(mx.init.Orthogonal(scale=1), force_reinit=True, ctx=ctx) + net.conv4.initialize(mx.init.Orthogonal(scale=1), force_reinit=True, ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': opt.lr}) loss = gluon.loss.L2Loss() diff --git a/example/gluon/tree_lstm/main.py b/example/gluon/tree_lstm/main.py index 67644f97d38..d2fe464638a 100644 --- a/example/gluon/tree_lstm/main.py +++ b/example/gluon/tree_lstm/main.py @@ -138,7 +138,7 @@ def test(ctx, data_iter, best, mode='validation', num_iter=-1): if test_r >= best: best = test_r logging.info('New optimum found: {}. Checkpointing.'.format(best)) - net.collect_params().save('childsum_tree_lstm_{}.params'.format(num_iter)) + net.save_params('childsum_tree_lstm_{}.params'.format(num_iter)) test(ctx, test_iter, -1, 'test') return best @@ -148,7 +148,7 @@ def train(epoch, ctx, train_data, dev_data): # initialization with context if isinstance(ctx, mx.Context): ctx = [ctx] - net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx[0]) + net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx[0]) net.embed.weight.set_data(vocab.embed.as_in_context(ctx[0])) train_data.set_context(ctx[0]) dev_data.set_context(ctx[0]) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 64003585dba..2f8cdd80fc7 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -23,14 +23,14 @@ import copy import warnings import re +from collections import OrderedDict from .. import symbol, ndarray, initializer from ..symbol import Symbol from ..ndarray import NDArray from .. import name as _name -from ..context import cpu from .parameter import Parameter, ParameterDict, DeferredInitializationError -from .utils import _indent +from .utils import _indent, _brief_print_list class _BlockScope(object): @@ -134,7 +134,6 @@ class Model(Block): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) # use name_scope to give child Blocks appropriate names. - # It also allows sharing Parameters between Blocks recursively. with self.name_scope(): self.dense0 = nn.Dense(20) self.dense1 = nn.Dense(20) @@ -154,10 +153,11 @@ def forward(self, x): Parameters ---------- prefix : str - Prefix acts like a name space. It will be prepended to the names of all - Parameters and child :py:class:`Block` s in this :py:class:`Block` 's - :py:meth:`name_scope` . - Prefix should be unique within one model to prevent name collisions. + Prefix acts like a name space. All children blocks created in parent block's + :py:meth:`name_scope` will have parent block's prefix in their name. + Please refer to + `naming tutorial <http://mxnet.incubator.apache.org/tutorials/basic/naming.html>`_ + for more info on prefix and naming. params : ParameterDict or None :py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example, if you want ``dense1`` to share ``dense0``'s weights, you can do:: @@ -170,15 +170,15 @@ def __init__(self, prefix=None, params=None): self._prefix, self._params = _BlockScope.create(prefix, params, self._alias()) self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix self._scope = _BlockScope(self) - self._children = [] + self._children = OrderedDict() + self._reg_params = {} def __repr__(self): s = '{name}(\n{modstr}\n)' modstr = '\n'.join([' ({key}): {block}'.format(key=key, block=_indent(block.__repr__(), 2)) for key, block in self.__dict__.items() if isinstance(block, Block)]) - return s.format(name=self.__class__.__name__, - modstr=modstr) + return s.format(name=self.__class__.__name__, modstr=modstr) def __setattr__(self, name, value): """Registers parameters.""" @@ -187,17 +187,17 @@ def __setattr__(self, name, value): existing = getattr(self, name) if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)): raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \ - 'is not allowed.'.format(name=name, - type1=type(existing), - type2=type(value))) - if isinstance(existing, Block): - for i, c in enumerate(self._children): - if c is existing: - self._children[i] = value - elif isinstance(value, Block): - self.register_child(value) - elif isinstance(value, Block): - self.register_child(value) + 'is not allowed.'.format( + name=name, type1=type(existing), type2=type(value))) + + if isinstance(value, Block): + self.register_child(value, name) + elif isinstance(value, Parameter): + assert name not in self._reg_params, \ + "Overriding Parameter attribute %s is not allowed. " \ + "If you want to share parameters between blocks, please set " \ + "'params' at Block construction instead." + self._reg_params[name] = value super(Block, self).__setattr__(name, value) @@ -247,6 +247,10 @@ def name_scope(self): with self.name_scope(): self.dense = nn.Dense(20) + + Please refer to + `naming tutorial <http://mxnet.incubator.apache.org/tutorials/basic/naming.html>`_ + for more info on prefix and naming. """ return self._scope @@ -288,19 +292,29 @@ def collect_params(self, select=None): else: pattern = re.compile(select) ret.update({name:value for name, value in self.params.items() if pattern.match(name)}) - for cld in self._children: + for cld in self._children.values(): ret.update(cld.collect_params(select=select)) return ret + def _collect_params_with_prefix(self, prefix=''): + if prefix: + prefix += '.' + ret = {prefix + key : val for key, val in self._reg_params.items()} + for name, child in self._children.items(): + ret.update(child._collect_params_with_prefix(prefix + name)) + return ret + def save_params(self, filename): """Save parameters to file. filename : str Path to file. """ - self.collect_params().save(filename, strip_prefix=self.prefix) + params = self._collect_params_with_prefix() + arg_dict = {key : val._reduce() for key, val in params.items()} + ndarray.save(filename, arg_dict) - def load_params(self, filename, ctx=cpu(), allow_missing=False, + def load_params(self, filename, ctx=None, allow_missing=False, ignore_extra=False): """Load parameters from file. @@ -314,20 +328,58 @@ def load_params(self, filename, ctx=cpu(), allow_missing=False, Whether to silently ignore parameters from the file that are not present in this Block. """ - self.collect_params().load(filename, ctx, allow_missing, ignore_extra, - self.prefix) + loaded = ndarray.load(filename) + params = self._collect_params_with_prefix() + if not loaded and not params: + return - def register_child(self, block): + if not any('.' in i for i in loaded.keys()): + # legacy loading + del loaded + self.collect_params().load( + filename, ctx, allow_missing, ignore_extra, self.prefix) + return + + if not allow_missing: + for name in params.keys(): + assert name in loaded, \ + "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \ + "Set allow_missing=True to ignore missing parameters."%( + name, filename, _brief_print_list(loaded.keys())) + for name in loaded: + if not ignore_extra and name not in params: + raise ValueError( + "Parameter '%s' loaded from file '%s' is not present in ParameterDict, " \ + "which contains parameters %s. Set ignore_extra=True to ignore. "%( + name, filename, _brief_print_list(self._params.keys()))) + params[name]._load_init(loaded[name], ctx) + + + def register_child(self, block, name=None): """Registers block as a child of self. :py:class:`Block` s assigned to self as attributes will be registered automatically.""" - self._children.append(block) + if name is None: + name = str(len(self._children)) + self._children[name] = block - def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False): + def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, + force_reinit=False): """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children. - Equivalent to ``block.collect_params().initialize(...)`` + + Parameters + ---------- + init : Initializer + Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``. + Otherwise, :py:meth:`Parameter.init` takes precedence. + ctx : Context or list of Context + Keeps a copy of Parameters on one or many context(s). + verbose : bool, default False + Whether to verbosely print out details on initialization. + force_reinit : bool, default False + Whether to force re-initialization if parameter is already initialized. """ - self.collect_params().initialize(init, ctx, verbose) + self.collect_params().initialize(init, ctx, verbose, force_reinit) def hybridize(self, active=True, **kwargs): """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on @@ -340,7 +392,7 @@ def hybridize(self, active=True, **kwargs): **kwargs : string Additional flags for hybridized operator. """ - for cld in self._children: + for cld in self._children.values(): cld.hybridize(active, **kwargs) def cast(self, dtype): @@ -351,7 +403,7 @@ def cast(self, dtype): dtype : str or numpy.dtype The new data type. """ - for child in self._children: + for child in self._children.values(): child.cast(dtype) for _, param in self.params.items(): param.cast(dtype) @@ -393,7 +445,6 @@ class HybridBlock(Block): """ def __init__(self, prefix=None, params=None): super(HybridBlock, self).__init__(prefix=prefix, params=params) - self._reg_params = {} self._cached_graph = () self._cached_op = None self._cached_op_args = None @@ -407,13 +458,6 @@ def __setattr__(self, name, value): super(HybridBlock, self).__setattr__(name, value) if isinstance(value, HybridBlock): self._clear_cached_op() - if isinstance(value, Parameter): - assert name not in self._reg_params or \ - not isinstance(self._reg_params[name], Parameter), \ - "Overriding Parameter attribute %s is not allowed. " \ - "Please pass in Parameters by specifying `params` at " \ - "Block construction instead." - self._reg_params[name] = value def _get_graph(self, *args): if not self._cached_graph: @@ -491,14 +535,14 @@ def _clear_cached_op(self): self._cached_op = None self._cached_op_args = None - def register_child(self, block): + def register_child(self, block, name=None): if not isinstance(block, HybridBlock): raise ValueError( "Children of HybridBlock must also be HybridBlock, " \ "but %s has type %s. If you are using Sequential, " \ - "please try HybridSequential instead"%( + "please try HybridSequential instead."%( str(block), str(type(block)))) - super(HybridBlock, self).register_child(block) + super(HybridBlock, self).register_child(block, name) self._clear_cached_op() def hybridize(self, active=True, **kwargs): diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 88708884c51..eccdf18c1bb 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -51,7 +51,7 @@ def __init__(self, axis=-1, prefix=None, params=None): def forward(self, x): out = [] - for block in self._children: + for block in self._children.values(): out.append(block(x)) out = nd.concat(*out, dim=self.axis) return out @@ -84,7 +84,7 @@ def __init__(self, axis=-1, prefix=None, params=None): def hybrid_forward(self, F, x): out = [] - for block in self._children: + for block in self._children.values(): out.append(block(x)) out = F.concat(*out, dim=self.axis) return out diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index f6113cc52b4..efca0c3d252 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -49,7 +49,7 @@ def add(self, *blocks): self.register_child(block) def forward(self, x): - for block in self._children: + for block in self._children.values(): x = block(x) return x @@ -57,13 +57,12 @@ def __repr__(self): s = '{name}(\n{modstr}\n)' modstr = '\n'.join([' ({key}): {block}'.format(key=key, block=_indent(block.__repr__(), 2)) - for key, block in enumerate(self._children) - if isinstance(block, Block)]) + for key, block in self._children.items()]) return s.format(name=self.__class__.__name__, modstr=modstr) def __getitem__(self, key): - return self._children[key] + return self._children[str(key)] def __len__(self): return len(self._children) @@ -79,9 +78,10 @@ def hybridize(self, active=True, **kwargs): **kwargs : string Additional flags for hybridized operator. """ - if self._children and all(isinstance(c, HybridBlock) for c in self._children): - warnings.warn('All children of this Sequential layer are HybridBlocks. Consider ' \ - 'using HybridSequential for the best performance.', stacklevel=2) + if self._children and all(isinstance(c, HybridBlock) for c in self._children.values()): + warnings.warn( + "All children of this Sequential layer '%s' are HybridBlocks. Consider " + "using HybridSequential for the best performance."%self.prefix, stacklevel=2) super(Sequential, self).hybridize(active, **kwargs) @@ -106,7 +106,7 @@ def add(self, *blocks): self.register_child(block) def hybrid_forward(self, F, x): - for block in self._children: + for block in self._children.values(): x = block(x) return x @@ -114,13 +114,12 @@ def __repr__(self): s = '{name}(\n{modstr}\n)' modstr = '\n'.join([' ({key}): {block}'.format(key=key, block=_indent(block.__repr__(), 2)) - for key, block in enumerate(self._children) - if isinstance(block, Block)]) + for key, block in self._children.items()]) return s.format(name=self.__class__.__name__, modstr=modstr) def __getitem__(self, key): - return self._children[key] + return self._children[str(key)] def __len__(self): return len(self._children) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 8d0c5ba7af4..ce82171a087 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -29,9 +29,9 @@ from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, context -from ..context import Context +from ..context import Context, cpu from .. import autograd -from .utils import _indent +from .utils import _indent, _brief_print_list # pylint: disable= invalid-name tensor_types = (symbol.Symbol, ndarray.NDArray) @@ -206,13 +206,16 @@ def _load_init(self, data, ctx): ctx = [ctx] if self._data is None: if self._deferred_init: - assert set(ctx) == set(self._deferred_init[1]), \ + assert ctx is None or set(ctx) == set(self._deferred_init[1]), \ "Failed to load Parameter '%s' on %s because it was " \ "previous initialized on %s."%( self.name, str(ctx), str(self.list_ctx())) + ctx = self._deferred_init[1] + elif ctx is None: + ctx = [cpu()] self._init_impl(data, ctx) else: - assert set(ctx) == set(self.list_ctx()), \ + assert ctx is None or set(ctx) == set(self.list_ctx()), \ "Failed to load Parameter '%s' on %s because it was " \ "previous initialized on %s."%( self.name, str(ctx), str(self.list_ctx())) @@ -497,13 +500,9 @@ def _init_weight(self, _, arr): name, grad_req='null', shape=value.shape, dtype=value.dtype, init=init_name) - -def _brief_print_list(lst, limit=7): - """Print at most `limit` elements of list.""" - if len(lst) > limit: - return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \ - _brief_print_list(lst[-limit//2:], limit) - return ', '.join(["'%s'"%str(i) for i in lst]) + def __repr__(self): + s = 'Constant {name} (shape={shape}, dtype={dtype})' + return s.format(name=self.name, shape=self.shape, dtype=self.dtype) class ParameterDict(object): @@ -677,6 +676,8 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, Otherwise, :py:meth:`Parameter.init` takes precedence. ctx : Context or list of Context Keeps a copy of Parameters on one or many context(s). + verbose : bool, default False + Whether to verbosely print out details on initialization. force_reinit : bool, default False Whether to force re-initialization if parameter is already initialized. """ @@ -735,17 +736,17 @@ def save(self, filename, strip_prefix=''): weight = param._reduce() if not param.name.startswith(strip_prefix): raise ValueError( - "Prefix '%s' is to be striped before saving, but Parameter " \ - "'%s' does not start with '%s'. If you are using Block.save_params, " \ - "This may be due to your Block shares parameters from other " \ - "Blocks or you forgot to use ``with name_scope()`` during init. " \ - "Consider switching to Block.collect_params.save and " \ - "Block.collect_params.load instead."%( + "Prefix '%s' is to be striped before saving, but Parameter's " + "name '%s' does not start with '%s'. " + "this may be due to your Block shares parameters from other " + "Blocks or you forgot to use 'with name_scope()' when creating " + "child blocks. For more info on naming, please see " + "http://mxnet.incubator.apache.org/tutorials/basic/naming.html"%( strip_prefix, param.name, strip_prefix)) arg_dict[param.name[len(strip_prefix):]] = weight ndarray.save(filename, arg_dict) - def load(self, filename, ctx, allow_missing=False, + def load(self, filename, ctx=None, allow_missing=False, ignore_extra=False, restore_prefix=''): """Load parameters from file. diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index f5c72f5f3e7..281aba45257 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -124,7 +124,7 @@ def reset(self): """Reset before re-using the cell for another graph.""" self._init_counter = -1 self._counter = -1 - for cell in self._children: + for cell in self._children.values(): cell.reset() def state_info(self, batch_size=0): @@ -639,7 +639,7 @@ def __repr__(self): s = '{name}(\n{modstr}\n)' return s.format(name=self.__class__.__name__, modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2)) - for i, m in enumerate(self._children)])) + for i, m in self._children.items()])) def add(self, cell): """Appends a cell into the stack. @@ -652,19 +652,19 @@ def add(self, cell): self.register_child(cell) def state_info(self, batch_size=0): - return _cells_state_info(self._children, batch_size) + return _cells_state_info(self._children.values(), batch_size) def begin_state(self, **kwargs): assert not self._modified, \ "After applying modifier cells (e.g. ZoneoutCell) the base " \ "cell cannot be called directly. Call the modifier cell instead." - return _cells_begin_state(self._children, **kwargs) + return _cells_begin_state(self._children.values(), **kwargs) def __call__(self, inputs, states): self._counter += 1 next_states = [] p = 0 - for cell in self._children: + for cell in self._children.values(): assert not isinstance(cell, BidirectionalCell) n = len(cell.state_info()) state = states[p:p+n] @@ -683,7 +683,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N p = 0 next_states = [] - for i, cell in enumerate(self._children): + for i, cell in enumerate(self._children.values()): n = len(cell.state_info()) states = begin_state[p:p+n] p += n @@ -696,7 +696,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N return inputs, next_states def __getitem__(self, i): - return self._children[i] + return self._children[str(i)] def __len__(self): return len(self._children) @@ -900,8 +900,8 @@ class BidirectionalCell(HybridRecurrentCell): """ def __init__(self, l_cell, r_cell, output_prefix='bi_'): super(BidirectionalCell, self).__init__(prefix='', params=None) - self.register_child(l_cell) - self.register_child(r_cell) + self.register_child(l_cell, 'l_cell') + self.register_child(r_cell, 'r_cell') self._output_prefix = output_prefix def __call__(self, inputs, states): @@ -910,17 +910,17 @@ def __call__(self, inputs, states): def __repr__(self): s = '{name}(forward={l_cell}, backward={r_cell})' return s.format(name=self.__class__.__name__, - l_cell=self._children[0], - r_cell=self._children[1]) + l_cell=self._children['l_cell'], + r_cell=self._children['r_cell']) def state_info(self, batch_size=0): - return _cells_state_info(self._children, batch_size) + return _cells_state_info(self._children.values(), batch_size) def begin_state(self, **kwargs): assert not self._modified, \ "After applying modifier cells (e.g. DropoutCell) the base " \ "cell cannot be called directly. Call the modifier cell instead." - return _cells_begin_state(self._children, **kwargs) + return _cells_begin_state(self._children.values(), **kwargs) def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, valid_length=None): @@ -938,7 +938,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) states = begin_state - l_cell, r_cell = self._children + l_cell, r_cell = self._children.values() l_outputs, l_states = l_cell.unroll(length, inputs=inputs, begin_state=states[:len(l_cell.state_info(batch_size))], layout=layout, merge_outputs=merge_outputs, diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index cb784b7480d..7dd2a1a2f5e 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -242,3 +242,10 @@ def _get_repo_file_url(namespace, filename): return '{base_url}{namespace}/{filename}'.format(base_url=_get_repo_url(), namespace=namespace, filename=filename) + +def _brief_print_list(lst, limit=7): + """Print at most `limit` elements of list.""" + if len(lst) > limit: + return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \ + _brief_print_list(lst[-limit//2:], limit) + return ', '.join(["'%s'"%str(i) for i in lst]) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d91b3f02cd3..ca1e121008d 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -557,7 +557,7 @@ def test_block_attr_regular(): b.c = gluon.Block() c2 = gluon.Block() b.c = c2 - assert b.c is c2 and b._children[0] is c2 + assert b.c is c2 and list(b._children.values())[0] is c2 @with_seed() @@ -589,18 +589,22 @@ def __init__(self, **kwargs): self.data = {'a': '4', 'b': 123} with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') model = Model1() model.collect_params() assert len(w) > 0 with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') model = Model2() model.collect_params() assert len(w) > 0 with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') model = Model3() model.collect_params() assert len(w) == 0 with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') model = Model4() model.collect_params() assert len(w) == 0 @@ -882,6 +886,14 @@ def check_dropout_axes(ratio, shape, axes): check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) +def test_save_load(): + net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True) + net.save_params('test.params') + + net = mx.gluon.model_zoo.vision.get_resnet(1, 18) + net.output = mx.gluon.nn.Dense(1000) + + net.load_params('test.params') if __name__ == '__main__': ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services