Repository: incubator-hivemall Updated Branches: refs/heads/master f8beee36b -> 4ca1c19c7
[HIVEMALL-212] Fix Classifier/Regressor not to forward zero weighted values ## What changes were proposed in this pull request? Feature with weight = 0.0 need not to be saved in the prediction model. It is preferable to reduce the size of prediction model. So, this PR fixes Classifier/Regressor not to forward zero weighted values ## What type of PR is it? Improvement ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-212 ## How was this patch tested? unit tests and manual tests ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <m...@apache.org> Closes #157 from myui/HIVEMALL-212. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4ca1c19c Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4ca1c19c Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4ca1c19c Branch: refs/heads/master Commit: 4ca1c19c7e00120cb0abd42d6c1f1b48176846e8 Parents: f8beee3 Author: Makoto Yui <m...@apache.org> Authored: Wed Aug 29 00:42:45 2018 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Wed Aug 29 00:42:45 2018 +0900 ---------------------------------------------------------------------- .../java/hivemall/GeneralLearnerBaseUDTF.java | 33 ++++++++++++++++---- .../hivemall/ensemble/ArgminKLDistanceUDAF.java | 14 ++++++--- .../optimizer/SparseOptimizerFactory.java | 6 +++- 3 files changed, 42 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 5c3967b..0198e77 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -452,9 +452,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { cvState.incrLoss(loss); // retain cumulative loss to check convergence final float dloss = lossFunction.dloss(predicted, target); + if (dloss == 0.f) { + optimizer.proceedStep(); + return; + } + if (is_mini_batch) { accumulateUpdate(features, dloss); - if (sampled >= mini_batch_size) { batchUpdate(); } @@ -494,7 +498,11 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { for (Map.Entry<Object, FloatAccumulator> e : accumulated.entrySet()) { Object feature = e.getKey(); FloatAccumulator v = e.getValue(); - float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M) + final float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M) + if (new_weight == 0.f) { + model.delete(feature); + continue; + } model.setWeight(feature, new_weight); } @@ -507,7 +515,11 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { Object feature = f.getFeature(); float xi = f.getValueAsFloat(); float weight = model.getWeight(feature); - float new_weight = optimizer.update(feature, weight, dloss * xi); + final float new_weight = optimizer.update(feature, weight, dloss * xi); + if (new_weight == 0.f) { + model.delete(feature); + continue; + } model.setWeight(feature, new_weight); } } @@ -701,9 +713,14 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { if (!probe.isTouched()) { continue; // skip outputting untouched weights } + final float v = probe.get(); + final float cv = probe.getCovariance(); + if (v == 0.f && cv == 0.f) { + continue; + } + fv.set(v); + cov.set(cv); Object k = itor.getKey(); - fv.set(probe.get()); - cov.set(probe.getCovariance()); forwardMapObj[0] = k; forwardMapObj[1] = fv; forwardMapObj[2] = cov; @@ -720,8 +737,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { if (!probe.isTouched()) { continue; // skip outputting untouched weights } + final float v = probe.get(); + if (v == 0.f) { + continue; + } + fv.set(v); Object k = itor.getKey(); - fv.set(probe.get()); forwardMapObj[0] = k; forwardMapObj[1] = fv; forward(forwardMapObj); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java b/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java index 136ca0d..774db6f 100644 --- a/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java +++ b/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java @@ -38,8 +38,8 @@ public final class ArgminKLDistanceUDAF extends UDAF { float sum_inv_covar; PartialResult() { - this.sum_mean_div_covar = 0f; - this.sum_inv_covar = 0f; + this.sum_mean_div_covar = 0.f; + this.sum_inv_covar = 0.f; } } @@ -54,7 +54,10 @@ public final class ArgminKLDistanceUDAF extends UDAF { if (partial == null) { this.partial = new PartialResult(); } - float covar_f = covar.get(); + final float covar_f = covar.get(); + if (covar_f == 0.f) {// avoid null division + return true; + } partial.sum_mean_div_covar += (mean.get() / covar_f); partial.sum_inv_covar += (1.f / covar_f); return true; @@ -80,7 +83,10 @@ public final class ArgminKLDistanceUDAF extends UDAF { if (partial == null) { return null; } - float mean = (1f / partial.sum_inv_covar) * partial.sum_mean_div_covar; + if (partial.sum_inv_covar == 0.f) {// avoid null division + return new FloatWritable(0.f); + } + float mean = (1.f / partial.sum_inv_covar) * partial.sum_mean_div_covar; return new FloatWritable(mean); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java index 7cf61d8..1254740 100644 --- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java @@ -177,7 +177,11 @@ public final class SparseOptimizerFactory { } else { auxWeight.set(weight); } - return update(auxWeight, gradient); + final float newWeight = update(auxWeight, gradient); + if (newWeight == 0.f) { + auxWeights.remove(feature); + } + return newWeight; } }