Github user hhbyyh commented on a diff in the pull request: https://github.com/apache/spark/pull/19993#discussion_r157870214 --- Diff: mllib/src/main/scala/org/apache/spark/ml/param/params.scala --- @@ -249,6 +250,29 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** + * Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If + * this is not true, an `IllegalArgumentException` is raised. + * @param model + */ + def assertColOrCols(model: Params): Unit = { + model match { + case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => + raiseIncompatibleParamsException("inputCols", "inputCol") + case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) => + raiseIncompatibleParamsException("outputCols", "inputCol") + case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) => + raiseIncompatibleParamsException("inputCols", "outputCol") + case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) => + raiseIncompatibleParamsException("outputCols", "outputCol") + case _ => + } + } + + def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { --- End diff -- private[spark]
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org