Author: jmannix
Date: Tue Apr 12 05:05:06 2011
New Revision: 1091292

URL: http://svn.apache.org/viewvc?rev=1091292&view=rev
Log:
fixes MAHOUT-666

(not a bug number you leave open for long!)

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java?rev=1091292&r1=1091291&r2=1091292&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
 Tue Apr 12 05:05:06 2011
@@ -62,7 +62,8 @@ import java.util.Iterator;
  *
  */
 public class DistributedRowMatrix implements VectorIterable, Configurable {
-
+  public static final String REMOVE_TEMP_DIRS = 
"DistributedMatrix.remove.temp.dirs";
+  
   private static final Logger log = 
LoggerFactory.getLogger(DistributedRowMatrix.class);
 
   private final Path inputPath;
@@ -72,6 +73,7 @@ public class DistributedRowMatrix implem
   private Path outputTmpBasePath;
   private final int numRows;
   private final int numCols;
+  private boolean removeTempDirs;
 
   public DistributedRowMatrix(Path inputPathString,
                               Path outputTmpPathString,
@@ -81,6 +83,7 @@ public class DistributedRowMatrix implem
     this.outputTmpPath = outputTmpPathString;
     this.numRows = numRows;
     this.numCols = numCols;
+    this.removeTempDirs = false;
   }
 
   @Override
@@ -94,6 +97,7 @@ public class DistributedRowMatrix implem
     try {
       rowPath = FileSystem.get(conf).makeQualified(inputPath);
       outputTmpBasePath = FileSystem.get(conf).makeQualified(outputTmpPath);
+      removeTempDirs = conf.getBoolean(REMOVE_TEMP_DIRS, false);
     } catch (IOException ioe) {
       throw new IllegalStateException(ioe);
     }
@@ -186,14 +190,21 @@ public class DistributedRowMatrix implem
   public Vector times(Vector v) {
     try {
       Configuration initialConf = getConf() == null ? new Configuration() : 
getConf();
+      Path outputVectorTmpPath = new Path(outputTmpBasePath,
+                                          new 
Path(Long.toString(System.nanoTime())));
       Configuration conf =
           TimesSquaredJob.createTimesJobConf(initialConf, 
                                              v,
                                              numRows,
                                              rowPath,
-                                             new Path(outputTmpPath, 
Long.toString(System.nanoTime())));
+                                             outputVectorTmpPath);
       JobClient.runJob(new JobConf(conf));
-      return TimesSquaredJob.retrieveTimesSquaredOutputVector(conf);
+      Vector result = TimesSquaredJob.retrieveTimesSquaredOutputVector(conf);
+      if (removeTempDirs) {
+        FileSystem fs = outputVectorTmpPath.getFileSystem(conf);
+        fs.delete(outputVectorTmpPath, true);
+      }
+      return result;
     } catch (IOException ioe) {
       throw new IllegalStateException(ioe);
     }
@@ -203,14 +214,20 @@ public class DistributedRowMatrix implem
   public Vector timesSquared(Vector v) {
     try {
       Configuration initialConf = getConf() == null ? new Configuration() : 
getConf();
+      Path outputVectorTmpPath = new Path(outputTmpBasePath,
+               new Path(Long.toString(System.nanoTime())));
       Configuration conf =
           TimesSquaredJob.createTimesSquaredJobConf(initialConf,
                                                     v,
                                                     rowPath,
-                                                    new Path(outputTmpBasePath,
-                                                             new 
Path(Long.toString(System.nanoTime()))));
+                                                    outputVectorTmpPath);
       JobClient.runJob(new JobConf(conf));
-      return TimesSquaredJob.retrieveTimesSquaredOutputVector(conf);
+      Vector result = TimesSquaredJob.retrieveTimesSquaredOutputVector(conf);
+      if (removeTempDirs) {
+        FileSystem fs = outputVectorTmpPath.getFileSystem(conf);
+        fs.delete(outputVectorTmpPath, true);
+      }
+      return result;
     } catch (IOException ioe) {
       throw new IllegalStateException(ioe);
     }

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java?rev=1091292&r1=1091291&r2=1091292&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
 Tue Apr 12 05:05:06 2011
@@ -23,6 +23,7 @@ import java.util.Iterator;
 import java.util.Map;
 
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.mahout.clustering.ClusteringTestUtils;
@@ -200,12 +201,81 @@ public final class TestDistributedRowMat
     assertEquals(TEST_PROPERTY_VALUE, 
customTimesSquaredJobConf3.get(TEST_PROPERTY_KEY));
   }
   
+  @Test
+  public void testTimesVectorTempDirDeletion() throws Exception {
+    Configuration conf = new Configuration();
+    Vector v = new RandomAccessSparseVector(50);
+    v.assign(1.0);
+    DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, 
false);
+
+    Path outputPath = dm.getOutputTempPath();
+    FileSystem fs = outputPath.getFileSystem(conf);
+
+    deleteContentsOfPath(conf, outputPath);
+
+    assertEquals(0, fs.listStatus(outputPath).length);
+
+    Vector result1 = dm.times(v);
+
+    assertEquals(1, fs.listStatus(outputPath).length);
+    
+    deleteContentsOfPath(conf, outputPath);
+    assertEquals(0, fs.listStatus(outputPath).length);
+    
+    conf.setBoolean(DistributedRowMatrix.REMOVE_TEMP_DIRS, true);
+    dm.setConf(conf);
+    
+    Vector result2 = dm.times(v);
+
+    assertEquals(0, fs.listStatus(outputPath).length);
+    assertEquals(0.0, result1.getDistanceSquared(result2), EPSILON);
+  }
+
+  @Test
+  public void testTimesSquaredVectorTempDirDeletion() throws Exception {
+    Configuration conf = new Configuration();
+    Vector v = new RandomAccessSparseVector(50);
+    v.assign(1.0);
+    DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, 
false);
+
+    Path outputPath = dm.getOutputTempPath();
+    FileSystem fs = outputPath.getFileSystem(conf);
+
+    deleteContentsOfPath(conf, outputPath);
+
+    assertEquals(0, fs.listStatus(outputPath).length);
+
+    Vector result1 = dm.timesSquared(v);
+
+    assertEquals(1, fs.listStatus(outputPath).length);
+    
+    deleteContentsOfPath(conf, outputPath);
+    assertEquals(0, fs.listStatus(outputPath).length);
+    
+    conf.setBoolean(DistributedRowMatrix.REMOVE_TEMP_DIRS, true);
+    dm.setConf(conf);
+    
+    Vector result2 = dm.timesSquared(v);
+
+    assertEquals(0, fs.listStatus(outputPath).length);
+    assertEquals(0.0, result1.getDistanceSquared(result2), EPSILON);
+  }
+
   public Configuration createInitialConf() {
     Configuration initialConf = new Configuration();
     initialConf.set(TEST_PROPERTY_KEY, TEST_PROPERTY_VALUE);
     return initialConf;
   }
   
+  private void deleteContentsOfPath(Configuration conf, Path path) throws 
Exception {
+    FileSystem fs = path.getFileSystem(conf);
+    
+    FileStatus[] statuses = fs.listStatus(path);
+    for (FileStatus status : statuses) {
+      fs.delete(status.getPath(), true);
+    }    
+  }
+    
   public DistributedRowMatrix randomDistributedMatrix(int numRows,
                                                       int nonNullRows,
                                                       int numCols,


Reply via email to