Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 61711fbc2 -> f8beee36b


[HIVEMALL-211][BUGFIX] Fixed Optimizer for regularization updates

## What changes were proposed in this pull request?

This PR fixes a bug of regularization scheme of Optimizer.

## What type of PR is it?

Bug Fix

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-211

## How was this patch tested?

unit tests, manual tests on EMR

## Checklist

(Please remove this section if not needed; check `x` for YES, blank for NO)

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for 
your commit?
- [x] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <m...@apache.org>

Closes #156 from myui/HIVEMALL-211.


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f8beee36
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f8beee36
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f8beee36

Branch: refs/heads/master
Commit: f8beee36b36274d8eb8948f6838aacd955817eba
Parents: 61711fb
Author: Makoto Yui <m...@apache.org>
Authored: Fri Aug 24 18:44:40 2018 +0900
Committer: Makoto Yui <m...@apache.org>
Committed: Fri Aug 24 18:44:40 2018 +0900

----------------------------------------------------------------------
 .../hivemall/optimizer/DenseOptimizerFactory.java   |  2 +-
 .../main/java/hivemall/optimizer/LossFunctions.java |  3 +++
 .../src/main/java/hivemall/optimizer/Optimizer.java | 16 ++++++++--------
 3 files changed, 12 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java 
b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
index b1fe917..5985868 100644
--- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
+++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
@@ -48,7 +48,7 @@ public final class DenseOptimizerFactory {
                 && "adagrad".equalsIgnoreCase(optimizerName) == false) {
             throw new IllegalArgumentException(
                 "`-regularization rda` is only supported for AdaGrad but 
`-optimizer "
-                        + optimizerName);
+                        + optimizerName + "`. Please specify `-regularization 
l1` and so on.");
         }
 
         final Optimizer optimizerImpl;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/LossFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java 
b/core/src/main/java/hivemall/optimizer/LossFunctions.java
index c4705c0..f76eb0e 100644
--- a/core/src/main/java/hivemall/optimizer/LossFunctions.java
+++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java
@@ -584,6 +584,9 @@ public final class LossFunctions {
         }
     }
 
+    /**
+     * logistic loss function where target is 0 (negative) or 1 (positive).
+     */
     public static float logisticLoss(final float target, final float 
predicted) {
         if (predicted > -100.d) {
             return target - (float) MathUtils.sigmoid(predicted);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/Optimizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java 
b/core/src/main/java/hivemall/optimizer/Optimizer.java
index 4b1ef0a..0cbac42 100644
--- a/core/src/main/java/hivemall/optimizer/Optimizer.java
+++ b/core/src/main/java/hivemall/optimizer/Optimizer.java
@@ -70,9 +70,8 @@ public interface Optimizer {
          */
         protected float update(@Nonnull final IWeightValue weight, final float 
gradient) {
             float oldWeight = weight.get();
-            float g = _reg.regularize(oldWeight, gradient);
-            float delta = computeDelta(weight, g);
-            float newWeight = oldWeight - _eta.eta(_numStep) * delta;
+            float delta = computeDelta(weight, gradient);
+            float newWeight = oldWeight - _eta.eta(_numStep) * 
_reg.regularize(oldWeight, delta);
             weight.set(newWeight);
             return newWeight;
         }
@@ -123,10 +122,10 @@ public interface Optimizer {
 
         @Override
         protected float computeDelta(@Nonnull final IWeightValue weight, final 
float gradient) {
-            float new_scaled_sum_sqgrad =
-                    weight.getSumOfSquaredGradients() + gradient * (gradient / 
scale);
-            weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
-            return gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * 
scale) + eps);
+            float old_scaled_gg = weight.getSumOfSquaredGradients();
+            float new_scaled_gg = old_scaled_gg + gradient * (gradient / 
scale);
+            weight.setSumOfSquaredGradients(new_scaled_gg);
+            return (float) (gradient / Math.sqrt(eps + ((double) 
old_scaled_gg) * scale));
         }
 
         @Override
@@ -156,7 +155,8 @@ public interface Optimizer {
             float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad)
                     + ((1.f - decay) * gradient * (gradient / scale));
             float delta = (float) Math.sqrt(
-                (old_sum_squared_delta_x + eps) / (new_scaled_sum_sqgrad * 
scale + eps)) * gradient;
+                (old_sum_squared_delta_x + eps) / ((double) 
new_scaled_sum_sqgrad * scale + eps))
+                    * gradient;
             float new_sum_squared_delta_x =
                     (decay * old_sum_squared_delta_x) + ((1.f - decay) * delta 
* delta);
             weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);

Reply via email to