Nikhil created MADLIB-1350:
------------------------------
Summary: Warm start with madlib_keras_fit()
Key: MADLIB-1350
URL: https://issues.apache.org/jira/browse/MADLIB-1350
Project: Apache MADlib
Issue Type: Improvement
Components: Deep Learning
Reporter: Nikhil
Fix For: v1.16
Many deep neural nets are not trained from scratch in one-shot. Training may
happen over time depending on available resources. So when you restart
training, you want to pick up from where you left off.
As a data scientist,
I want to continue training a model based on weights that I have from a
previous run,
so that I don't have to start from scratch.
* e.g., continue training from where you left off
Interface
Add `warm_start` Boolean to fit() like in MLP
http://madlib.apache.org/docs/latest/group__grp__nn.html
{code}
madlib_keras_fit(
source_table VARCHAR,
model VARCHAR,
model_arch_table VARCHAR,
model_arch_id INTEGER,
compile_params VARCHAR,
fit_params VARCHAR,
num_iterations INTEGER,
gpus_per_host INTEGER,
validation_table VARCHAR,
warm_start BOOLEAN, <-- NEW PARAMETER
name VARCHAR,
description VARCHAR
{code}
Logic
{code}
if warm_start = TRUE
use weights from output table
else
use weights from model arch table if there are any (if not use the
initialization as defined in the model arch in keras)
{code}
This JIRA is for the first part of the if , i.e, `if warm_start = TRUE`
Details
1. User should be able to change the `compile_params` and `fit_params` between
warm starts. However, the model architecture is fixed between warm starts.
2. Ensure that weight initialization done in model arch is not overwritten by
MADlib, in the case that the model arch tables has NULL in the weights column.
3. Overwrite i.e., replace the model output and model summary tables when use
warm-start.
Acceptance
1. Train a model for n iterations
2. Start training again using the saved model state as an input to the 2nd
round of training. Training for the 2nd round should start from where it left
off, which you can see by looking at the loss or accuracy function and not
seeing a discontinuity.
3. Repeat #1 above using a 3rd round of training.
4. Repeat 2-3 for a different metric besides accuracy.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)