Repository: incubator-hivemall Updated Branches: refs/heads/master 61711fbc2 -> f8beee36b
[HIVEMALL-211][BUGFIX] Fixed Optimizer for regularization updates ## What changes were proposed in this pull request? This PR fixes a bug of regularization scheme of Optimizer. ## What type of PR is it? Bug Fix ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-211 ## How was this patch tested? unit tests, manual tests on EMR ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <m...@apache.org> Closes #156 from myui/HIVEMALL-211. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f8beee36 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f8beee36 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f8beee36 Branch: refs/heads/master Commit: f8beee36b36274d8eb8948f6838aacd955817eba Parents: 61711fb Author: Makoto Yui <m...@apache.org> Authored: Fri Aug 24 18:44:40 2018 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Fri Aug 24 18:44:40 2018 +0900 ---------------------------------------------------------------------- .../hivemall/optimizer/DenseOptimizerFactory.java | 2 +- .../main/java/hivemall/optimizer/LossFunctions.java | 3 +++ .../src/main/java/hivemall/optimizer/Optimizer.java | 16 ++++++++-------- 3 files changed, 12 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java index b1fe917..5985868 100644 --- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java @@ -48,7 +48,7 @@ public final class DenseOptimizerFactory { && "adagrad".equalsIgnoreCase(optimizerName) == false) { throw new IllegalArgumentException( "`-regularization rda` is only supported for AdaGrad but `-optimizer " - + optimizerName); + + optimizerName + "`. Please specify `-regularization l1` and so on."); } final Optimizer optimizerImpl; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/LossFunctions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java index c4705c0..f76eb0e 100644 --- a/core/src/main/java/hivemall/optimizer/LossFunctions.java +++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java @@ -584,6 +584,9 @@ public final class LossFunctions { } } + /** + * logistic loss function where target is 0 (negative) or 1 (positive). + */ public static float logisticLoss(final float target, final float predicted) { if (predicted > -100.d) { return target - (float) MathUtils.sigmoid(predicted); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index 4b1ef0a..0cbac42 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -70,9 +70,8 @@ public interface Optimizer { */ protected float update(@Nonnull final IWeightValue weight, final float gradient) { float oldWeight = weight.get(); - float g = _reg.regularize(oldWeight, gradient); - float delta = computeDelta(weight, g); - float newWeight = oldWeight - _eta.eta(_numStep) * delta; + float delta = computeDelta(weight, gradient); + float newWeight = oldWeight - _eta.eta(_numStep) * _reg.regularize(oldWeight, delta); weight.set(newWeight); return newWeight; } @@ -123,10 +122,10 @@ public interface Optimizer { @Override protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { - float new_scaled_sum_sqgrad = - weight.getSumOfSquaredGradients() + gradient * (gradient / scale); - weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad); - return gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps); + float old_scaled_gg = weight.getSumOfSquaredGradients(); + float new_scaled_gg = old_scaled_gg + gradient * (gradient / scale); + weight.setSumOfSquaredGradients(new_scaled_gg); + return (float) (gradient / Math.sqrt(eps + ((double) old_scaled_gg) * scale)); } @Override @@ -156,7 +155,8 @@ public interface Optimizer { float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * gradient * (gradient / scale)); float delta = (float) Math.sqrt( - (old_sum_squared_delta_x + eps) / (new_scaled_sum_sqgrad * scale + eps)) * gradient; + (old_sum_squared_delta_x + eps) / ((double) new_scaled_sum_sqgrad * scale + eps)) + * gradient; float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) + ((1.f - decay) * delta * delta); weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);