fmcquillan99 commented on pull request #558:
URL: https://github.com/apache/madlib/pull/558#issuecomment-792925492


   This works for me with the following model arch on cifar10:
   ```
   from __future__ import print_function
   from tensorflow import keras
   from tensorflow.keras.datasets import cifar10
   from tensorflow.keras.preprocessing.image import ImageDataGenerator
   from tensorflow.keras.models import Sequential
   from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
   from tensorflow.keras.layers import Conv2D, MaxPooling2D
   import os
   
   batch_size = 32
   num_classes = 10
   epochs = 2
   data_augmentation = True
   num_predictions = 20
   #save_dir = os.path.join(os.getcwd(), 'saved_models')
   #model_name = 'keras_cifar10_trained_model.h5'
   
   # The data, split between train and test sets:
   (x_train, y_train), (x_test, y_test) = cifar10.load_data()
   print('x_train shape:', x_train.shape)
   print(x_train.shape[0], 'train samples')
   print(x_test.shape[0], 'test samples')
   
   # Convert class vectors to binary class matrices.
   y_train = keras.utils.to_categorical(y_train, num_classes)
   y_test = keras.utils.to_categorical(y_test, num_classes)
   
   model = Sequential()
   model.add(Conv2D(32, (3, 3), padding='same',
                    input_shape=x_train.shape[1:]))
   model.add(Activation('relu'))
   model.add(Conv2D(32, (3, 3)))
   model.add(Activation('relu'))
   model.add(MaxPooling2D(pool_size=(2, 2)))
   model.add(Dropout(0.25))
   
   model.add(Conv2D(64, (3, 3), padding='same'))
   model.add(Activation('relu'))
   model.add(Conv2D(64, (3, 3)))
   model.add(Activation('relu'))
   model.add(MaxPooling2D(pool_size=(2, 2)))
   model.add(Dropout(0.25))
   
   model.add(Flatten())
   model.add(Dense(512))
   model.add(Activation('relu'))
   model.add(Dropout(0.5))
   model.add(Dense(num_classes))
   model.add(Activation('softmax'))
   
   # initiate RMSprop optimizer
   opt = keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
   
   # Let's train the model using RMSprop
   model.compile(loss='categorical_crossentropy',
                 optimizer=opt,
                 metrics=['accuracy']);
   ```
   LGTM


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to