[
https://issues.apache.org/jira/browse/MAHOUT-1093?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=13472159#comment-13472159
]
Eric Springer edited comment on MAHOUT-1093 at 10/9/12 5:59 AM:
----------------------------------------------------------------
There's no "upload file"??
Raw Diff on github:
https://github.com/espringe/mahout/commit/831ca2200df9802f24c8a92077377f677be746ef.diff
or inline:
----
commit 831ca2200df9802f24c8a92077377f677be746ef
Author: Eric Springer <[email protected]>
Date: Tue Oct 9 12:36:14 2012 +1100
CrossFoldLearner shouldn't train on all folds if TrackingKey is negative
diff --git
a/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
b/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
index 33f0266..f8b5b67 100644
--- a/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
+++ b/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
@@ -123,7 +123,7 @@ public class CrossFoldLearner extends
AbstractVectorClassifier implements Online
record++;
int k = 0;
for (OnlineLogisticRegression model : models) {
- if (k == trackingKey % models.size()) {
+ if (k == mod(trackingKey, models.size())) {
Vector v = model.classifyFull(instance);
double score = Math.max(v.get(actual), MIN_SCORE);
logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record,
windowSize);
@@ -140,6 +140,11 @@ public class CrossFoldLearner extends
AbstractVectorClassifier implements Online
}
}
+ private int mod(int x, int y) {
+ int r = x % y;
+ return r < 0 ? r + y : r;
+ }
+
@Override
public void close() {
for (OnlineLogisticRegression m : models) {
was (Author: espringe):
Since there's no "upload file", I'll paste it inline...
----
commit 831ca2200df9802f24c8a92077377f677be746ef
Author: Eric Springer <[email protected]>
Date: Tue Oct 9 12:36:14 2012 +1100
CrossFoldLearner shouldn't train on all folds if TrackingKey is negative
diff --git
a/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
b/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
index 33f0266..f8b5b67 100644
--- a/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
+++ b/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
@@ -123,7 +123,7 @@ public class CrossFoldLearner extends
AbstractVectorClassifier implements Online
record++;
int k = 0;
for (OnlineLogisticRegression model : models) {
- if (k == trackingKey % models.size()) {
+ if (k == mod(trackingKey, models.size())) {
Vector v = model.classifyFull(instance);
double score = Math.max(v.get(actual), MIN_SCORE);
logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record,
windowSize);
@@ -140,6 +140,11 @@ public class CrossFoldLearner extends
AbstractVectorClassifier implements Online
}
}
+ private int mod(int x, int y) {
+ int r = x % y;
+ return r < 0 ? r + y : r;
+ }
+
@Override
public void close() {
for (OnlineLogisticRegression m : models) {
> CrossFoldLearner trains in all folds if trackign key is negative
> ----------------------------------------------------------------
>
> Key: MAHOUT-1093
> URL: https://issues.apache.org/jira/browse/MAHOUT-1093
> Project: Mahout
> Issue Type: Bug
> Components: Classification
> Reporter: Eric Springer
>
> See: https://github.com/apache/mahout/pull/7
--
This message is automatically generated by JIRA.
If you think it was sent incorrectly, please contact your JIRA administrators
For more information on JIRA, see: http://www.atlassian.com/software/jira