Ekta Khanna created MADLIB-1441:
-----------------------------------

             Summary: DL - Add support for custom loss functions 
                 Key: MADLIB-1441
                 URL: https://issues.apache.org/jira/browse/MADLIB-1441
             Project: Apache MADlib
          Issue Type: New Feature
          Components: Deep Learning
            Reporter: Ekta Khanna
             Fix For: v1.18.0


Keras supports custom loss functions to be passed in as compile_params to the 
fit function. For MADlib deep_learning module to support custom functions, we 
have the following functions are affected:

1. madlib_keras_fit()

2. madlib_keras_fit_multiple_model()

3. madlib_keras_evaluate()

Currently the compile_param passed in are like
 {{optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']}}
 For a custom function we would just pass in as follows:
 {{optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss=custom_fn1, 
metrics=['mae']}} 

 

object table definition with a dill object of custom functions is as follows:
||id||name||description||function||
|1|custom_fn1|test function|\x80026...|
|1|custom_fn2|test function2|\x8ab26...|

 

For fit_multiple, we want to support multiple configs in mst_table with a mix 
of custom and keras loss functions:
||mst_key||model_id||compile_params||fit_params||
|2|1|loss='custom_fn1', optimizer=....|batch_size=16, epochs=1|
|3|1|loss='custom_fn2', optimizer=....|batch_size=16, epochs=1|
|1|1|loss='categorical_crossentropy', optimizer=....|batch_size=16, epochs=1|

 
h4. API Changes
 # *Fit*: add a new optional param for passing in the object_table name
 object_table (optional) VARCHAR: Name of the table containing Python objects 
in the case that custom loss functions or custom metrics are specified in the 
parameter `compile_params
{code:java}
madlib_keras_fit(
    source_table,
    model,
    model_arch_table,
    model_id,
    compile_params,
    fit_params,
    num_iterations,
    use_gpus,
    validation_table,
    metrics_compute_frequency,
    warm_start,
    name,
    description,
    object_table  -- new parameter
    ){code}

 # *Fit_multiple*: No change to the {{madlib_keras_fit_multiple_model()}} 
function. Reads object_table information from the model_selection table.
 A summary table named <model>_summary is also created, which has the following 
new columns:
|model_selection_table |Name of the table containing model selection parameters 
to be tried.|
|object_table| Name of the object table containing the serialized Python 
objects for custom loss functions and custom metrics (read from the mst_summary 
table).|

 # *Evaluate*: No change to the {{madlib_keras_evaluate()}} function. Reads 
object_table information from the model table. Output table adds a new column
|loss_type|Type of loss used that was used in the training step
 If a custom loss or metric is used, we should give the name of it. Otherwise 
list the built-in one used|



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

Reply via email to