Github user BryanCutler commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21465#discussion_r239245388
  
    --- Diff: python/pyspark/ml/regression.py ---
    @@ -705,12 +710,59 @@ def getNumTrees(self):
             return self.getOrDefault(self.numTrees)
     
     
    -class GBTParams(TreeEnsembleParams):
    +class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, 
HasValidationIndicatorCol):
         """
         Private class to track supported GBT params.
         """
    +
    +    stepSize = Param(Params._dummy(), "stepSize",
    +                     "Step size (a.k.a. learning rate) in interval (0, 1] 
for shrinking " +
    +                     "the contribution of each estimator.",
    +                     typeConverter=TypeConverters.toFloat)
    +
    +    validationTol = Param(Params._dummy(), "validationTol",
    +                          "Threshold for stopping early when fit with 
validation is used. " +
    +                          "If the error rate on the validation input 
changes by less than the " +
    +                          "validationTol, then learning will stop early 
(before `maxIter`). " +
    +                          "This parameter is ignored when fit without 
validation is used.",
    +                          typeConverter=TypeConverters.toFloat)
    +
    +    @since("3.0.0")
    +    def setValidationTol(self, value):
    +        """
    +        Sets the value of :py:attr:`validationTol`.
    +        """
    +        return self._set(validationTol=value)
    +
    +    @since("3.0.0")
    +    def getValidationTol(self):
    +        """
    +        Gets the value of validationTol or its default value.
    +        """
    +        return self.getOrDefault(self.validationTol)
    +
    +
    +class GBTRegressorParams(GBTParams, TreeRegressorParams):
    +    """
    +    Private class to track supported GBTRegressor params.
    +
    +    .. versionadded:: 3.0.0
    +    """
    +
         supportedLossTypes = ["squared", "absolute"]
     
    +    lossType = Param(Params._dummy(), "lossType",
    +                     "Loss function which GBT tries to minimize 
(case-insensitive). " +
    +                     "Supported options: " + ", ".join(supportedLossTypes),
    +                     typeConverter=TypeConverters.toString)
    +
    +    @since("1.4.0")
    +    def setLossType(self, value):
    --- End diff --
    
    `setLossType` should be in the estimator and `getLossType` should be here


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to