This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 84d61a1  [Perl] - ndarray to native array conversion fix (#16635)
84d61a1 is described below

commit 84d61a1df3eca95be68c15d39fe057064a4da018
Author: Robert Stone <ta...@trap.mtview.ca.us>
AuthorDate: Sun Oct 27 12:31:50 2019 -0700

    [Perl] - ndarray to native array conversion fix (#16635)
---
 perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm |  6 +++++-
 perl-package/AI-MXNet/t/test_ndarray.t        | 19 ++++++++++++++++++-
 2 files changed, 23 insertions(+), 2 deletions(-)

diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm 
b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
index f75cc84..1d968c1 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
@@ -116,7 +116,11 @@ method STORABLE_thaw($cloning, $buf, $writable)
 
 method split_array(@args)
 {
-     $self->shape->[0] > 1 ? $self->split(num_outputs => $self->shape->[0], 
squeeze_axis => @{ $self->shape } > 1 ? 1 : 0, axis => 0) : [$self];
+    my $shape = $self->shape;
+    return [] if $shape->[0] == 0;
+    my $list = $self->split(num_outputs=>$shape->[0],
+        squeeze_axis=>int(@$shape > 1), axis=>0);
+    $shape->[0] == 1 ? [ $list ] : $list;
 }
 
 method at(Index @indices)
diff --git a/perl-package/AI-MXNet/t/test_ndarray.t 
b/perl-package/AI-MXNet/t/test_ndarray.t
index a6cd113..1e290b4 100644
--- a/perl-package/AI-MXNet/t/test_ndarray.t
+++ b/perl-package/AI-MXNet/t/test_ndarray.t
@@ -19,7 +19,7 @@ use strict;
 use warnings;
 use AI::MXNet qw(mx);
 use AI::MXNet::TestUtils qw(almost_equal same rand_ndarray randint zip);
-use Test::More tests => 251;
+use Test::More tests => 261;
 use PDL;
 use File::Temp qw(tempdir);
 use IO::File;
@@ -217,6 +217,22 @@ sub test_histogram
     ok(same($bins->aspdl, pdl([10, 20, 30, 60])));
 }
 
+sub test_array_overload
+{
+    # array conversions are largely calls to mx->nd->split(), but have
+    # special cases around dimensions of length 0 and 1.
+    is_deeply([ @{ mx->nd->array(zeros(7, 0)) } ], []);
+    is_deeply(mx->nd->zeros([3, 7])->[0]->shape, [ 7 ]);
+    is_deeply(mx->nd->zeros([2, 7])->[0]->shape, [ 7 ]);
+    is_deeply(mx->nd->zeros([1, 7])->[0]->shape, [ 7 ]);
+    is_deeply(mx->nd->zeros([3, 7, 11])->[0]->shape, [7, 11]);
+    is_deeply(mx->nd->zeros([2, 7, 11])->[0]->shape, [7, 11]);
+    is_deeply(mx->nd->zeros([1, 7, 11])->[0]->shape, [7, 11]);
+    is_deeply(mx->nd->zeros([3, 7, 11, 13])->[0]->shape, [7, 11, 13]);
+    is_deeply(mx->nd->zeros([2, 7, 11, 13])->[0]->shape, [7, 11, 13]);
+    is_deeply(mx->nd->zeros([1, 7, 11, 13])->[0]->shape, [7, 11, 13]);
+}
+
 test_ndarray_slice();
 test_ndarray_reshape();
 test_moveaxis();
@@ -226,3 +242,4 @@ test_linalg_gemm2();
 test_image_to_tensor();
 test_buffer_load();
 test_histogram();
+test_array_overload();

Reply via email to