reductionista commented on a change in pull request #355: Keras fit interface
URL: https://github.com/apache/madlib/pull/355#discussion_r267584568
 
 

 ##########
 File path: src/ports/postgres/modules/convex/madlib_keras.py_in
 ##########
 @@ -0,0 +1,585 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import numpy as np
+import os
+import plpy
+import time
+
+from keras import backend as K
+from keras import utils as keras_utils
+from keras.layers import *
+from keras.models import *
+from keras.optimizers import *
+from keras.regularizers import *
+
+from madlib_keras_helper import KerasWeightsSerializer
+from madlib_keras_helper import get_data_as_np_array
+from madlib_keras_wrapper import *
+
+from utilities.model_arch_info import get_input_shape
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
+from utilities.utilities import _assert
+from utilities.utilities import add_postfix
+from utilities.utilities import is_var_valid
+from utilities.utilities import madlib_version
+
+
+def _validate_input_table(source_table, independent_varname,
+                          dependent_varname):
+    _assert(is_var_valid(source_table, independent_varname),
+            "model_keras error: invalid independent_varname "
+            "('{independent_varname}') for source_table "
+            "({source_table})!".format(
+                independent_varname=independent_varname,
+                source_table=source_table))
+
+    _assert(is_var_valid(source_table, dependent_varname),
+            "model_keras error: invalid dependent_varname "
+            "('{dependent_varname}') for source_table "
+            "({source_table})!".format(
+                dependent_varname=dependent_varname, 
source_table=source_table))
+
+def _validate_input_args(
+    source_table, dependent_varname, independent_varname, model_arch_table,
+    validation_table, output_model_table, num_iterations):
+
+    module_name = 'model_keras'
+    _assert(num_iterations > 0,
+        "model_keras error: Number of iterations cannot be < 1.")
+
+    output_summary_model_table = add_postfix(output_model_table, "_summary")
+    input_tbl_valid(source_table, module_name)
+    # Source table and validation tables must have the same schema
+    _validate_input_table(source_table, independent_varname, dependent_varname)
+    if validation_table and validation_table.strip() != '':
+        input_tbl_valid(validation_table, module_name)
+        _validate_input_table(validation_table, independent_varname,
+                              dependent_varname)
+    # Validate model arch table's schema.
+    input_tbl_valid(model_arch_table, module_name)
+    # Validate output tables
+    output_tbl_valid(output_model_table, module_name)
+    output_tbl_valid(output_summary_model_table, module_name)
+
+def _validate_input_shapes(source_table, independent_varname, input_shape):
+    """
+    Validate if the input shape specified in model architecture is the same
+    as the shape of the image specified in the indepedent var of the input
+    table.
+    """
+    # The weird indexing with 'i+2' and 'i' below has two reasons:
+    # 1) The indexing for array_upper() starts from 1, but indexing in the
+    # input_shape list starts from 0.
+    # 2) Input_shape is only the image's dimension, whereas a row of
+    # independent varname in a table contains buffer size as the first
+    # dimension, followed by the image's dimension. So we must ignore
+    # the first dimension from independent varname.
+    array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
+        independent_varname, i+2, i) for i in range(len(input_shape)))
+    query = """
+        SELECT {0}
+        FROM {1}
+        LIMIT 1
+    """.format(array_upper_query, source_table)
+    # This query will fail if an image in independent var does not have the
+    # same number of dimensions as the input_shape.
+    result = plpy.execute(query)[0]
+    _assert(len(result) == len(input_shape),
+        "model_keras error: The number of dimensions ({0}) of each image in" \
+        " model architecture and {1} in {2} ({3}) do not match.".format(
+            len(input_shape), independent_varname, source_table, len(result)))
+    for i in range(len(input_shape)):
+        key_name = "n_{0}".format(i)
+        if result[key_name] != input_shape[i]:
+            # Construct the shape in independent varname to display meaningful
+            # error msg.
+            input_shape_from_table = [result["n_{0}".format(i)]
+                for i in range(len(input_shape))]
+            plpy.error("model_keras error: Input shape {0} in the model" \
+                " architecture does not match the input shape {1} of column" \
+                " {2} in table {3}.".format(
+                    input_shape, input_shape_from_table, independent_varname,
+                    source_table))
+
+def fit(schema_madlib, source_table, model, dependent_varname,
+        independent_varname, model_arch_table, model_arch_id, compile_params,
+        fit_params, num_iterations, num_classes, use_gpu = True,
+        validation_table=None, name="", description="", **kwargs):
+    _validate_input_args(source_table, dependent_varname, independent_varname,
+                         model_arch_table, validation_table,
+                         model, num_iterations)
+
+    start_training_time = datetime.datetime.now()
+
+    # Disable GPU on master
+    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+
+    use_gpu = bool(use_gpu)
+
+    # Get the serialized master model
+    start_deserialization = time.time()
+    model_arch_query = "SELECT model_arch, model_weights FROM {0} WHERE "\
+        "id = {1}".format(model_arch_table, model_arch_id)
+    query_result = plpy.execute(model_arch_query)
+    if not  query_result:
+        plpy.error("no model arch found in table {0} with id {1}".format(
+            model_arch_table, model_arch_id))
+    query_result = query_result[0]
+    model_arch = query_result['model_arch']
+    input_shape = get_input_shape(model_arch)
+    _validate_input_shapes(source_table, independent_varname, input_shape)
+    if validation_table:
+        _validate_input_shapes(
+            validation_table, independent_varname, input_shape)
+    model_weights_serialized = query_result['model_weights']
+
+    # Convert model from json and initialize weights
+    master_model = model_from_json(model_arch)
+    model_weights = master_model.get_weights()
+
+    # Get shape of weights in each layer from model arch
+    model_shapes = []
+    for weight_arr in master_model.get_weights():
+        model_shapes.append(weight_arr.shape)
+
+    if model_weights_serialized:
+        # If warm start from previously trained model, set weights
+        model_weights = KerasWeightsSerializer.deserialize_weights_orig(
+            model_weights_serialized, model_shapes)
+        master_model.set_weights(model_weights)
+
+    end_deserialization = time.time()
+    # plpy.info("Model deserialization time: {} 
sec".format(end_deserialization - start_deserialization))
+
+    # Construct validation dataset if provided
+    validation_set_provided = bool(validation_table)
+    validation_aggregate_accuracy = []; validation_aggregate_loss = []
+    x_validation = None; y_validation = None
+    if validation_set_provided:
+        x_validation,  y_validation = get_data_as_np_array(
+            validation_table, dependent_varname, independent_varname,
+            input_shape, num_classes)
+
+    # Compute total buffers on each segment
+    total_buffers_per_seg = plpy.execute(
+        """ SELECT gp_segment_id, count(*) AS total_buffers_per_seg
+            FROM {0}
+            GROUP BY gp_segment_id
+        """.format(source_table))
+    seg_nums = [int(each_buffer["gp_segment_id"])
+        for each_buffer in total_buffers_per_seg]
+    total_buffers_per_seg = [int(each_buffer["total_buffers_per_seg"])
+        for each_buffer in total_buffers_per_seg]
+
+    # Prepare the SQL for running distributed training via UDA
+    compile_params_to_pass = "$madlib$" + compile_params + "$madlib$"
+    fit_params_to_pass = "$madlib$" + fit_params + "$madlib$"
+    run_training_iteration = plpy.prepare("""
+        SELECT {0}.fit_step(
+            {1}::REAL[],
+            {2}::SMALLINT[],
+            gp_segment_id,
+            {3}::INTEGER,
+            ARRAY{4},
+            ARRAY{5},
+            $MAD${6}$MAD$::TEXT,
+            {7}::TEXT,
+            {8}::TEXT,
+            {9},
+            $1
+        ) AS iteration_result
+        FROM {10}
+        """.format(schema_madlib, independent_varname, dependent_varname,
+                   num_classes, seg_nums, total_buffers_per_seg, model_arch,
+                   compile_params_to_pass, fit_params_to_pass,
+                   use_gpu, source_table), ["bytea"])
+
+    # Define the state for the model and loss/accuracy storage lists
+    model_state = KerasWeightsSerializer.serialize_weights(
+        0, 0, 0, model_weights)
+    aggregate_loss, aggregate_accuracy, aggregate_runtime = [], [], []
+
+    plpy.info("Model architecture size: {}KB".format(len(model_arch)/1024))
+    plpy.info("Model state (serialized) size: {}MB".format(
+        len(model_state)/1024/1024))
+
+    # Run distributed training for specified number of iterations
+    for i in range(num_iterations):
+        # prev_state = model_state
+        start_iteration = time.time()
+        try:
+            iteration_result = plpy.execute(
+                run_training_iteration, [model_state])[0]['iteration_result']
+        except plpy.SPIError as e:
+            plpy.error('A plpy error occurred in the step function: {0}'.
+                       format(str(e)))
+        end_iteration = time.time()
+        plpy.info("Time for iteration {0}: {1} sec".
+                  format(i + 1, end_iteration - start_iteration))
+        aggregate_runtime.append(datetime.datetime.now())
+        avg_loss, avg_accuracy, model_state = 
KerasWeightsSerializer.deserialize_iteration_state(iteration_result)
+        plpy.info("Average loss after training iteration {0}: {1}".format(
+            i + 1, avg_loss))
+        plpy.info("Average accuracy after training iteration {0}: {1}".format(
+            i + 1, avg_accuracy))
+        if validation_set_provided:
+            _, _, _, updated_weights = 
KerasWeightsSerializer.deserialize_weights(model_state, model_shapes)
+            master_model.set_weights(updated_weights)
+            compile_params_args = 
convert_string_of_args_to_dict(compile_params)
+            master_model.compile(**compile_params_args)
+            evaluate_result = master_model.evaluate(x_validation, y_validation)
+            if len(evaluate_result) < 2:
+                plpy.error('Calling evaluate on validation data returned < 2 '
+                           'metrics. Expected metrics are loss and accuracy')
+            validation_loss = evaluate_result[0]
+            validation_accuracy = evaluate_result[1]
+            plpy.info("Validation set accuracy after iteration {0}: {1}".
+                      format(i + 1, validation_accuracy))
+            validation_aggregate_accuracy.append(validation_accuracy)
+            validation_aggregate_loss.append(validation_loss)
+        aggregate_loss.append(avg_loss)
+        aggregate_accuracy.append(avg_accuracy)
+
+
+    end_training_time = datetime.datetime.now()
+
+    final_validation_acc = None
+    if validation_aggregate_accuracy and len(validation_aggregate_accuracy) > 
0:
+        final_validation_acc = validation_aggregate_accuracy[-1]
+
+    final_validation_loss = None
+    if validation_aggregate_loss and len(validation_aggregate_loss) > 0:
+        final_validation_loss = validation_aggregate_loss[-1]
+    version = madlib_version(schema_madlib)
+    # accuracy = aggregate_accuracy[-1]
+    # loss = aggregate_loss[-1]
+    create_output_summary_table = plpy.prepare("""
+        CREATE TABLE {0}_summary AS
+        SELECT
+        $1 AS model_arch_table,
+        $2 AS model_arch_id,
+        $3 AS model_type,
+        $4 AS start_training_time,
+        $5 AS end_training_time,
+        $6 AS source_table,
+        $7 AS validation_table,
+        $8 AS model,
+        $9 AS dependent_varname,
+        $10 AS independent_varname,
+        $11 AS name,
+        $12 AS description,
+        $13 AS model_size,
+        $14 AS madlib_version,
+        $15 AS compile_params,
+        $16 AS fit_params,
+        $17 AS num_iterations,
+        $18 AS num_classes,
+        $19 AS accuracy,
+        $20 AS loss,
+        $21 AS accuracy_iter,
+        $22 AS loss_iter,
+        $23 AS time_iter,
+        $24 AS accuracy_validation,
+        $25 AS loss_validation,
+        $26 AS accuracy_iter_validation,
+        $27 AS loss_iter_validation
+        """.format(model), ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
+                                 "TIMESTAMP", "TEXT", "TEXT","TEXT",
+                                 "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
+                                 "TEXT", "TEXT", "TEXT", "INTEGER",
+                                 "INTEGER", "DOUBLE PRECISION",
+                                 "DOUBLE PRECISION", "DOUBLE PRECISION[]",
+                                 "DOUBLE PRECISION[]", "TIMESTAMP[]",
+                                 "DOUBLE PRECISION", "DOUBLE PRECISION",
+                                 "DOUBLE PRECISION[]", "DOUBLE PRECISION[]"])
+    plpy.execute(
+        create_output_summary_table,
+        [
+            model_arch_table, model_arch_id,
+            "madlib_keras",
+            start_training_time, end_training_time,
+            source_table, validation_table,
+            model, dependent_varname,
+            independent_varname, name, description,
+            None, version, compile_params,
+            fit_params, num_iterations, num_classes,
+            aggregate_accuracy[-1],
+            aggregate_loss[-1],
+            aggregate_accuracy, aggregate_loss,
+            aggregate_runtime, final_validation_acc,
+            final_validation_loss,
+            validation_aggregate_accuracy,
+            validation_aggregate_loss
+        ]
+        )
+
+    create_output_table = plpy.prepare("""
+        CREATE TABLE {0} AS
+        SELECT $1 as model_data""".format(model), ["bytea"])
+    plpy.execute(create_output_table, [model_state])
+
+
+def fit_transition(state, ind_var, dep_var, current_seg_id, num_classes,
+                   all_seg_ids, total_buffers_per_seg, architecture,
+                   compile_params, fit_params, use_gpu, previous_state,
+                   **kwargs):
+
+    """
+
+    :param state:
+    :param ind_var:
+    :param dep_var:
+    :param current_seg_id:
+    :param num_classes:
+    :param all_seg_ids:
+    :param total_buffers_per_seg:
+    :param architecture:
+    :param compile_params:
+    :param fit_params:
+    :param use_gpu:
+    :param previous_state:
+    :param kwargs:
+    :return:
+    """
+    if not ind_var or not dep_var:
+        return state
+
+    start_transition = time.time()
+    SD = kwargs['SD']
+
+    gpus_per_host = 4
+    # Configure GPUs/CPUs
+    device_name = get_device_name_for_keras(
+        use_gpu, current_seg_id, gpus_per_host)
+
+    # Set up system if this is the first buffer on segment'
+    if not state:
+        set_keras_session(use_gpu)
+        segment_model = model_from_json(architecture)
+        compile_and_set_weights(segment_model, compile_params, device_name,
+                                previous_state)
+        SD['segment_model'] = segment_model
+        SD['buffer_count'] = 0
+    else:
+        segment_model = SD['segment_model']
+
+    agg_loss = 0
+    agg_accuracy = 0
+    input_shape = get_input_shape(architecture)
+
+    # Prepare the data
+    x_train = np.array(ind_var, dtype='float64').reshape(
+        len(ind_var), *input_shape)
+    y_train = np.array(dep_var)
+    y_train = keras_utils.to_categorical(y_train, num_classes)
+
+    # Fit segment model on data
+    start_fit = time.time()
+    with K.tf.device(device_name):
+        fit_params = convert_string_of_args_to_dict(fit_params)
+        history = segment_model.fit(x_train, y_train, **fit_params)
+        # loss, accuracy = prev_model.evaluate(x_train, y_train)
+        loss = history.history['loss'][0]
+        accuracy = history.history['acc'][0]
+    end_fit = time.time()
+
+    # Re-serialize the weights
+    # Update buffer count, check if we are done
+    SD['buffer_count'] += 1
+    updated_loss = agg_loss + loss
+    updated_accuracy = agg_accuracy + accuracy
+
+    with K.tf.device(device_name):
+        updated_weights = segment_model.get_weights()
+
+    total_buffers = total_buffers_per_seg[all_seg_ids.index(current_seg_id)]
+    if SD['buffer_count'] == total_buffers:
+        if total_buffers == 0:
+            plpy.error('total buffers is 0')
+
+        updated_loss /= total_buffers
+        updated_accuracy /= total_buffers
+        # plpy.info('final buffer loss {}, accuracy {}, buffer count 
{}'.format(loss, accuracy, SD['buffer_count']))
 
 Review comment:
   Remove commented code

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to