This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 52e2e8e  Fix scalastyle (#14669)
52e2e8e is described below

commit 52e2e8ed630886b5e8edce20b960fa7d62620287
Author: Zach Kimberg <zach...@kimberg.com>
AuthorDate: Thu Apr 11 10:07:52 2019 -0700

    Fix scalastyle (#14669)
---
 .../main/scala/org/apache/mxnet/FeedForward.scala  |  15 +-
 .../org/apache/mxnet/module/BucketingModule.scala  |  54 ++++----
 .../scala/org/apache/mxnet/module/Module.scala     | 153 ++++++++++-----------
 .../org/apache/mxnet/module/SequentialModule.scala | 135 +++++++++---------
 .../org/apache/mxnetexamples/rnn/BucketIo.scala    |   8 +-
 .../org/apache/mxnet/spark/utils/Network.scala     |  22 ++-
 .../apache/mxnet/spark/SharedSparkContext.scala    |  12 +-
 7 files changed, 194 insertions(+), 205 deletions(-)

diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
index 2ed9d8c..2b17655 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
@@ -180,6 +180,7 @@ class FeedForward private(
 
   // Initialize the predictor module for running prediction.
   private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
+    var shouldInit = true
     if (this.predExec != null) {
       val (argShapes, _, _) = symbol.inferShape(inputShapes)
       require(argShapes != null, "Shape inference failed." +
@@ -187,14 +188,16 @@ class FeedForward private(
         s"and aux states ${symbol.listAuxiliaryStates()}")
       val predShapes = this.predExec.argArrays.map(_.shape)
       if (argShapes.sameElements(predShapes)) {
-        return
+        shouldInit = false
       }
     }
-    // for now only use the first device
-    val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = 
inputShapes)
-    predExec.copyParamsFrom(_argParams, _auxParams)
-    ExecutorManager.checkArguments(symbol)
-    this.predExec = predExec
+    if(shouldInit) {
+      // for now only use the first device
+      val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = 
inputShapes)
+      predExec.copyParamsFrom(_argParams, _auxParams)
+      ExecutorManager.checkArguments(symbol)
+      this.predExec = predExec
+    }
   }
 
   // Initialize the iterator given input.
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
index 1ac798e..41a6f69 100644
--- 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
@@ -173,14 +173,13 @@ class BucketingModule(symGen: AnyRef => (Symbol, 
IndexedSeq[String], IndexedSeq[
                           allowMissing: Boolean = false,
                           forceInit: Boolean = false,
                           allowExtra: Boolean = false): Unit = {
-    if (paramsInitialized && !forceInit) {
-      return
+    if (!paramsInitialized || forceInit) {
+      require(binded, "call bind before initializing the parameters")
+      this._currModule.initParams(initializer, argParams, auxParams,
+        allowMissing, forceInit, allowExtra)
+      this.paramsDirty = false
+      this.paramsInitialized = true
     }
-    require(binded, "call bind before initializing the parameters")
-    this._currModule.initParams(initializer, argParams, auxParams,
-      allowMissing, forceInit, allowExtra)
-    this.paramsDirty = false
-    this.paramsInitialized = true
   }
 
   /**
@@ -218,28 +217,27 @@ class BucketingModule(symGen: AnyRef => (Symbol, 
IndexedSeq[String], IndexedSeq[
 
     if (this.binded) {
       logger.warn("Already bound, ignoring bind()")
-      return
-    }
+    } else {
+      require(sharedModule.isEmpty,
+        "sharedModule for BucketingModule is not supported")
 
-    require(sharedModule.isEmpty,
-      "sharedModule for BucketingModule is not supported")
-
-    this.forTraining = forTraining
-    this.inputsNeedGrad = inputsNeedGrad
-    this.binded = true
-
-    val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
-    val module = new Module(sym, dNames, lNames, this.contexts,
-      this.workLoadList, this.fixedParamNames)
-    module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
-      forceRebind = false, sharedModule = None, gradReq)
-    this._currModule = module
-    this._currBucketKey = this.defaultBucketKey
-    this._buckets(this.defaultBucketKey) = module
-
-    // copy back saved params, if already initialized
-    if (this.paramsInitialized) {
-      this.setParams(argParams, auxParams)
+      this.forTraining = forTraining
+      this.inputsNeedGrad = inputsNeedGrad
+      this.binded = true
+
+      val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
+      val module = new Module(sym, dNames, lNames, this.contexts,
+        this.workLoadList, this.fixedParamNames)
+      module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
+        forceRebind = false, sharedModule = None, gradReq)
+      this._currModule = module
+      this._currBucketKey = this.defaultBucketKey
+      this._buckets(this.defaultBucketKey) = module
+
+      // copy back saved params, if already initialized
+      if (this.paramsInitialized) {
+        this.setParams(argParams, auxParams)
+      }
     }
   }
 
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index 97df3dc..3255d93 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -121,36 +121,35 @@ class Module(symbolVar: Symbol,
                           allowMissing: Boolean = false,
                           forceInit: Boolean = false,
                           allowExtra: Boolean = false): Unit = {
-    if (paramsInitialized && !forceInit) {
-      return
-    }
-    require(binded, "call bind before initializing the parameters")
+    if (!paramsInitialized || forceInit) {
+      require(binded, "call bind before initializing the parameters")
 
-    if (this.argParams == null) {
-      val paramArrays =
-        execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = 
nds(0).dtype))
-      this.argParams = this.paramNames.zip(paramArrays).toMap
-    }
+      if (this.argParams == null) {
+        val paramArrays =
+          execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = 
nds(0).dtype))
+        this.argParams = this.paramNames.zip(paramArrays).toMap
+      }
 
-    if (this.auxParams == null) {
-      val auxArrays =
-        execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = 
nds(0).dtype))
-      this.auxParams = this.auxNames.zip(auxArrays).toMap
-    }
+      if (this.auxParams == null) {
+        val auxArrays =
+          execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = 
nds(0).dtype))
+        this.auxParams = this.auxNames.zip(auxArrays).toMap
+      }
 
-    this.argParams.foreach { case (name, arr) =>
-      impl(name, arr, allowMissing, Option(initializer), argParams)
-    }
+      this.argParams.foreach { case (name, arr) =>
+        impl(name, arr, allowMissing, Option(initializer), argParams)
+      }
 
-    this.auxParams.foreach { case (name, arr) =>
-      impl(name, arr, allowMissing, Option(initializer), auxParams)
-    }
+      this.auxParams.foreach { case (name, arr) =>
+        impl(name, arr, allowMissing, Option(initializer), auxParams)
+      }
 
-    this.paramsInitialized = true
-    this.paramsDirty = false
+      this.paramsInitialized = true
+      this.paramsDirty = false
 
-    // copy the initialized parameters to devices
-    this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = 
allowExtra)
+      // copy the initialized parameters to devices
+      this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = 
allowExtra)
+    }
   }
 
   // Internal helper for parameter initialization
@@ -246,64 +245,64 @@ class Module(symbolVar: Symbol,
 
     if (binded) {
       logger.warn("Already binded, ignoring bind()")
-      return
-    }
+    } else {
+      this.forTraining = forTraining
+      this.inputsNeedGrad = inputsNeedGrad
+      this.binded = true
 
-    this.forTraining = forTraining
-    this.inputsNeedGrad = inputsNeedGrad
-    this.binded = true
+      if (!forTraining) {
+        require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if 
not forTraining)")
+      } else {
+        // this is not True, as some module might not contains a loss function
+        // that consumes the labels
+        // require(labelShapes != None)
+      }
 
-    if (!forTraining) {
-      require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not 
forTraining)")
-    } else {
-      // this is not True, as some module might not contains a loss function
-      // that consumes the labels
-      // require(labelShapes != None)
-    }
+      this.dataShapesVar = dataShapes
+      this.labelShapesVar = labelShapes
 
-    this.dataShapesVar = dataShapes
-    this.labelShapesVar = labelShapes
-
-    val sharedGroup =
-      sharedModule.map(sharedModuleInst => {
-        require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized,
-          s"bind() and initParams() must be called first on shared module.")
-        sharedModuleInst.execGroup
-      })
-
-    val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, 
dataDesc.dtype)).toMap ++
-      labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, 
dataDesc.dtype)).toMap)
-                 .getOrElse(Map.empty[String, DType])
-
-    execGroup = new Builder(symbol, contexts, paramNames)
-      .setWorkLoadList(workLoads)
-      .setDataShapes(dataShapes)
-      .setLabelShapes(labelShapes.orNull)
-      .setForTraining(forTraining)
-      .setInputsNeedGrad(inputsNeedGrad)
-      .setSharedGroup(sharedGroup.orNull)
-      .setFixedParamNames(fixedParamNames.orNull)
-      .setGradReq(gradReq)
-      .setInputTypes(inputTypes)
-      .build()
-
-    if (sharedModule.isDefined) {
-      paramsInitialized = true
-      argParams = sharedModule.get.argParams
-      auxParams = sharedModule.get.auxParams
-    } else if (paramsInitialized) {
-      // if the parameters are already initialized, we are re-binding
-      // so automatically copy the already initialized params
-      execGroup.setParams(argParams, auxParams)
-    }
+      val sharedGroup =
+        sharedModule.map(sharedModuleInst => {
+          require(sharedModuleInst.binded && 
sharedModuleInst.paramsInitialized,
+            s"bind() and initParams() must be called first on shared module.")
+          sharedModuleInst.execGroup
+        })
 
-    sharedModule.foreach {
-      case sharedModuleInst: Module =>
-        if (sharedModuleInst.optimizerInitialized) {
-          borrowOptimizer(sharedModuleInst)
-        }
-      case _ =>
+      val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, 
dataDesc.dtype)).toMap ++
+        labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, 
dataDesc.dtype)).toMap)
+          .getOrElse(Map.empty[String, DType])
+
+      execGroup = new Builder(symbol, contexts, paramNames)
+        .setWorkLoadList(workLoads)
+        .setDataShapes(dataShapes)
+        .setLabelShapes(labelShapes.orNull)
+        .setForTraining(forTraining)
+        .setInputsNeedGrad(inputsNeedGrad)
+        .setSharedGroup(sharedGroup.orNull)
+        .setFixedParamNames(fixedParamNames.orNull)
+        .setGradReq(gradReq)
+        .setInputTypes(inputTypes)
+        .build()
+
+      if (sharedModule.isDefined) {
+        paramsInitialized = true
+        argParams = sharedModule.get.argParams
+        auxParams = sharedModule.get.auxParams
+      } else if (paramsInitialized) {
+        // if the parameters are already initialized, we are re-binding
+        // so automatically copy the already initialized params
+        execGroup.setParams(argParams, auxParams)
+      }
+
+      sharedModule.foreach {
+        case sharedModuleInst: Module =>
+          if (sharedModuleInst.optimizerInitialized) {
+            borrowOptimizer(sharedModuleInst)
+          }
+        case _ =>
+      }
     }
+
   }
 
   /**
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
index 2e506c0..3c3eeb9 100644
--- 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
@@ -154,38 +154,37 @@ class SequentialModule extends BaseModule {
                           allowMissing: Boolean = false,
                           forceInit: Boolean = false,
                           allowExtra: Boolean = false): Unit = {
-    if (this.paramsInitialized && !forceInit) {
-      return
-    }
-    require(this.binded, "call bind before initializing the parameters")
+    if (!this.paramsInitialized || forceInit) {
+      require(this.binded, "call bind before initializing the parameters")
 
-    for (module <- this.modules) {
-      module.initParams(initializer = initializer, argParams = argParams,
-          auxParams = auxParams, allowMissing = allowMissing,
-          forceInit = forceInit, allowExtra = allowExtra)
-    }
+      for (module <- this.modules) {
+        module.initParams(initializer = initializer, argParams = argParams,
+            auxParams = auxParams, allowMissing = allowMissing,
+            forceInit = forceInit, allowExtra = allowExtra)
+      }
 
-    // Internal function to help checking duplicated names,
-    // make sure we do not have duplicated parameter names.
-    def checkName(knownNames: scala.collection.mutable.Map[String, Int],
-      newNames: Array[String], modules: ArrayBuffer[BaseModule], i: Int): Unit 
= {
-      for (name <- newNames) {
-        require(!knownNames.contains(name), s"Duplicated parameter names: " +
-            s"name $name in layer $i (${modules(i).getClass.getName}) is 
already " +
-            s"used in layer ${knownNames("name")}" +
-            s"(${modules(knownNames("name")).getClass.getName})")
-        knownNames(name) = i
+      // Internal function to help checking duplicated names,
+      // make sure we do not have duplicated parameter names.
+      def checkName(knownNames: scala.collection.mutable.Map[String, Int],
+        newNames: Array[String], modules: ArrayBuffer[BaseModule], i: Int): 
Unit = {
+        for (name <- newNames) {
+          require(!knownNames.contains(name), s"Duplicated parameter names: " +
+              s"name $name in layer $i (${modules(i).getClass.getName}) is 
already " +
+              s"used in layer ${knownNames("name")}" +
+              s"(${modules(knownNames("name")).getClass.getName})")
+          knownNames(name) = i
+        }
       }
-    }
 
-    val argNames = scala.collection.mutable.Map[String, Int]()
-    val auxNames = scala.collection.mutable.Map[String, Int]()
-    for ((module, iLayer) <- this.modules.zipWithIndex) {
-      val (argParams, auxParams) = module.getParams
-      checkName(argNames, argParams.keys.toArray, this.modules, iLayer)
-      checkName(auxNames, auxParams.keys.toArray, this.modules, iLayer)
+      val argNames = scala.collection.mutable.Map[String, Int]()
+      val auxNames = scala.collection.mutable.Map[String, Int]()
+      for ((module, iLayer) <- this.modules.zipWithIndex) {
+        val (argParams, auxParams) = module.getParams
+        checkName(argNames, argParams.keys.toArray, this.modules, iLayer)
+        checkName(auxNames, auxParams.keys.toArray, this.modules, iLayer)
+      }
+      this.paramsInitialized = true
     }
-    this.paramsInitialized = true
   }
 
   /**
@@ -216,54 +215,54 @@ class SequentialModule extends BaseModule {
                     gradReq: String = "write"): Unit = {
     if (this.binded && !forceRebind) {
       logger.warn(s"Already binded, ignoring bind()")
-      return
-    }
-
-    if (inputsNeedGrad) {
-      require(forTraining, "inputsNeedGrad can be set only for training")
-    }
-
-    require(sharedModule == None, "Shared module is not supported")
-    require(this.modules.length > 0, "Attempting to bind an empty 
SequentialModule")
-
-    this.forTraining = forTraining
-    this.inputsNeedGrad = inputsNeedGrad
-    this.binded = true
-
-    // the same label shapes are used for all chained modules
-    this.labelShapesVar = labelShapes
+    } else {
+      if (inputsNeedGrad) {
+        require(forTraining, "inputsNeedGrad can be set only for training")
+      }
 
-    var myDataShapes = dataShapes
-    var myLabelShapes = labelShapes
-    var anybodyEverNeedsLabel = false
-    for ((module, iLayer) <- this.modules.zipWithIndex) {
-      val meta = this.metas(iLayer)
-      if (meta.contains(META_TAKE_LABELS) && meta(META_TAKE_LABELS)) {
-        myLabelShapes = labelShapes
-        anybodyEverNeedsLabel = true
-      } else myLabelShapes = None
-
-      val myInputsNeedGrad = if (inputsNeedGrad || (forTraining && iLayer > 
0)) true else false
-      if (meta.contains(META_AUTO_WIRING) && meta(META_AUTO_WIRING)) {
-        val dataNames = module.dataNames
-        require(dataNames.length == myDataShapes.length,
-          s"dataNmes $dataNames and dataShapes $myDataShapes do not match")
-        myDataShapes = dataNames.zip(myDataShapes).map { case (newName, 
dataDes) =>
-          DataDesc(newName, dataDes.shape)
+      require(sharedModule == None, "Shared module is not supported")
+      require(this.modules.length > 0, "Attempting to bind an empty 
SequentialModule")
+
+      this.forTraining = forTraining
+      this.inputsNeedGrad = inputsNeedGrad
+      this.binded = true
+
+      // the same label shapes are used for all chained modules
+      this.labelShapesVar = labelShapes
+
+      var myDataShapes = dataShapes
+      var myLabelShapes = labelShapes
+      var anybodyEverNeedsLabel = false
+      for ((module, iLayer) <- this.modules.zipWithIndex) {
+        val meta = this.metas(iLayer)
+        if (meta.contains(META_TAKE_LABELS) && meta(META_TAKE_LABELS)) {
+          myLabelShapes = labelShapes
+          anybodyEverNeedsLabel = true
+        } else myLabelShapes = None
+
+        val myInputsNeedGrad = if (inputsNeedGrad || (forTraining && iLayer > 
0)) true else false
+        if (meta.contains(META_AUTO_WIRING) && meta(META_AUTO_WIRING)) {
+          val dataNames = module.dataNames
+          require(dataNames.length == myDataShapes.length,
+            s"dataNmes $dataNames and dataShapes $myDataShapes do not match")
+          myDataShapes = dataNames.zip(myDataShapes).map { case (newName, 
dataDes) =>
+            DataDesc(newName, dataDes.shape)
+          }
         }
-      }
 
-      module.bind(myDataShapes, myLabelShapes, forTraining, myInputsNeedGrad,
+        module.bind(myDataShapes, myLabelShapes, forTraining, myInputsNeedGrad,
           forceRebind, sharedModule = None, gradReq)
-      // the output of the previous module is the data of the next module
-      myDataShapes = module.outputShapes.map{case (name, shape) => 
DataDesc(name, shape)}
-    }
+        // the output of the previous module is the data of the next module
+        myDataShapes = module.outputShapes.map{case (name, shape) => 
DataDesc(name, shape)}
+      }
 
 
-    if (!anybodyEverNeedsLabel) {
-      // then I do not need label either
-      this.labelShapesVar = None
+      if (!anybodyEverNeedsLabel) {
+        // then I do not need label either
+        this.labelShapesVar = None
+      }
     }
+
   }
 
   /**
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index 6d414bb..350e28c 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -202,10 +202,10 @@ object BucketIo {
       labelBuf.set(labels.flatten)
 
       iBucket += 1
-      val batchProvideData = { val tmp = ListMap("data" -> dataBuf.shape)
-        tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
-      }
-      val batchProvideLabel = ListMap("softmax_label" -> labelBuf.shape)
+      val batchProvideData = IndexedSeq(DataDesc("data", dataBuf.shape, 
dataBuf.dtype)) ++
+        initStates.map {
+          case (name, shape) => DataDesc(name, Shape(shape._1, shape._2), 
DType.Float32)}
+      val batchProvideLabel = IndexedSeq(DataDesc("softmax_label", 
labelBuf.shape, labelBuf.dtype))
       val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, 
x._2._2))
       new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
         IndexedSeq(labelBuf.copy()),
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala
index c61229a..836901f 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/utils/Network.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet.spark.utils
 import java.io.IOException
 import java.net.{ServerSocket, NetworkInterface}
 import java.util.regex.Pattern
+import scala.collection.JavaConverters._
 
 /**
  * Helper functions to decide ip address / port
@@ -33,19 +34,16 @@ object Network {
       "([01]?\\d\\d?|2[0-4]\\d|25[0-5])$")
 
   def ipAddress: String = {
-    val interfaces = NetworkInterface.getNetworkInterfaces
-    while (interfaces.hasMoreElements) {
-      val interface = interfaces.nextElement
-      val addresses = interface.getInetAddresses
-      while (addresses.hasMoreElements) {
-        val address = addresses.nextElement
-        val ip = address.getHostAddress
-        if (!ip.startsWith("127.") && IPADDRESS_PATTERN.matcher(ip).matches()) 
{
-          return ip
+    val interfaces = NetworkInterface.getNetworkInterfaces.asScala
+    val interface = interfaces.toStream.flatMap(
+      _.getInetAddresses.asScala.toStream.flatMap(
+        address => {
+          val ip = address.getHostAddress
+          Option(ip).filter(ip => !ip.startsWith("127.") && 
IPADDRESS_PATTERN.matcher(ip).matches())
         }
-      }
-    }
-    "127.0.0.1"
+      )
+    ).headOption
+    interface.getOrElse("127.0.0.1")
   }
 
   def availablePort: Int = {
diff --git 
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
 
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
index 6d36ca5..293cfa1 100644
--- 
a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
+++ 
b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/SharedSparkContext.scala
@@ -92,20 +92,12 @@ trait SharedSparkContext extends FunSuite with 
BeforeAndAfterEach with BeforeAnd
 
   private def getJarFilePath(root: String): String = {
     val jarFiles = findJars(s"$root/target/")
-    if (jarFiles != null && jarFiles.nonEmpty) {
-      jarFiles.head.getAbsolutePath
-    } else {
-      null
-    }
+    Option(jarFiles).flatMap(_.headOption).map(_.getAbsolutePath).orNull
   }
 
   private def getSparkJar: String = {
     val jarFiles = findJars(s"$composeWorkingDirPath/target/")
-    if (jarFiles != null && jarFiles.nonEmpty) {
-      jarFiles.head.getAbsolutePath
-    } else {
-      null
-    }
+    Option(jarFiles).flatMap(_.headOption).map(_.getAbsolutePath).orNull
   }
 
   private def getNativeJars(root: String): String =

Reply via email to