abhinavs95 commented on a change in pull request #14587: [MXNET-1344, 1346][FIT 
API] Retrieve Batch size and Logging verbose support for Gluon fit() API
URL: https://github.com/apache/incubator-mxnet/pull/14587#discussion_r271964229
 
 

 ##########
 File path: tests/python/unittest/test_gluon_estimator.py
 ##########
 @@ -275,3 +278,44 @@ def test_context():
                                   loss=loss,
                                   metrics=metrics,
                                   context='cpu')
+
+
+def test_batch_size():
+    '''Test batch size'''
+    num_samples = 32
+
+    # No Data Loader
+    data = mx.nd.random.uniform(shape=(num_samples, 3, 28, 28))
+    label = mx.nd.random.randint(low=0, high=2, shape=(num_samples,))
+    data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=16)
+    net = get_model()
+    loss = mx.gluon.loss.L2Loss()
+    ctx = mx.cpu()
+    est = estimator.Estimator(net=net, loss=loss, context=ctx)
+    with assert_raises(ValueError):
+        est.fit(train_data=data_iter)
+
+    # Empty data loader
+    data = mx.nd.random.uniform(shape=(0,))
+    label = mx.nd.random.randint(low=0, high=2, shape=(0,))
+    batch_size = 2
+    data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
+    data_loader = mx.gluon.data.DataLoader(data_arr, batch_size=batch_size)
+    est = estimator.Estimator(net=net, loss=loss, context=ctx)
+    with assert_raises(ValueError):
+        est.fit(train_data=data_loader)
+
+    # Batch size less than context
+    ctx = [mx.gpu(i) for i in range(4)]
+    data = mx.nd.random.uniform(shape=(num_samples, 3, 28, 28))
+    label = mx.nd.random.randint(low=0, high=2, shape=(num_samples,))
+    batch_size = 2
+    data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
+    data_loader = mx.gluon.data.DataLoader(data_arr, batch_size=batch_size)
+    est = estimator.Estimator(net=net, loss=loss, context=ctx)
 
 Review comment:
   A new net needs to be created here, otherwise it will use the one created 
before and the test won't be accurate.
   
   The same problem is also present in the pre-existing test_context() test

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to