This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new d009345  [1.x][LT] Add forward, backward test for linalg.gemm2 (#18784)
d009345 is described below

commit d0093458e3be5e76d78750043c4e5a3f01a7d056
Author: Chaitanya Prakash Bapat <chai.ba...@gmail.com>
AuthorDate: Mon Jul 27 20:28:43 2020 -0700

    [1.x][LT] Add forward, backward test for linalg.gemm2 (#18784)
    
    * added forward, backward test for gemm2
    
    * add backward check
    
    * correct gradient assert
    
    * move test inside linalg_ops
    
    * add shape checks
---
 tests/nightly/test_large_array.py | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/tests/nightly/test_large_array.py 
b/tests/nightly/test_large_array.py
index 020a707..306c827 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -1207,6 +1207,25 @@ def test_linalg():
         assert A.grad[0,0,0] == 4
         assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, 
atol=1e-5)
 
+    def check_gemm2():
+        def run_gemm2(inp1,inp2):
+            inp1.attach_grad()
+            inp2.attach_grad()
+            with mx.autograd.record():
+                out = mx.nd.linalg.gemm2(inp1,inp2)
+            return inp1.grad, inp2.grad, out
+
+        inp1=mx.nd.ones(shape=(SMALL_Y, LARGE_X))
+        inp1[0][0]=0.1
+        inp2=mx.nd.ones(shape=(LARGE_X, SMALL_Y))
+        inp1_grad, inp2_grad, out= run_gemm2(inp1,inp2)
+        assert out.asnumpy()[0][0] == LARGE_X
+        assert out.shape == (SMALL_Y, SMALL_Y)
+        out.backward()
+        assert inp1_grad.shape == (SMALL_Y, LARGE_X)
+        assert inp2_grad.shape == (LARGE_X, SMALL_Y)
+        assert_almost_equal(inp2_grad.asnumpy()[0][0],49.1)
+
     def check_det():
         def run_det(inp):
             inp.attach_grad()
@@ -1321,6 +1340,7 @@ def test_linalg():
     check_potrf()
     check_potri()
     check_syrk_batch()
+    check_gemm2()
     check_det()
     check_inverse()
     check_trmm()

Reply via email to