liuzh91 opened a new issue #16930: Plug and Play fit_batch() for the estimator 
class
URL: https://github.com/apache/incubator-mxnet/issues/16930
 
 
   ## Description
   In the current estimator implementation, fit_batch() is a class method of 
the estimator class. A common workflow of fit_batch() is that the model 
`self.net` forwards the training batch to generate outputs and compute loss 
functions. The problem is that such design is not flexible enough with 
different model forward interfaces on the same task. For example, fit_batch() 
of the base estimator trains the current batch on the label prediction task:
   ```
           with autograd.record():
               pred = [self.net(x) for x in data]
               loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
   ```
   In the above example, the model forward interface of `self.net` is `def 
forward(self, inputs)` with the return value of predict labels.  The estimator 
is compatible with any model using this forward interface. However, if we have 
another model for the label prediction task with a different forward interface 
`def forward(self, inputs, input_length)`, the base estimator is not compatible 
with this model even though both models share the same loss functions, training 
and evaluation metrics. A real world example can be found at LM models 
(https://github.com/dmlc/gluon-nlp/blob/c03665bafb1e0fe0fa5c2a59bbb4845393fbf9ba/src/gluonnlp/model/train/language_model.py).
 `StandardRNN` and `AWDRNN` shares the same forward interface, whereas `BigRNN` 
has a different one.
   
   A straightforward workaround is to create a new customized estimator for 
each model interface. It will bring the issue that we need to create a 
standalone estimator each time we see a new model interface even on the same 
task. In machine learning community, it is common to see different model 
forward logic on the same task. This approach will leads to prohibitively many 
estimators for some simple task. In the above LM example, we need to create a 
`vanillaRNNEstimator` and a `BigRNNEstimator` even most of the training logic 
between these two estimators are the same.
   
   To prevent the above estimator explosion issue, we suggest adding support of 
a plug and play customized `fit_batch()` which is similar to the 
`event_handlers` from the estimator. Given an existing estimator `est`,  we 
modify the `fit_batch()` method to take an extra argument of 
`fit_batch_handler`. So we can call `est.fit_batch(train_data=data_loader, 
epochs=epochs, fit_batch_handler=fit_StandardRNN_batch)` or 
`est.fit_batch(train_data=data_loader, epochs=epochs, 
fit_batch_handler=fit_BigRNN_batch)` to use models with different interface.
   If there is no `fit_batch_handler` provided, we will use the default 
`fit_batch()` method.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to