Github user njayaram2 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/229#discussion_r163048733
--- Diff: src/modules/convex/algo/igd.hpp ---
@@ -56,6 +59,62 @@ IGD<State, ConstState, Task>::transition(state_type
&state,
state.task.stepsize * tuple.weight);
}
+/**
+ * @brief Update the transition state in mini-batches
+ *
+ * Note: We assume that
+ * 1. Task defines a model_eigen_type
+ * 2. A batch of tuple.indVar is a Matrix
+ * 3. A batch of tuple.depVar is a ColumnVector
+ * 4. Task defines a getLossAndUpdateModel method
+ *
+ */
+ template <class State, class ConstState, class Task>
+ void
+ IGD<State, ConstState, Task>::transitionInMiniBatch(
+ state_type &state,
+ const tuple_type &tuple) {
+
+ madlib_assert(tuple.indVar.rows() == tuple.depVar.rows(),
+ std::runtime_error("Invalid data. Independent and
dependent "
+ "batches don't have same number of
rows."));
+
+ int batch_size = state.algo.batchSize;
+ int n_epochs = state.algo.nEpochs;
+
+ // n_rows/n_ind_cols are the rows/cols in a transition tuple.
+ int n_rows = tuple.indVar.rows();
+ int n_ind_cols = tuple.indVar.cols();
+ int n_batches = n_rows < batch_size ? 1 :
+ n_rows / batch_size +
+ int(n_rows%batch_size > 0);
+
+ for (int curr_epoch=0; curr_epoch < n_epochs; curr_epoch++) {
+ double loss = 0.0;
+ for (int curr_batch=0, curr_batch_row_index=0; curr_batch <
n_batches;
+ curr_batch++, curr_batch_row_index += batch_size) {
+ Matrix X_batch;
+ ColumnVector y_batch;
+ if (curr_batch == n_batches-1) {
+ // last batch
+ X_batch =
tuple.indVar.bottomRows(n_rows-curr_batch_row_index);
+ y_batch = tuple.depVar.tail(n_rows-curr_batch_row_index);
+ } else {
+ X_batch = tuple.indVar.block(curr_batch_row_index, 0,
batch_size, n_ind_cols);
+ y_batch = tuple.depVar.segment(curr_batch_row_index,
batch_size);
+ }
+ loss += Task::getLossAndUpdateModel(
+ state.task.model, X_batch, y_batch, state.task.stepsize);
+ }
+
+ // The first epoch will most likely have the most loss.
+ // So being pessimistic, we return average loss only for the first
epoch.
+ if (curr_epoch==0) state.algo.loss += loss;
--- End diff --
Should we average this over `n_batches`?
---