zhidongqu-db commented on code in PR #55682:
URL: https://github.com/apache/spark/pull/55682#discussion_r3190627051
##########
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala:
##########
@@ -1569,3 +1614,46 @@ class Dataset[T] private[sql] (
override def queryExecution: QueryExecution =
throw ConnectClientUnsupportedErrors.queryExecution()
}
+
+private[sql] object Dataset {
+ // Acceptance lists for `nearestByJoin`. Must stay aligned with
`NearestByJoinType` /
Review Comment:
how do we keep these in sync? is there a good way to share the same
validation list? like move the lists to sql/api so that both sql/connect and
sql/catalyst could reuse it.
##########
python/pyspark/sql/connect/dataframe.py:
##########
@@ -726,6 +726,30 @@ def lateralJoin(
session=self._session,
)
+ def nearestByJoin(
Review Comment:
we need Spark connect tests for nearestByJoin - see lateralJoin tests in
DataFrameSubquerySuite and PlanGenerationTestSuite
##########
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala:
##########
@@ -1569,3 +1614,46 @@ class Dataset[T] private[sql] (
override def queryExecution: QueryExecution =
throw ConnectClientUnsupportedErrors.queryExecution()
}
+
+private[sql] object Dataset {
+ // Acceptance lists for `nearestByJoin`. Must stay aligned with
`NearestByJoinType` /
+ // `NearestByJoinMode` / `NearestByDirection` in sql/catalyst, which
`sql/connect/common`
+ // cannot import.
+ private val MaxNumResults: Int = 100000
+ private val SupportedJoinTypeDisplay = "'INNER', 'LEFT OUTER'"
+ private val SupportedJoinTypes = Set("inner", "leftouter", "left",
"left_outer")
+ private val SupportedModes = Seq("approx", "exact")
+ private val SupportedDirections = Seq("distance", "similarity")
+
+ private[connect] def validateNearestByJoinArgs(
Review Comment:
look like we are missing similar validations in python client?
##########
python/pyspark/sql/dataframe.py:
##########
@@ -2865,6 +2865,74 @@ def lateralJoin(
"""
...
+ def nearestByJoin(
+ self,
+ other: "DataFrame",
+ rankingExpression: Column,
+ numResults: int,
+ mode: str,
+ direction: str,
+ *,
+ joinType: str = "inner",
+ ) -> "DataFrame":
+ """
+ Nearest-by top-K ranking join with another :class:`DataFrame`. For
each row on the
+ left (query side), returns up to ``numResults`` rows from ``other``
(base side), ranked
+ by ``rankingExpression``.
+
+ The current implementation evaluates the full cross-product of left
and right and
+ bounds memory per left row by ``numResults``. Index-backed approximate
strategies
+ (transparent to ``approx`` mode) are planned for a future release;
until then,
+ pre-filter ``other`` when it is large. Tie-breaking among rows with
equal ranking
+ values is unspecified.
+
+ .. versionadded:: 4.2.0
+
+ Parameters
+ ----------
+ other : :class:`DataFrame`
+ Right (base side) of the join - the candidate pool searched for
each row of this
+ DataFrame.
+ rankingExpression : :class:`Column`
+ Scalar expression used to rank candidate rows on the right side.
+ numResults : int
+ Maximum number of matches per query row. Must be between 1 and
100000.
+ mode : str
+ Search algorithm contract. Must be one of: ``approx``, ``exact``.
``approx`` allows
+ the optimizer to use indexed or other approximate strategies when
available;
+ ``exact`` forces brute-force evaluation and requires the ranking
expression to be
+ deterministic.
+ direction : str
+ ``"distance"`` (smallest values first) or ``"similarity"``
(largest values first).
+ joinType : str, keyword-only, optional
+ Default ``inner``. Must be one of: ``inner``, ``left``,
``leftouter``,
+ ``left_outer``.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ Joined DataFrame.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> users = spark.createDataFrame(
+ ... [(1, 10.0), (2, 20.0), (3, 30.0)], ["user_id", "score"])
+ >>> products = spark.createDataFrame(
+ ... [("A", 11.0), ("B", 22.0), ("C", 5.0)], ["product", "pscore"])
+ >>> users.nearestByJoin(
+ ... products, -sf.abs(users.score - products.pscore), 1, "exact",
"similarity"
+ ... ).select("user_id", "product").orderBy("user_id").show()
Review Comment:
the doctest will only cover the hapy path with default inner join - can we
add more tests similar to
https://github.com/apache/spark/blob/3a6bc83ff531533cf17b28ebc86a50537d412cf0/python/pyspark/sql/tests/test_subquery.py#L679
##########
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala:
##########
@@ -1569,3 +1614,46 @@ class Dataset[T] private[sql] (
override def queryExecution: QueryExecution =
throw ConnectClientUnsupportedErrors.queryExecution()
}
+
+private[sql] object Dataset {
+ // Acceptance lists for `nearestByJoin`. Must stay aligned with
`NearestByJoinType` /
+ // `NearestByJoinMode` / `NearestByDirection` in sql/catalyst, which
`sql/connect/common`
+ // cannot import.
+ private val MaxNumResults: Int = 100000
+ private val SupportedJoinTypeDisplay = "'INNER', 'LEFT OUTER'"
+ private val SupportedJoinTypes = Set("inner", "leftouter", "left",
"left_outer")
Review Comment:
why do we need both leftouter and left_outer?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]