This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4ed161be4 [VL] RAS: Incorporate query plan's logical link into
metadata model (#6165)
4ed161be4 is described below
commit 4ed161be4e044322c7b3267d48dc6dffa40cae72
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Jun 24 08:54:15 2024 +0800
[VL] RAS: Incorporate query plan's logical link into metadata model (#6165)
---
.../columnar/enumerated/RemoveFilter.scala | 1 +
.../gluten/planner/metadata/GlutenMetadata.scala | 36 +++------------
.../planner/metadata/GlutenMetadataModel.scala | 21 +++++----
.../gluten/planner/metadata/LogicalLink.scala | 53 ++++++++++++++++++++++
.../{GlutenMetadataModel.scala => Schema.scala} | 50 +++++++++++++-------
.../org/apache/gluten/planner/property/Conv.scala | 1 +
.../gluten/ras/best/GroupBasedBestFinder.scala | 14 ++++--
.../org/apache/gluten/ras/OperationSuite.scala | 4 +-
8 files changed, 119 insertions(+), 61 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
index 5d7209dfb..e2b8439fd 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
@@ -42,6 +42,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
val filter = node.asInstanceOf[FilterExecTransformerBase]
if (filter.isNoop()) {
val out = NoopFilter(filter.child, filter.output)
+ out.copyTagsFrom(filter)
return List(out)
}
List.empty
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala
index e25f0a1f1..f66c5290e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala
@@ -18,42 +18,18 @@ package org.apache.gluten.planner.metadata
import org.apache.gluten.ras.Metadata
-import org.apache.spark.sql.catalyst.expressions.Attribute
-
sealed trait GlutenMetadata extends Metadata {
- import GlutenMetadata._
def schema(): Schema
+ def logicalLink(): LogicalLink
}
object GlutenMetadata {
- def apply(schema: Schema): Metadata = {
- Impl(schema)
+ def apply(schema: Schema, logicalLink: LogicalLink): Metadata = {
+ Impl(schema, logicalLink)
}
- private case class Impl(override val schema: Schema) extends GlutenMetadata
-
- case class Schema(output: Seq[Attribute]) {
- private val hash = output.map(_.semanticHash()).hashCode()
-
- override def hashCode(): Int = {
- hash
- }
-
- override def equals(obj: Any): Boolean = obj match {
- case other: Schema =>
- semanticEquals(other)
- case _ =>
- false
- }
-
- private def semanticEquals(other: Schema): Boolean = {
- if (output.size != other.output.size) {
- return false
- }
- output.zip(other.output).forall {
- case (left, right) =>
- left.semanticEquals(right)
- }
- }
+ private case class Impl(override val schema: Schema, override val
logicalLink: LogicalLink)
+ extends GlutenMetadata {
+ override def toString: String = s"$schema,$logicalLink"
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
index 6d1baa79d..7b95f1383 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
@@ -16,7 +16,6 @@
*/
package org.apache.gluten.planner.metadata
-import org.apache.gluten.planner.metadata.GlutenMetadata.Schema
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Metadata, MetadataModel}
@@ -31,18 +30,22 @@ object GlutenMetadataModel extends Logging {
private object MetadataModelImpl extends MetadataModel[SparkPlan] {
override def metadataOf(node: SparkPlan): Metadata = node match {
case g: GroupLeafExec => throw new UnsupportedOperationException()
- case other => GlutenMetadata(Schema(other.output))
+ case other =>
+ GlutenMetadata(
+ Schema(other.output),
+
other.logicalLink.map(LogicalLink(_)).getOrElse(LogicalLink.notFound))
}
- override def dummy(): Metadata = GlutenMetadata(Schema(List()))
+ override def dummy(): Metadata = GlutenMetadata(Schema(List()),
LogicalLink.notFound)
override def verify(one: Metadata, other: Metadata): Unit = (one, other)
match {
- case (left: GlutenMetadata, right: GlutenMetadata) if left.schema() !=
right.schema() =>
- // We apply loose restriction on schema. Since Gluten still have some
customized
- // logics causing schema of an operator to change after being
transformed.
- // For example: https://github.com/apache/incubator-gluten/pull/5171
- logWarning(s"Warning: Schema mismatch: one: ${left.schema()}, other:
${right.schema()}")
- case (left: GlutenMetadata, right: GlutenMetadata) if left == right =>
+ case (left: GlutenMetadata, right: GlutenMetadata) =>
+ implicitly[Verifier[Schema]].verify(left.schema(), right.schema())
+ implicitly[Verifier[LogicalLink]].verify(left.logicalLink(),
right.logicalLink())
case _ => throw new IllegalStateException(s"Metadata mismatch: one:
$one, other $other")
}
}
+
+ trait Verifier[T <: Any] {
+ def verify(one: T, other: T): Unit
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/LogicalLink.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/LogicalLink.scala
new file mode 100644
index 000000000..4c3bffd47
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/LogicalLink.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.planner.metadata
+
+import org.apache.gluten.planner.metadata.GlutenMetadataModel.Verifier
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+
+case class LogicalLink(plan: LogicalPlan) {
+ override def hashCode(): Int = System.identityHashCode(plan)
+ override def equals(obj: Any): Boolean = obj match {
+ // LogicalLink's comparison is based on ref equality of the logical plans
being compared.
+ case LogicalLink(otherPlan) => plan eq otherPlan
+ case _ => false
+ }
+
+ override def toString: String =
s"${plan.nodeName}[${plan.stats.simpleString}]"
+}
+
+object LogicalLink {
+ private case class LogicalLinkNotFound() extends logical.LeafNode {
+ override def output: Seq[Attribute] = List.empty
+ override def canEqual(that: Any): Boolean = throw new
UnsupportedOperationException()
+ override def computeStats(): Statistics = Statistics(sizeInBytes = 0)
+ }
+
+ val notFound = new LogicalLink(LogicalLinkNotFound())
+ implicit val verifier: Verifier[LogicalLink] = new Verifier[LogicalLink]
with Logging {
+ override def verify(one: LogicalLink, other: LogicalLink): Unit = {
+ // LogicalLink's comparison is based on ref equality of the logical
plans being compared.
+ if (one != other) {
+ logWarning(s"Warning: Logical link mismatch: one: $one, other: $other")
+ }
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/Schema.scala
similarity index 50%
copy from
gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
copy to
gluten-core/src/main/scala/org/apache/gluten/planner/metadata/Schema.scala
index 6d1baa79d..969d34d5c 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/Schema.scala
@@ -16,33 +16,49 @@
*/
package org.apache.gluten.planner.metadata
-import org.apache.gluten.planner.metadata.GlutenMetadata.Schema
-import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
-import org.apache.gluten.ras.{Metadata, MetadataModel}
+import org.apache.gluten.planner.metadata.GlutenMetadataModel.Verifier
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.catalyst.expressions.Attribute
-object GlutenMetadataModel extends Logging {
- def apply(): MetadataModel[SparkPlan] = {
- MetadataModelImpl
+case class Schema(output: Seq[Attribute]) {
+ private val hash = output.map(_.semanticHash()).hashCode()
+
+ override def hashCode(): Int = {
+ hash
+ }
+
+ override def equals(obj: Any): Boolean = obj match {
+ case other: Schema =>
+ semanticEquals(other)
+ case _ =>
+ false
}
- private object MetadataModelImpl extends MetadataModel[SparkPlan] {
- override def metadataOf(node: SparkPlan): Metadata = node match {
- case g: GroupLeafExec => throw new UnsupportedOperationException()
- case other => GlutenMetadata(Schema(other.output))
+ private def semanticEquals(other: Schema): Boolean = {
+ if (output.size != other.output.size) {
+ return false
+ }
+ output.zip(other.output).forall {
+ case (left, right) =>
+ left.semanticEquals(right)
}
+ }
+
+ override def toString: String = {
+ output.toString()
+ }
+}
- override def dummy(): Metadata = GlutenMetadata(Schema(List()))
- override def verify(one: Metadata, other: Metadata): Unit = (one, other)
match {
- case (left: GlutenMetadata, right: GlutenMetadata) if left.schema() !=
right.schema() =>
+object Schema {
+ implicit val verifier: Verifier[Schema] = new Verifier[Schema] with Logging {
+ override def verify(one: Schema, other: Schema): Unit = {
+ if (one != other) {
// We apply loose restriction on schema. Since Gluten still have some
customized
// logics causing schema of an operator to change after being
transformed.
// For example: https://github.com/apache/incubator-gluten/pull/5171
- logWarning(s"Warning: Schema mismatch: one: ${left.schema()}, other:
${right.schema()}")
- case (left: GlutenMetadata, right: GlutenMetadata) if left == right =>
- case _ => throw new IllegalStateException(s"Metadata mismatch: one:
$one, other $other")
+ logWarning(s"Warning: Schema mismatch: one: $one, other: $other")
+ }
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
index 475f62920..18db0f959 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
@@ -99,6 +99,7 @@ case class ConvEnforcerRule(reqConv: Conv) extends
RasRule[SparkPlan] {
}
val transition = Conv.findTransition(conv, reqConv)
val after = transition.apply(node)
+ after.copyTagsFrom(node)
List(after)
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
index effebd41b..1128ab8de 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
@@ -82,15 +82,23 @@ private object GroupBasedBestFinder {
return Some(KnownCostPath(ras, path))
}
val childrenGroups = can.getChildrenGroups(allGroups).map(gn =>
allGroups(gn.groupId()))
- val maybeBestChildrenPaths: Seq[Option[RasPath[T]]] = childrenGroups.map
{
- childGroup => childrenGroupsOutput(childGroup).map(kcg =>
kcg.best().rasPath)
+ val maybeBestChildrenPaths: Seq[Option[KnownCostPath[T]]] =
childrenGroups.map {
+ childGroup => childrenGroupsOutput(childGroup).map(kcg => kcg.best())
}
if (maybeBestChildrenPaths.exists(_.isEmpty)) {
// Node should only be solved when all children outputs exist.
return None
}
val bestChildrenPaths = maybeBestChildrenPaths.map(_.get)
- Some(KnownCostPath(ras, path.RasPath(ras, can, bestChildrenPaths).get))
+ val kcp = KnownCostPath(ras, path.RasPath(ras, can,
bestChildrenPaths.map(_.rasPath)).get)
+ // Cost should be in monotonically increasing basis.
+ bestChildrenPaths.map(_.cost).foreach {
+ childCost =>
+ assert(
+ ras.costModel.costComparator().gteq(kcp.cost, childCost),
+ "Illegal decreasing cost")
+ }
+ Some(kcp)
}
override def solveGroup(
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
index 60ec2eedd..e1ccfa1f4 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
@@ -230,7 +230,7 @@ class OperationSuite extends AnyFunSuite {
48,
Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48,
Unary3(48, Leaf(30))))))))))))
assert(costModel.costOfCount == 32) // TODO reduce this for performance
- assert(costModel.costCompareCount == 20) // TODO reduce this for
performance
+ assert(costModel.costCompareCount == 50) // TODO reduce this for
performance
}
test("Cost evaluation count - max cost") {
@@ -292,7 +292,7 @@ class OperationSuite extends AnyFunSuite {
48,
Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48,
Unary3(48, Leaf(30))))))))))))
assert(costModel.costOfCount == 32) // TODO reduce this for performance
- assert(costModel.costCompareCount == 20) // TODO reduce this for
performance
+ assert(costModel.costCompareCount == 50) // TODO reduce this for
performance
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]