Github user HyukjinKwon commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22365#discussion_r216233575
  
    --- Diff: python/pyspark/sql/dataframe.py ---
    @@ -880,18 +880,23 @@ def sampleBy(self, col, fractions, seed=None):
             |  0|    5|
             |  1|    9|
             +---+-----+
    +        >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, 
seed=0).count()
    +        33
     
             """
    -        if not isinstance(col, basestring):
    -            raise ValueError("col must be a string, but got %r" % 
type(col))
    +        if isinstance(col, basestring):
    +            col = Column(col)
    +        elif not isinstance(col, Column):
    +            raise ValueError("col must be a string or a column, but got 
%r" % type(col))
             if not isinstance(fractions, dict):
                 raise ValueError("fractions must be a dict but got %r" % 
type(fractions))
             for k, v in fractions.items():
                 if not isinstance(k, (float, int, long, basestring)):
                     raise ValueError("key must be float, int, long, or string, 
but got %r" % type(k))
                 fractions[k] = float(v)
             seed = seed if seed is not None else random.randint(0, sys.maxsize)
    -        return DataFrame(self._jdf.stat().sampleBy(col, 
self._jmap(fractions), seed), self.sql_ctx)
    +        return DataFrame(self._jdf.stat()
    +                         .sampleBy(col._jc, self._jmap(fractions), seed), 
self.sql_ctx)
    --- End diff --
    
    I would just do `col = col._jc`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to