hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177036733


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -520,54 +515,205 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
-
-    val deserializer = UnresolvedDeserializer(iEnc.deserializer)
-    val deserialized = DeserializeToObject(deserializer, 
generateObjAttr(iEnc), child)
+    val udf = ScalaUdf(fun)
+    val deserialized = DeserializeToObject(udf.inputDeserializer(), 
udf.inputObjAttr, child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
-      generateObjAttr(rEnc),
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.outputObjAttr,
       deserialized)
-    SerializeFromObject(rEnc.namedExpressions, mapped)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => 
Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: proto.GroupMap,
+      commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+    val udf = ScalaUdf(commonUdf)
+    val ds = UntypedKeyValueGroupedDataset(
+      rel.getInput,
+      rel.getGroupingExpressionsList,
+      rel.getSortingExpressionsList)
+
+    val mapped = new MapGroups(
+      udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+      udf.inputDeserializer(ds.groupingAttributes),
+      ds.valueDeserializer,
+      ds.groupingAttributes,
+      ds.dataAttributes,
+      ds.sortOrder,
+      udf.outputObjAttr,
+      ds.analyzed)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
   }
 
   private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
+        transformTypedCoGroupMap(rel, commonUdf)
 
-    val inputCols =
-      rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
-        Column(transformExpression(expr)))
-    val otherCols =
-      rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
-        Column(transformExpression(expr)))
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
 
-    val input = Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(inputCols: _*)
-    val other = Dataset
-      .ofRows(session, transformRelation(rel.getOther))
-      .groupBy(otherCols: _*)
+        val inputCols =
+          rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+        val otherCols =
+          rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
 
-    input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+        val input = Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(inputCols: _*)
+        val other = Dataset
+          .ofRows(session, transformRelation(rel.getOther))
+          .groupBy(otherCols: _*)
+
+        input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  private def transformTypedCoGroupMap(
+      rel: proto.CoGroupMap,
+      commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+    val udf = ScalaUdf(commonUdf)
+    val left = UntypedKeyValueGroupedDataset(
+      rel.getInput,
+      rel.getInputGroupingExpressionsList,
+      rel.getInputSortingExpressionsList)
+    val right = UntypedKeyValueGroupedDataset(
+      rel.getOther,
+      rel.getOtherGroupingExpressionsList,
+      rel.getOtherSortingExpressionsList)
+
+    val mapped = CoGroup(
+      udf.function.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => 
TraversableOnce[Any]],
+      // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, 
so it's safe to
+      // resolve the `keyDeserializer` based on either of them, here we pick 
the left one.
+      udf.inputDeserializer(left.groupingAttributes),
+      left.valueDeserializer,
+      right.valueDeserializer,
+      left.groupingAttributes,
+      right.groupingAttributes,
+      left.dataAttributes,
+      right.dataAttributes,
+      left.sortOrder,
+      right.sortOrder,
+      udf.outputObjAttr,
+      left.analyzed,
+      right.analyzed)
+    SerializeFromObject(udf.outputNamedExpression, mapped)
+  }
+
+  /**
+   * This is the untyped version of [[KeyValueGroupedDataset]].
+   */
+  private case class UntypedKeyValueGroupedDataset(
+      kEncoder: ExpressionEncoder[_],
+      vEncoder: ExpressionEncoder[_],
+      valueDeserializer: Expression,
+      analyzed: LogicalPlan,
+      dataAttributes: Seq[Attribute],
+      groupingAttributes: Seq[Attribute],
+      sortOrder: Seq[SortOrder])
+  private object UntypedKeyValueGroupedDataset {
+    def apply(
+        input: proto.Relation,
+        groupingExprs: java.util.List[proto.Expression],
+        sortingExprs: java.util.List[proto.Expression]): 
UntypedKeyValueGroupedDataset = {
+      val logicalPlan = transformRelation(input)
+      assert(groupingExprs.size() == 1)
+      val groupFunc = groupingExprs.asScala.toSeq

Review Comment:
   NIT More direct: 
`unpackUdf(groupingExprs.get(0).getCommonInlineUserDefinedFunction)`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to