Author: dlyubimov
Date: Tue Mar 29 04:36:33 2011
New Revision: 1086473
URL: http://svn.apache.org/viewvc?rev=1086473&view=rev
Log:
MAHOUT-638 first installment: the fix. I will add tests on various types of
vectors a bit later.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java?rev=1086473&r1=1086472&r2=1086473&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
Tue Mar 29 04:36:33 2011
@@ -42,6 +42,8 @@ import org.apache.hadoop.mapreduce.lib.i
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.QJob.QJobKeyWritable;
@@ -148,14 +150,25 @@ public class BtJob {
// A
row
//
labels.
- int n = aRow.size();
Vector btRow = btValue.get();
- for (int i = 0; i < n; i++) {
- double mul = aRow.getQuick(i);
- for (int j = 0; j < kp; j++)
- btRow.setQuick(j, mul * qRow.getQuick(j));
- btKey.set(i);
- context.write(btKey, btValue);
+ if ( (aRow instanceof SequentialAccessSparseVector) ||
+ (aRow instanceof RandomAccessSparseVector )) {
+ for ( Vector.Element el:aRow ) {
+ double mul=el.get();
+ for ( int j =0; j < kp; j++ )
+ btRow.setQuick(j, mul * qRow.getQuick(j));
+ btKey.set(el.index());
+ context.write(btKey, btValue);
+ }
+ } else {
+ int n = aRow.size();
+ for (int i = 0; i < n; i++) {
+ double mul = aRow.getQuick(i);
+ for (int j = 0; j < kp; j++)
+ btRow.setQuick(j, mul * qRow.getQuick(j));
+ btKey.set(i);
+ context.write(btKey, btValue);
+ }
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java?rev=1086473&r1=1086472&r2=1086473&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
Tue Mar 29 04:36:33 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.math.hadoop.st
import java.util.Arrays;
import java.util.Random;
+import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
@@ -59,7 +60,8 @@ public class Omega {
assert yRow.length == kp;
Arrays.fill(yRow, 0);
- if (aRow instanceof SequentialAccessSparseVector) {
+ if ((aRow instanceof SequentialAccessSparseVector)||
+ (aRow instanceof RandomAccessSparseVector)){
int j = 0;
for (Element el : aRow) {
accumDots(j, el.get(), yRow);