stu1130 commented on a change in pull request #12131: [MXNET-737][WIP] Add last 
batch handle for imageiter
URL: https://github.com/apache/incubator-mxnet/pull/12131#discussion_r210995631
 
 

 ##########
 File path: tests/python/unittest/test_image.py
 ##########
 @@ -130,29 +129,81 @@ def test_color_normalize(self):
                 mx.nd.array(mean), mx.nd.array(std))
             assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, 
atol=1e-3)
 
-
     def test_imageiter(self):
         def check_imageiter(dtype='float32'):
             im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
-            test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, 
imglist=im_list,
-                path_root='', dtype=dtype)
-            for _ in range(3):
-                for batch in test_iter:
-                    pass
-                test_iter.reset()
-
-            # test with list file
             fname = './data/test_imageiter.lst'
-            file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
-                for k, x in enumerate(TestImage.IMAGES)]
+            file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+                         for k, x in enumerate(TestImage.IMAGES)]
             with open(fname, 'w') as f:
                 for line in file_list:
                     f.write(line + '\n')
+            
+            test_list = ['imglist', 'path_imglist']
 
-            test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, 
path_imglist=fname,
-                path_root='', dtype=dtype)
-            for batch in test_iter:
-                pass
+            for test in test_list:
+                imglist = im_list if test == 'imglist' else None
+                path_imglist = fname if test == 'path_imglist' else None
+                
+                test_iter = mx.image.ImageIter(2, (3, 224, 224), 
label_width=1, imglist=imglist, 
+                    path_imglist=path_imglist, path_root='', dtype=dtype)
+                # test batch data shape
+                for _ in range(3):
+                    for batch in test_iter:
+                        assert batch.data[0].shape == (2, 3, 224, 224)
+                    test_iter.reset()
+                # test last batch handle(discard)
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
+                    path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='discard')
+                i = 0
+                for batch in test_iter:
+                    i += 1
+                assert i == 5
+                # test last_batch_handle(pad)
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist, 
+                    path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='pad')
+                i = 0
+                for batch in test_iter:
+                    if i == 0:
+                        first_three_data = batch.data[0][:2]
+                    if i == 5:
+                        last_three_data = batch.data[0][1:]
+                    i += 1
+                assert i == 6
+                assert np.array_equal(first_three_data.asnumpy(), 
last_three_data.asnumpy())
+                # test last_batch_handle(roll_over)
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
+                    path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='roll_over')
+                i = 0
+                for batch in test_iter:
+                    if i == 0:
+                        first_image = batch.data[0][0]
+                    i += 1
+                assert i == 5
+                test_iter.reset()
+                first_batch_roll_over = test_iter.next()
+                assert np.array_equal(
+                    first_batch_roll_over.data[0][1].asnumpy(), 
first_image.asnumpy())
+                assert first_batch_roll_over.pad == 2
+                # test iteratopr work properly after calling reset several 
times when last_batch_handle is roll_over
+                for _ in test_iter:
+                    pass
+                test_iter.reset()
+                first_batch_roll_over_twice = test_iter.next()
+                assert np.array_equal(
+                    first_batch_roll_over_twice.data[0][2].asnumpy(), 
first_image.asnumpy())
+                assert first_batch_roll_over_twice.pad == 1
+                # we've called next once
+                i = 1
+                for _ in test_iter:
+                    i += 1
+                # test the third epoch with size 6
+                assert i == 6
+                # test shuffle option for sanity test
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist, shuffle=True,
 
 Review comment:
   add shuffle test case

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to