reductionista commented on a change in pull request #355: Keras fit interface URL: https://github.com/apache/madlib/pull/355#discussion_r267112373
########## File path: src/ports/postgres/modules/convex/madlib_keras.py_in ########## @@ -0,0 +1,633 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import os +import plpy +import time + +from keras import backend as K +from keras import utils as keras_utils +from keras.layers import * +from keras.models import * +from keras.optimizers import * +from keras.regularizers import * +import numpy as np + +from madlib_keras_helper import * +from utilities.model_arch_info import get_input_shape +from utilities.validate_args import input_tbl_valid +from utilities.validate_args import output_tbl_valid +from utilities.utilities import _assert +from utilities.utilities import add_postfix +from utilities.utilities import is_var_valid +from utilities.utilities import madlib_version + +def _validate_input_table(source_table, independent_varname, + dependent_varname): + _assert(is_var_valid(source_table, independent_varname), + "model_keras error: invalid independent_varname " + "('{independent_varname}') for source_table " + "({source_table})!".format( + independent_varname=independent_varname, + source_table=source_table)) + + _assert(is_var_valid(source_table, dependent_varname), + "model_keras error: invalid dependent_varname " + "('{dependent_varname}') for source_table " + "({source_table})!".format( + dependent_varname=dependent_varname, source_table=source_table)) + +def _validate_input_args( + source_table, dependent_varname, independent_varname, model_arch_table, + validation_table, output_model_table, num_iterations): + + module_name = 'model_keras' + _assert(num_iterations > 0, + "model_keras error: Number of iterations cannot be < 1.") + + output_summary_model_table = add_postfix(output_model_table, "_summary") + input_tbl_valid(source_table, module_name) + # Source table and validation tables must have the same schema + _validate_input_table(source_table, independent_varname, dependent_varname) + if validation_table and validation_table.strip() != '': + input_tbl_valid(validation_table, module_name) + _validate_input_table(validation_table, independent_varname, + dependent_varname) + # Validate model arch table's schema. + input_tbl_valid(model_arch_table, module_name) + # Validate output tables + output_tbl_valid(output_model_table, module_name) + output_tbl_valid(output_summary_model_table, module_name) + +def _validate_input_shapes(source_table, independent_varname, input_shape): + """ + Validate if the input shape specified in model architecture is the same + as the shape of the image specified in the indepedent var of the input + table. + """ + # The weird indexing with 'i+2' and 'i' below has two reasons: + # 1) The indexing for array_upper() starts from 1, but indexing in the + # input_shape list starts from 0. + # 2) Input_shape is only the image's dimension, whereas a row of + # independent varname in a table contains buffer size as the first + # dimension, followed by the image's dimension. So we must ignore + # the first dimension from independent varname. + array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format( + independent_varname, i+2, i) for i in range(len(input_shape))) + query = """ + SELECT {0} + FROM {1} + LIMIT 1 + """.format(array_upper_query, source_table) + # This query will fail if an image in independent var does not have the + # same number of dimensions as the input_shape. + result = plpy.execute(query)[0] + _assert(len(result) == len(input_shape), + "model_keras error: The number of dimensions ({0}) of each image in" \ + " model architecture and {1} in {2} ({3}) do not match.".format( + len(input_shape), independent_varname, source_table, len(result))) + for i in range(len(input_shape)): + key_name = "n_{0}".format(i) + if result[key_name] != input_shape[i]: + # Construct the shape in independent varname to display meaningful + # error msg. + input_shape_from_table = [result["n_{0}".format(i)] for i in range( + 1, len(input_shape))] + plpy.error("model_keras error: Input shape {0} in the model" \ + " architecture does not match the input shape {1} of column" \ + " {2} in table {3}.".format(input_shape, input_shape_from_table, independent_varname, source_table)) + +def fit(schema_madlib, source_table, model, dependent_varname, + independent_varname, model_arch_table, model_arch_id, compile_params, + fit_params, num_iterations, num_classes, use_gpu = True, + validation_table=None, name="", description="", **kwargs): + _validate_input_args(source_table, dependent_varname, independent_varname, + model_arch_table, validation_table, + model, num_iterations) + + start_training_time = datetime.datetime.now() + + # Disable GPU on master + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + use_gpu = bool(use_gpu) + + # Get the serialized master model + start_deserialization = time.time() Review comment: start_deserialization is not used--remove? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
