Recover metrics values from checkpoint

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/3784b541
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/3784b541
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/3784b541

Branch: refs/heads/master
Commit: 3784b5417e3b6053e70c25606f2a7e5022ba0a6a
Parents: d7d49ce
Author: Aviem Zur <aviem...@gmail.com>
Authored: Wed Feb 1 14:46:51 2017 +0200
Committer: Sela <ans...@paypal.com>
Committed: Wed Feb 15 11:10:52 2017 +0200

----------------------------------------------------------------------
 .../apache/beam/runners/spark/SparkRunner.java  |  10 +-
 .../aggregators/AggregatorsAccumulator.java     |  46 ++-----
 .../spark/aggregators/SparkAggregators.java     |   2 +-
 .../spark/metrics/MetricsAccumulator.java       |  65 ++++++++-
 .../spark/translation/TransformTranslator.java  |   4 +-
 .../spark/translation/streaming/Checkpoint.java | 137 +++++++++++++++++++
 .../translation/streaming/CheckpointDir.java    |  69 ----------
 .../SparkRunnerStreamingContextFactory.java     |   1 +
 .../streaming/StreamingTransformTranslator.java |   4 +-
 .../ResumeFromCheckpointStreamingTest.java      |  43 ++++--
 10 files changed, 256 insertions(+), 125 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
index 3dc4857..ebac375 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
@@ -30,13 +30,14 @@ import 
org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.metrics.AggregatorMetricSource;
 import org.apache.beam.runners.spark.metrics.CompositeSource;
+import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
 import org.apache.beam.runners.spark.metrics.SparkBeamMetricSource;
 import org.apache.beam.runners.spark.translation.EvaluationContext;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
 import org.apache.beam.runners.spark.translation.TransformEvaluator;
 import org.apache.beam.runners.spark.translation.TransformTranslator;
-import org.apache.beam.runners.spark.translation.streaming.CheckpointDir;
+import 
org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
 import 
org.apache.beam.runners.spark.translation.streaming.SparkRunnerStreamingContextFactory;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.Read;
@@ -143,6 +144,8 @@ public final class SparkRunner extends 
PipelineRunner<SparkPipelineResult> {
             : Optional.<CheckpointDir>absent();
     final Accumulator<NamedAggregators> aggregatorsAccumulator =
         SparkAggregators.getOrCreateNamedAggregators(jsc, maybeCheckpointDir);
+    // Instantiate metrics accumulator
+    MetricsAccumulator.init(jsc, maybeCheckpointDir);
     final NamedAggregators initialValue = aggregatorsAccumulator.value();
     if (opts.getEnableSparkMetricSinks()) {
       final MetricsSystem metricsSystem = 
SparkEnv$.MODULE$.get().metricsSystem();
@@ -180,10 +183,13 @@ public final class SparkRunner extends 
PipelineRunner<SparkPipelineResult> {
           
JavaStreamingContext.getOrCreate(checkpointDir.getSparkCheckpointDir().toString(),
               contextFactory);
 
-      // Checkpoint aggregator values
+      // Checkpoint aggregator/metrics values
       jssc.addStreamingListener(
           new JavaStreamingListenerWrapper(
               new 
AggregatorsAccumulator.AccumulatorCheckpointingSparkListener()));
+      jssc.addStreamingListener(
+          new JavaStreamingListenerWrapper(
+              new MetricsAccumulator.AccumulatorCheckpointingSparkListener()));
 
       // register listeners.
       for (JavaStreamingListener listener: 
mOptions.as(SparkContextOptions.class).getListeners()) {

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggregatorsAccumulator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggregatorsAccumulator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggregatorsAccumulator.java
index 187205b..1b49e91 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggregatorsAccumulator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AggregatorsAccumulator.java
@@ -21,11 +21,8 @@ package org.apache.beam.runners.spark.aggregators;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Optional;
 import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import org.apache.beam.runners.spark.translation.streaming.CheckpointDir;
-import org.apache.hadoop.fs.FSDataInputStream;
-import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.beam.runners.spark.translation.streaming.Checkpoint;
+import 
org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.spark.Accumulator;
@@ -43,13 +40,11 @@ import org.slf4j.LoggerFactory;
 public class AggregatorsAccumulator {
   private static final Logger LOG = 
LoggerFactory.getLogger(AggregatorsAccumulator.class);
 
-  private static final String ACCUMULATOR_CHECKPOINT_FILENAME = 
"beam_aggregators";
+  private static final String ACCUMULATOR_CHECKPOINT_FILENAME = "aggregators";
 
   private static volatile Accumulator<NamedAggregators> instance;
   private static volatile FileSystem fileSystem;
-  private static volatile Path checkpointPath;
-  private static volatile Path tempCheckpointPath;
-  private static volatile Path backupCheckpointPath;
+  private static volatile Path checkpointFilePath;
 
   @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
   static Accumulator<NamedAggregators> getInstance(
@@ -71,23 +66,12 @@ public class AggregatorsAccumulator {
   private static void recoverValueFromCheckpoint(
       JavaSparkContext jsc,
       CheckpointDir checkpointDir) {
-    FSDataInputStream is = null;
     try {
       Path beamCheckpointPath = checkpointDir.getBeamCheckpointDir();
-      checkpointPath = new Path(beamCheckpointPath, 
ACCUMULATOR_CHECKPOINT_FILENAME);
-      tempCheckpointPath = checkpointPath.suffix(".tmp");
-      backupCheckpointPath = checkpointPath.suffix(".bak");
-      fileSystem = checkpointPath.getFileSystem(jsc.hadoopConfiguration());
-      if (fileSystem.exists(checkpointPath)) {
-        is = fileSystem.open(checkpointPath);
-      } else if (fileSystem.exists(backupCheckpointPath)) {
-        is = fileSystem.open(backupCheckpointPath);
-      }
-      if (is != null) {
-        ObjectInputStream objectInputStream = new ObjectInputStream(is);
-        NamedAggregators recoveredValue =
-            (NamedAggregators) objectInputStream.readObject();
-        objectInputStream.close();
+      checkpointFilePath = new Path(beamCheckpointPath, 
ACCUMULATOR_CHECKPOINT_FILENAME);
+      fileSystem = checkpointFilePath.getFileSystem(jsc.hadoopConfiguration());
+      NamedAggregators recoveredValue = Checkpoint.readObject(fileSystem, 
checkpointFilePath);
+      if (recoveredValue != null) {
         LOG.info("Recovered accumulators from checkpoint: " + recoveredValue);
         instance.setValue(recoveredValue);
       } else {
@@ -99,18 +83,8 @@ public class AggregatorsAccumulator {
   }
 
   private static void checkpoint() throws IOException {
-    if (checkpointPath != null) {
-      if (fileSystem.exists(checkpointPath)) {
-        if (fileSystem.exists(backupCheckpointPath)) {
-          fileSystem.delete(backupCheckpointPath, false);
-        }
-        fileSystem.rename(checkpointPath, backupCheckpointPath);
-      }
-      FSDataOutputStream os = fileSystem.create(tempCheckpointPath, true);
-      ObjectOutputStream oos = new ObjectOutputStream(os);
-      oos.writeObject(instance.value());
-      oos.close();
-      fileSystem.rename(tempCheckpointPath, checkpointPath);
+    if (checkpointFilePath != null) {
+      Checkpoint.writeObject(fileSystem, checkpointFilePath, instance.value());
     }
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
index 326acfe..131b761 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java
@@ -25,7 +25,7 @@ import java.util.Map;
 import org.apache.beam.runners.core.AggregatorFactory;
 import org.apache.beam.runners.core.ExecutionContext;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
-import org.apache.beam.runners.spark.translation.streaming.CheckpointDir;
+import 
org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
 import org.apache.beam.sdk.AggregatorValues;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Combine;

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
index effcbe9..f27a826 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
@@ -19,8 +19,18 @@
 package org.apache.beam.runners.spark.metrics;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Optional;
+import java.io.IOException;
+import org.apache.beam.runners.spark.translation.streaming.Checkpoint;
+import 
org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
 import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.streaming.api.java.JavaStreamingListener;
+import org.apache.spark.streaming.api.java.JavaStreamingListenerBatchCompleted;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
@@ -28,10 +38,18 @@ import org.apache.spark.api.java.JavaSparkContext;
  * @see <a 
href="https://spark.apache.org/docs/1.6.3/streaming-programming-guide.html#accumulators-and-broadcast-variables";>accumulators</a>
  */
 public class MetricsAccumulator {
+  private static final Logger LOG = 
LoggerFactory.getLogger(MetricsAccumulator.class);
+
+  private static final String ACCUMULATOR_CHECKPOINT_FILENAME = "metrics";
 
   private static volatile Accumulator<SparkMetricsContainer> instance = null;
+  private static volatile FileSystem fileSystem;
+  private static volatile Path checkpointFilePath;
 
-  public static Accumulator<SparkMetricsContainer> 
getOrCreateInstance(JavaSparkContext jsc) {
+  @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
+  public static void init(
+      JavaSparkContext jsc,
+      Optional<CheckpointDir> checkpointDir) {
     if (instance == null) {
       synchronized (MetricsAccumulator.class) {
         if (instance == null) {
@@ -41,13 +59,15 @@ public class MetricsAccumulator {
           SparkMetricsContainer initialValue = new SparkMetricsContainer();
           instance = jsc.sc().accumulator(initialValue, "Beam.Metrics",
               new MetricsAccumulatorParam());
+          if (checkpointDir.isPresent()) {
+            recoverValueFromCheckpoint(jsc, checkpointDir.get());
+          }
         }
       }
     }
-    return instance;
   }
 
-  static Accumulator<SparkMetricsContainer> getInstance() {
+  public static Accumulator<SparkMetricsContainer> getInstance() {
     if (instance == null) {
       throw new IllegalStateException("Metrics accumulator has not been 
instantiated");
     } else {
@@ -55,6 +75,25 @@ public class MetricsAccumulator {
     }
   }
 
+  private static void recoverValueFromCheckpoint(
+      JavaSparkContext jsc,
+      CheckpointDir checkpointDir) {
+    try {
+      Path beamCheckpointPath = checkpointDir.getBeamCheckpointDir();
+      checkpointFilePath = new Path(beamCheckpointPath, 
ACCUMULATOR_CHECKPOINT_FILENAME);
+      fileSystem = checkpointFilePath.getFileSystem(jsc.hadoopConfiguration());
+      SparkMetricsContainer recoveredValue = Checkpoint.readObject(fileSystem, 
checkpointFilePath);
+      if (recoveredValue != null) {
+        LOG.info("Recovered metrics from checkpoint: " + recoveredValue);
+        instance.setValue(recoveredValue);
+      } else {
+        LOG.info("No metrics checkpoint found.");
+      }
+    } catch (Exception e) {
+      throw new RuntimeException("Failure while reading metrics checkpoint.", 
e);
+    }
+  }
+
   @SuppressWarnings("unused")
   @VisibleForTesting
   static void clear() {
@@ -62,4 +101,24 @@ public class MetricsAccumulator {
       instance = null;
     }
   }
+
+  private static void checkpoint() throws IOException {
+    if (checkpointFilePath != null) {
+      Checkpoint.writeObject(fileSystem, checkpointFilePath, instance.value());
+    }
+  }
+
+  /**
+   * Spark Listener which checkpoints {@link SparkMetricsContainer} values for 
fault-tolerance.
+   */
+  public static class AccumulatorCheckpointingSparkListener extends 
JavaStreamingListener {
+    @Override
+    public void onBatchCompleted(JavaStreamingListenerBatchCompleted 
batchCompleted) {
+      try {
+        checkpoint();
+      } catch (IOException e) {
+        LOG.error("Failed to checkpoint metrics singleton.", e);
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 584bcc3..5ce1f77 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -254,7 +254,7 @@ public final class TransformTranslator {
         Accumulator<NamedAggregators> aggAccum =
             SparkAggregators.getNamedAggregators(jsc);
         Accumulator<SparkMetricsContainer> metricsAccum =
-            MetricsAccumulator.getOrCreateInstance(jsc);
+            MetricsAccumulator.getInstance();
         Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> 
sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
         context.putDataset(transform,
@@ -281,7 +281,7 @@ public final class TransformTranslator {
         Accumulator<NamedAggregators> aggAccum =
             SparkAggregators.getNamedAggregators(jsc);
         Accumulator<SparkMetricsContainer> metricsAccum =
-            MetricsAccumulator.getOrCreateInstance(jsc);
+            MetricsAccumulator.getInstance();
         JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD
             .mapPartitionsToPair(
                 new MultiDoFnFunction<>(aggAccum, metricsAccum, stepName, doFn,

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/Checkpoint.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/Checkpoint.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/Checkpoint.java
new file mode 100644
index 0000000..a7427b2
--- /dev/null
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/Checkpoint.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation.streaming;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.commons.io.IOUtils;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Checkpoint data to make it available in future pipeline runs.
+ */
+public class Checkpoint {
+  private static final String TEMP_FILE_SUFFIX = ".tmp";
+  private static final String BACKUP_FILE_SUFFIX = ".bak";
+
+  public static void write(FileSystem fileSystem, Path checkpointFilePath, 
byte[] value)
+      throws IOException {
+    Path  tmpPath = checkpointFilePath.suffix(TEMP_FILE_SUFFIX);
+    Path backupPath = checkpointFilePath.suffix(BACKUP_FILE_SUFFIX);
+    if (fileSystem.exists(checkpointFilePath)) {
+      if (fileSystem.exists(backupPath)) {
+        fileSystem.delete(backupPath, false);
+      }
+      fileSystem.rename(checkpointFilePath, backupPath);
+    }
+    FSDataOutputStream os = fileSystem.create(tmpPath, true);
+    os.write(value);
+    os.close();
+    fileSystem.rename(tmpPath, checkpointFilePath);
+  }
+
+  public static void writeObject(FileSystem fileSystem, Path 
checkpointFilePath, Object value)
+      throws IOException {
+    ByteArrayOutputStream bos = new ByteArrayOutputStream();
+    ObjectOutputStream oos = new ObjectOutputStream(bos);
+    oos.writeObject(value);
+    oos.close();
+    write(fileSystem, checkpointFilePath, bos.toByteArray());
+  }
+
+  public static byte[] read(FileSystem fileSystem, Path checkpointFilePath)
+      throws IOException {
+    Path backupCheckpointPath = checkpointFilePath.suffix(".bak");
+    FSDataInputStream is = null;
+    if (fileSystem.exists(checkpointFilePath)) {
+      is = fileSystem.open(checkpointFilePath);
+    } else if (fileSystem.exists(backupCheckpointPath)) {
+      is = fileSystem.open(backupCheckpointPath);
+    }
+    return is != null ? IOUtils.toByteArray(is) : null;
+  }
+
+  @SuppressWarnings("unchecked")
+  public static <T> T readObject(FileSystem fileSystem, Path 
checkpointfilePath)
+      throws IOException, ClassNotFoundException {
+    byte[] bytes = read(fileSystem, checkpointfilePath);
+    if (bytes == null) {
+      return null;
+    }
+    ObjectInputStream objectInputStream = new ObjectInputStream(new 
ByteArrayInputStream(bytes));
+    T value = (T) objectInputStream.readObject();
+    objectInputStream.close();
+    return value;
+  }
+
+  /**
+   * Checkpoint dir tree.
+   *
+   * {@link SparkPipelineOptions} checkpointDir is used as a root directory 
under which one
+   * directory is created for Spark's checkpoint and another for Beam's Spark 
runner's fault
+   * checkpointing needs.
+   * Spark's checkpoint relies on Hadoop's {@link 
org.apache.hadoop.fs.FileSystem} and is used for
+   * Beam as well rather than {@link org.apache.beam.sdk.io.FileSystem} to be 
consistent with Spark.
+   */
+  public static class CheckpointDir {
+    private static final Logger LOG = 
LoggerFactory.getLogger(CheckpointDir.class);
+
+    private static final String SPARK_CHECKPOINT_DIR = "spark-checkpoint";
+    private static final String BEAM_CHECKPOINT_DIR = "beam-checkpoint";
+    private static final String KNOWN_RELIABLE_FS_PATTERN = "^(hdfs|s3|gs)";
+
+    private final Path rootCheckpointDir;
+    private final Path sparkCheckpointDir;
+    private final Path beamCheckpointDir;
+
+    public CheckpointDir(String rootCheckpointDir) {
+      if (!rootCheckpointDir.matches(KNOWN_RELIABLE_FS_PATTERN)) {
+        LOG.warn("The specified checkpoint dir {} does not match a reliable 
filesystem so in case "
+            + "of failures this job may not recover properly or even at all.", 
rootCheckpointDir);
+      }
+      LOG.info("Checkpoint dir set to: {}", rootCheckpointDir);
+
+      this.rootCheckpointDir = new Path(rootCheckpointDir);
+      this.sparkCheckpointDir = new Path(rootCheckpointDir, 
SPARK_CHECKPOINT_DIR);
+      this.beamCheckpointDir = new Path(rootCheckpointDir, 
BEAM_CHECKPOINT_DIR);
+    }
+
+    public Path getRootCheckpointDir() {
+      return rootCheckpointDir;
+    }
+
+    public Path getSparkCheckpointDir() {
+      return sparkCheckpointDir;
+    }
+
+    public Path getBeamCheckpointDir() {
+      return beamCheckpointDir;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CheckpointDir.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CheckpointDir.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CheckpointDir.java
deleted file mode 100644
index 5b192bd..0000000
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CheckpointDir.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.translation.streaming;
-
-import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.hadoop.fs.Path;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-
-/**
- * Spark checkpoint dir tree.
- *
- * {@link SparkPipelineOptions} checkpointDir is used as a root directory 
under which one directory
- * is created for Spark's checkpoint and another for Beam's Spark runner's 
fault tolerance needs.
- * Spark's checkpoint relies on Hadoop's {@link 
org.apache.hadoop.fs.FileSystem} and is used for
- * Beam as well rather than {@link org.apache.beam.sdk.io.FileSystem} to be 
consistent with Spark.
- */
-public class CheckpointDir {
-  private static final Logger LOG = 
LoggerFactory.getLogger(CheckpointDir.class);
-
-  private static final String SPARK_CHECKPOINT_DIR = "spark-checkpoint";
-  private static final String BEAM_CHECKPOINT_DIR = "beam-checkpoint";
-  private static final String KNOWN_RELIABLE_FS_PATTERN = "^(hdfs|s3|gs)";
-
-  private final Path rootCheckpointDir;
-  private final Path sparkCheckpointDir;
-  private final Path beamCheckpointDir;
-
-  public CheckpointDir(String rootCheckpointDir) {
-    if (!rootCheckpointDir.matches(KNOWN_RELIABLE_FS_PATTERN)) {
-      LOG.warn("The specified checkpoint dir {} does not match a reliable 
filesystem so in case "
-          + "of failures this job may not recover properly or even at all.", 
rootCheckpointDir);
-    }
-    LOG.info("Checkpoint dir set to: {}", rootCheckpointDir);
-
-    this.rootCheckpointDir = new Path(rootCheckpointDir);
-    this.sparkCheckpointDir = new Path(rootCheckpointDir, 
SPARK_CHECKPOINT_DIR);
-    this.beamCheckpointDir = new Path(rootCheckpointDir, BEAM_CHECKPOINT_DIR);
-  }
-
-  public Path getRootCheckpointDir() {
-    return rootCheckpointDir;
-  }
-
-  public Path getSparkCheckpointDir() {
-    return sparkCheckpointDir;
-  }
-
-  public Path getBeamCheckpointDir() {
-    return beamCheckpointDir;
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java
index b461856..ffa8e69 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java
@@ -27,6 +27,7 @@ import 
org.apache.beam.runners.spark.translation.EvaluationContext;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
 import org.apache.beam.runners.spark.translation.TransformTranslator;
+import 
org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index f270a99..36cd2f3 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -397,7 +397,7 @@ final class StreamingTransformTranslator {
             final Accumulator<NamedAggregators> aggAccum =
                 SparkAggregators.getNamedAggregators(jsc);
             final Accumulator<SparkMetricsContainer> metricsAccum =
-                MetricsAccumulator.getOrCreateInstance(jsc);
+                MetricsAccumulator.getInstance();
             final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
                 TranslationUtils.getSideInputs(transform.getSideInputs(),
                     jsc, pviews);
@@ -438,7 +438,7 @@ final class StreamingTransformTranslator {
             final Accumulator<NamedAggregators> aggAccum =
                 SparkAggregators.getNamedAggregators(jsc);
             final Accumulator<SparkMetricsContainer> metricsAccum =
-                MetricsAccumulator.getOrCreateInstance(jsc);
+                MetricsAccumulator.getInstance();
             final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
                 TranslationUtils.getSideInputs(transform.getSideInputs(),
                     JavaSparkContext.fromSparkContext(rdd.context()), pviews);

http://git-wip-us.apache.org/repos/asf/beam/blob/3784b541/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
index 5a27b29..62ee672 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
@@ -17,7 +17,9 @@
  */
 package org.apache.beam.runners.spark.translation.streaming;
 
+import static 
org.apache.beam.sdk.metrics.MetricMatchers.attemptedMetricsResult;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasItem;
 import static org.junit.Assert.assertThat;
 
 import com.google.common.collect.ImmutableList;
@@ -39,6 +41,10 @@ import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipeli
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.kafka.KafkaIO;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.metrics.MetricsFilter;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -77,7 +83,9 @@ public class ResumeFromCheckpointStreamingTest {
   );
   private static final String[] EXPECTED = {"k1,v1", "k2,v2", "k3,v3", 
"k4,v4"};
   private static final long EXPECTED_AGG_FIRST = 4L;
+  private static final long EXPECTED_COUNTER_FIRST = 4L;
   private static final long EXPECTED_AGG_SECOND = 8L;
+  private static final long EXPECTED_COUNTER_SECOND = 8L;
 
   @Rule
   public TemporaryFolder checkpointParentDir = new TemporaryFolder();
@@ -131,12 +139,20 @@ public class ResumeFromCheckpointStreamingTest {
     // checkpoint after first (and only) interval.
     options.setCheckpointDurationMillis(options.getBatchIntervalMillis());
 
+    MetricsFilter metricsFilter =
+        MetricsFilter.builder()
+            
.addNameFilter(MetricNameFilter.inNamespace(ResumeFromCheckpointStreamingTest.class))
+            .build();
+
     // first run will read from Kafka backlog - "auto.offset.reset=smallest"
     SparkPipelineResult res = run(options);
     long processedMessages1 = res.getAggregatorValue("processedMessages", 
Long.class);
     assertThat(String.format("Expected %d processed messages count but "
         + "found %d", EXPECTED_AGG_FIRST, processedMessages1), 
processedMessages1,
             equalTo(EXPECTED_AGG_FIRST));
+    assertThat(res.metrics().queryMetrics(metricsFilter).counters(),
+        
hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(),
+            "aCounter", "formatKV", EXPECTED_COUNTER_FIRST)));
 
     // recovery should resume from last read offset, and read the second batch 
of input.
     res = runAgain(options);
@@ -144,6 +160,9 @@ public class ResumeFromCheckpointStreamingTest {
     assertThat(String.format("Expected %d processed messages count but "
         + "found %d", EXPECTED_AGG_SECOND, processedMessages2), 
processedMessages2,
             equalTo(EXPECTED_AGG_SECOND));
+    assertThat(res.metrics().queryMetrics(metricsFilter).counters(),
+        
hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(),
+            "aCounter", "formatKV", EXPECTED_COUNTER_SECOND)));
   }
 
   private SparkPipelineResult runAgain(SparkPipelineOptions options) {
@@ -177,16 +196,20 @@ public class ResumeFromCheckpointStreamingTest {
 
     PCollection<String> formattedKV =
         p.apply(read.withoutMetadata())
-          .apply(ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
-               @ProcessElement
-               public void process(ProcessContext c) {
-                  // Check side input is passed correctly also after resuming 
from checkpoint
-                  Assert.assertEquals(c.sideInput(expectedView), 
Arrays.asList(EXPECTED));
-                  c.output(c.element());
-                }
-          }).withSideInputs(expectedView))
-        .apply(Window.<KV<String, 
String>>into(FixedWindows.of(windowDuration)))
-        .apply(ParDo.of(new FormatAsText()));
+            .apply("formatKV", ParDo.of(new DoFn<KV<String, String>, 
KV<String, String>>() {
+              Counter counter =
+                  Metrics.counter(ResumeFromCheckpointStreamingTest.class, 
"aCounter");
+
+              @ProcessElement
+              public void process(ProcessContext c) {
+                // Check side input is passed correctly also after resuming 
from checkpoint
+                Assert.assertEquals(c.sideInput(expectedView), 
Arrays.asList(EXPECTED));
+                counter.inc();
+                c.output(c.element());
+              }
+            }).withSideInputs(expectedView))
+            .apply(Window.<KV<String, 
String>>into(FixedWindows.of(windowDuration)))
+            .apply(ParDo.of(new FormatAsText()));
 
     // graceful shutdown will make sure first batch (at least) will finish.
     Duration timeout = Duration.standardSeconds(1L);

Reply via email to