Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/22365#discussion_r216233575 --- Diff: python/pyspark/sql/dataframe.py --- @@ -880,18 +880,23 @@ def sampleBy(self, col, fractions, seed=None): | 0| 5| | 1| 9| +---+-----+ + >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count() + 33 """ - if not isinstance(col, basestring): - raise ValueError("col must be a string, but got %r" % type(col)) + if isinstance(col, basestring): + col = Column(col) + elif not isinstance(col, Column): + raise ValueError("col must be a string or a column, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) for k, v in fractions.items(): if not isinstance(k, (float, int, long, basestring)): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + return DataFrame(self._jdf.stat() + .sampleBy(col._jc, self._jmap(fractions), seed), self.sql_ctx) --- End diff -- I would just do `col = col._jc`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org