cjolivier01 commented on a change in pull request #7903: Refactor AdaGrad 
optimizer to support sparse tensors
URL: https://github.com/apache/incubator-mxnet/pull/7903#discussion_r139765735
 
 

 ##########
 File path: tests/python/unittest/test_module.py
 ##########
 @@ -462,74 +462,105 @@ def test_shared_exec_group(exec_grp_shared, 
exec_grp_created, shared_arg_names=N
 
 def test_factorization_machine_module():
     """ Test factorization machine model with sparse operators """
-    mx.random.seed(11)
-    rnd.seed(11)
-
-    def fm(factor_size, feature_dim, init):
-        x = mx.symbol.Variable("data", stype='csr')
-        v = mx.symbol.Variable("v", shape=(feature_dim, factor_size),
-                               init=init, stype='row_sparse')
-
-        w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1),
-                                  init=init, stype='row_sparse')
-        w1_bias = mx.symbol.var('w1_bias', shape=(1))
-        w1 = mx.symbol.broadcast_add(mx.symbol.dot(x, w1_weight), w1_bias)
-
-        v_s = mx.symbol._internal._square_sum(data=v, axis=1, keepdims=True)
-        x_s = mx.symbol.square(data=x)
-        bd_sum = mx.sym.dot(x_s, v_s)
-
-        w2 = mx.symbol.dot(x, v)
-        w2_squared = 0.5 * mx.symbol.square(data=w2)
-
-        w_all = mx.symbol.Concat(w1, w2_squared, dim=1)
-        sum1 = mx.symbol.sum(data=w_all, axis=1, keepdims=True)
-        sum2 = 0.5 * mx.symbol.negative(bd_sum)
-        model = mx.sym.elemwise_add(sum1, sum2)
-
-        y = mx.symbol.Variable("label")
-        model = mx.symbol.LinearRegressionOutput(data=model, label=y)
-        return model
-
-    # model
-    ctx = default_context()
-    init = mx.initializer.Normal(sigma=0.01)
-    factor_size = 4
-    feature_dim = 10000
-    model = fm(factor_size, feature_dim, init)
-
-    # data iter
-    num_batches = 5
-    batch_size = 64
-    num_samples = batch_size * num_batches
-    # generate some random csr data
-    csr_nd = rand_ndarray((num_samples, feature_dim), 'csr', 0.1)
-    label = mx.nd.ones((num_samples,1))
-    # the alternative is to use LibSVMIter
-    train_iter = mx.io.NDArrayIter(data=csr_nd, label={'label':label},
-                                   batch_size=batch_size, 
last_batch_handle='discard')
-    # create module
-    mod = mx.mod.Module(symbol=model, data_names=['data'], 
label_names=['label'])
-    # allocate memory by given the input data and lable shapes
-    mod.bind(data_shapes=train_iter.provide_data, 
label_shapes=train_iter.provide_label)
-    # initialize parameters by uniform random numbers
-    mod.init_params(initializer=init)
-    # use Sparse SGD with learning rate 0.1 to train
-    adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.001, 
rescale_grad=1.0/batch_size)
-    mod.init_optimizer(optimizer=adam)
-    # use accuracy as the metric
-    metric = mx.metric.create('MSE')
-    # train 10 epoch
-    for epoch in range(10):
-        train_iter.reset()
-        metric.reset()
-        for batch in train_iter:
-            mod.forward(batch, is_train=True)       # compute predictions
-            mod.update_metric(metric, batch.label)  # accumulate prediction 
accuracy
-            mod.backward()                          # compute gradients
-            mod.update()                            # update parameters
-        # print('Epoch %d, Training %s' % (epoch, metric.get()))
-    assert(metric.get()[1] < 0.05), metric.get()[1]
+    def check_factorization_machine_module(optimizer=None, num_epochs=None):
 
 Review comment:
   test_optimizer appears to test C++ version against python version. There is 
only a python version for AdaGrad, therefore it's not clear what it tests 
against.  I am using the test_module() test with an expected accuracy rate to 
test.
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to