This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 993e96afe [CORE][VL] RAS: Refactor memo cache to look up on 
cluster-canonical node rather than on group-canonical node (#5305)
993e96afe is described below

commit 993e96afe81ff85e2928151b4a9b9d45baf4b79f
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Apr 8 11:42:51 2024 +0800

    [CORE][VL] RAS: Refactor memo cache to look up on cluster-canonical node 
rather than on group-canonical node (#5305)
---
 .../planner/property/GlutenPropertyModel.scala     |   2 +
 .../org/apache/gluten/ras/MetadataModel.scala      |   3 +-
 .../org/apache/gluten/ras/PropertyModel.scala      |   1 +
 .../src/main/scala/org/apache/gluten/ras/Ras.scala |  48 ++++-
 .../scala/org/apache/gluten/ras/RasCluster.scala   |  16 +-
 .../main/scala/org/apache/gluten/ras/RasNode.scala |  53 ++++-
 .../scala/org/apache/gluten/ras/RasPlanner.scala   |  23 ++-
 .../org/apache/gluten/ras/best/BestFinder.scala    |  21 +-
 .../gluten/ras/best/GroupBasedBestFinder.scala     |  23 ++-
 .../org/apache/gluten/ras/dp/DpClusterAlgo.scala   |   2 +-
 .../org/apache/gluten/ras/dp/DpGroupAlgo.scala     |   2 +-
 .../scala/org/apache/gluten/ras/dp/DpPlanner.scala |   4 +-
 .../org/apache/gluten/ras/dp/DpZipperAlgo.scala    |   2 -
 .../gluten/ras/exaustive/ExhaustivePlanner.scala   |  24 ++-
 .../apache/gluten/ras/memo/ForwardMemoTable.scala  |  19 +-
 .../scala/org/apache/gluten/ras/memo/Memo.scala    | 214 +++++++++++++--------
 .../org/apache/gluten/ras/memo/MemoTable.scala     |   1 +
 .../scala/org/apache/gluten/ras/path/RasPath.scala |  43 ++---
 .../org/apache/gluten/ras/rule/RuleApplier.scala   |  51 ++---
 .../scala/org/apache/gluten/ras/util/NodeMap.scala |  60 ------
 .../apache/gluten/ras/vis/GraphvizVisualizer.scala |   4 +-
 .../org/apache/gluten/ras/OperationSuite.scala     |   8 +-
 .../org/apache/gluten/ras/PropertySuite.scala      |  34 +++-
 .../scala/org/apache/gluten/ras/RasSuite.scala     |   9 +-
 .../org/apache/gluten/ras/path/RasPathSuite.scala  |  60 ++++--
 .../gluten/ras/specific/DistributedSuite.scala     |   4 +
 26 files changed, 453 insertions(+), 278 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
index 54f4e3b84..07dd3fe02 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
@@ -53,6 +53,8 @@ object GlutenProperties {
         val conv = getProperty(plan)
         plan.children.map(_ => conv)
     }
+
+    override def any(): Convention = Conventions.ANY
   }
 
   case class ConventionEnforcerRule(reqConv: Convention) extends 
RasRule[SparkPlan] {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
index d2056746c..a81ac31cb 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
@@ -21,8 +21,9 @@ package org.apache.gluten.ras
  */
 trait MetadataModel[T <: AnyRef] {
   def metadataOf(node: T): Metadata
-  def dummy(): Metadata
   def verify(one: Metadata, other: Metadata): Unit
+
+  def dummy(): Metadata
 }
 
 trait Metadata {}
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
index e2ba99136..e764631e7 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
@@ -26,6 +26,7 @@ trait Property[T <: AnyRef] {
 }
 
 trait PropertyDef[T <: AnyRef, P <: Property[T]] {
+  def any(): P
   def getProperty(plan: T): P
   def getChildrenConstraints(constraint: Property[T], plan: T): Seq[P]
 }
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 6832d07c5..9910fab6f 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -79,9 +79,10 @@ class Ras[T <: AnyRef] private (
       ruleFactory)
   }
 
-  // Normal groups start with ID 0, so it's safe to use -1 to do validation.
+  private val propSetFactory: PropertySetFactory[T] = 
PropertySetFactory(propertyModel, planModel)
+  // Normal groups start with ID 0, so it's safe to use Int.MinValue to do 
validation.
   private val dummyGroup: T =
-    planModel.newGroupLeaf(-1, metadataModel.dummy(), PropertySet(Seq.empty))
+    planModel.newGroupLeaf(Int.MinValue, metadataModel.dummy(), 
propSetFactory.any())
   private val infCost: Cost = costModel.makeInfCost()
 
   validateModels()
@@ -123,8 +124,6 @@ class Ras[T <: AnyRef] private (
     }
   }
 
-  private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(this)
-
   override def newPlanner(
       plan: T,
       constraintSet: PropertySet[T],
@@ -171,6 +170,8 @@ class Ras[T <: AnyRef] private (
   private[ras] def getInfCost(): Cost = infCost
 
   private[ras] def isInfCost(cost: Cost) = 
costModel.costComparator().equiv(cost, infCost)
+
+  private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node)
 }
 
 object Ras {
@@ -192,16 +193,29 @@ object Ras {
   }
 
   trait PropertySetFactory[T <: AnyRef] {
+    def any(): PropertySet[T]
     def get(node: T): PropertySet[T]
     def childrenConstraintSets(constraintSet: PropertySet[T], node: T): 
Seq[PropertySet[T]]
   }
 
   private object PropertySetFactory {
-    def apply[T <: AnyRef](ras: Ras[T]): PropertySetFactory[T] = new 
PropertySetFactoryImpl[T](ras)
-
-    private class PropertySetFactoryImpl[T <: AnyRef](val ras: Ras[T])
+    def apply[T <: AnyRef](
+        propertyModel: PropertyModel[T],
+        planModel: PlanModel[T]): PropertySetFactory[T] =
+      new PropertySetFactoryImpl[T](propertyModel, planModel)
+
+    private class PropertySetFactoryImpl[T <: AnyRef](
+        propertyModel: PropertyModel[T],
+        planModel: PlanModel[T])
       extends PropertySetFactory[T] {
-      private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = 
ras.propertyModel.propertyDefs
+      private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = 
propertyModel.propertyDefs
+      private val anyConstraint = {
+        val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
+          propDefs.map(propDef => (propDef, propDef.any())).toMap
+        PropertySet[T](m)
+      }
+
+      override def any(): PropertySet[T] = anyConstraint
 
       override def get(node: T): PropertySet[T] = {
         val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
@@ -213,7 +227,7 @@ object Ras {
           constraintSet: PropertySet[T],
           node: T): Seq[PropertySet[T]] = {
         val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]], 
Property[T]]] =
-          ras.planModel
+          planModel
             .childrenOf(node)
             .map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]], 
Property[T]]())
 
@@ -236,4 +250,20 @@ object Ras {
       }
     }
   }
+
+  trait UnsafeKey[T]
+
+  private object UnsafeKey {
+    def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new 
UnsafeKeyImpl(ras, self)
+    private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends 
UnsafeKey[T] {
+      override def hashCode(): Int = ras.planModel.hashCode(self)
+      override def equals(other: Any): Boolean = {
+        other match {
+          case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self)
+          case _ => false
+        }
+      }
+      override def toString: String = ras.explain.describeNode(self)
+    }
+  }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index 63b8b1e68..1b30e1242 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.ras
 
+import org.apache.gluten.ras.Ras.UnsafeKey
 import org.apache.gluten.ras.memo.MemoTable
 import org.apache.gluten.ras.property.PropertySet
 
@@ -54,16 +55,19 @@ object RasCluster {
         override val ras: Ras[T],
         metadata: Metadata)
       extends MutableRasCluster[T] {
-      private val buffer: mutable.Set[CanonicalNode[T]] =
-        mutable.Set()
+      private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set()
+      private val buffer: mutable.ListBuffer[CanonicalNode[T]] =
+        mutable.ListBuffer()
 
       override def contains(t: CanonicalNode[T]): Boolean = {
-        buffer.contains(t)
+        deDup.contains(t.toUnsafeKey())
       }
 
       override def add(t: CanonicalNode[T]): Unit = {
+        val key = t.toUnsafeKey()
+        assert(!deDup.contains(key))
         ras.metadataModel.verify(metadata, 
ras.metadataModel.metadataOf(t.self()))
-        assert(!buffer.contains(t))
+        deDup += key
         buffer += t
       }
 
@@ -75,12 +79,12 @@ object RasCluster {
 
   case class ImmutableRasCluster[T <: AnyRef] private (
       ras: Ras[T],
-      override val nodes: Set[CanonicalNode[T]])
+      override val nodes: Seq[CanonicalNode[T]])
     extends RasCluster[T]
 
   object ImmutableRasCluster {
     def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]): 
ImmutableRasCluster[T] = {
-      ImmutableRasCluster(ras, cluster.nodes().toSet)
+      ImmutableRasCluster(ras, cluster.nodes().toVector)
     }
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 5f18f96a7..878020391 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.ras
 
+import org.apache.gluten.ras.Ras.UnsafeKey
 import org.apache.gluten.ras.property.PropertySet
 
 trait RasNode[T <: AnyRef] {
@@ -41,6 +42,8 @@ object RasNode {
     def asGroup(): GroupNode[T] = {
       node.asInstanceOf[GroupNode[T]]
     }
+
+    def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self())
   }
 }
 
@@ -53,7 +56,7 @@ object CanonicalNode {
     assert(ras.isCanonical(canonical))
     val propSet = ras.propSetsOf(canonical)
     val children = ras.planModel.childrenOf(canonical)
-    CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
+    new CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
   }
 
   // We put RasNode's API methods that accept mutable input in implicit 
definition.
@@ -74,12 +77,16 @@ object CanonicalNode {
     }
   }
 
-  private case class CanonicalNodeImpl[T <: AnyRef](
-      ras: Ras[T],
+  private class CanonicalNodeImpl[T <: AnyRef](
+      override val ras: Ras[T],
       override val self: T,
       override val propSet: PropertySet[T],
       override val childrenCount: Int)
-    extends CanonicalNode[T]
+    extends CanonicalNode[T] {
+    override def toString: String = ras.explain.describeNode(self)
+    override def hashCode(): Int = throw new UnsupportedOperationException()
+    override def equals(obj: Any): Boolean = throw new 
UnsupportedOperationException()
+  }
 }
 
 trait GroupNode[T <: AnyRef] extends RasNode[T] {
@@ -88,15 +95,19 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] {
 
 object GroupNode {
   def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = {
-    GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
+    new GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
   }
 
-  private case class GroupNodeImpl[T <: AnyRef](
-      ras: Ras[T],
+  private class GroupNodeImpl[T <: AnyRef](
+      override val ras: Ras[T],
       override val self: T,
       override val propSet: PropertySet[T],
       override val groupId: Int)
-    extends GroupNode[T] {}
+    extends GroupNode[T] {
+    override def toString: String = ras.explain.describeNode(self)
+    override def hashCode(): Int = throw new UnsupportedOperationException()
+    override def equals(obj: Any): Boolean = throw new 
UnsupportedOperationException()
+  }
 
   // We put RasNode's API methods that accept mutable input in implicit 
definition.
   // Do not break this rule during further development.
@@ -116,8 +127,21 @@ object InGroupNode {
   def apply[T <: AnyRef](groupId: Int, node: CanonicalNode[T]): InGroupNode[T] 
= {
     InGroupNodeImpl(groupId, node)
   }
+
   private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: 
CanonicalNode[T])
     extends InGroupNode[T]
+
+  trait HashKey extends Any
+
+  implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) {
+    import InGroupNodeImplicits._
+    def toHashKey: HashKey =
+      InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can))
+  }
+
+  private object InGroupNodeImplicits {
+    private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends 
HashKey
+  }
 }
 
 trait InClusterNode[T <: AnyRef] {
@@ -129,8 +153,21 @@ object InClusterNode {
   def apply[T <: AnyRef](clusterId: RasClusterKey, node: CanonicalNode[T]): 
InClusterNode[T] = {
     InClusterNodeImpl(clusterId, node)
   }
+
   private case class InClusterNodeImpl[T <: AnyRef](
       clusterKey: RasClusterKey,
       can: CanonicalNode[T])
     extends InClusterNode[T]
+
+  trait HashKey extends Any
+
+  implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) {
+    import InClusterNodeImplicits._
+    def toHashKey: HashKey =
+      InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can))
+  }
+
+  private object InClusterNodeImplicits {
+    private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey, 
cid: Int) extends HashKey
+  }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
index 0665d3661..74793a3d0 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
@@ -49,8 +49,8 @@ object RasPlanner {
 trait Best[T <: AnyRef] {
   import Best._
   def rootGroupId(): Int
-  def bestNodes(): Set[InGroupNode[T]]
-  def winnerNodes(): Set[InGroupNode[T]]
+  def bestNodes(): InGroupNode[T] => Boolean
+  def winnerNodes(): InGroupNode[T] => Boolean
   def costs(): InGroupNode[T] => Option[Cost]
   def path(): KnownCostPath[T]
 }
@@ -62,11 +62,11 @@ object Best {
       bestPath: KnownCostPath[T],
       winnerNodes: Seq[InGroupNode[T]],
       costs: InGroupNode[T] => Option[Cost]): Best[T] = {
-    val bestNodes = mutable.Set[InGroupNode[T]]()
+    val bestNodes = mutable.Set[InGroupNode.HashKey]()
 
     def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = {
       val can = cursor.self().asCanonical()
-      bestNodes += InGroupNode(groupId, can)
+      bestNodes += InGroupNode(groupId, can).toHashKey
       cursor.zipChildrenWithGroupIds().foreach {
         case (childPathNode, childGroupId) =>
           dfs(childGroupId, childPathNode)
@@ -75,17 +75,24 @@ object Best {
 
     dfs(rootGroupId, bestPath.rasPath.node())
 
-    val winnerNodeSet = winnerNodes.toSet
+    val bestNodeSet = bestNodes.toSet
+    val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet
 
-    BestImpl(ras, rootGroupId, bestPath, bestNodes.toSet, winnerNodeSet, costs)
+    BestImpl(
+      ras,
+      rootGroupId,
+      bestPath,
+      n => bestNodeSet.contains(n.toHashKey),
+      n => winnerNodeSet.contains(n.toHashKey),
+      costs)
   }
 
   private case class BestImpl[T <: AnyRef](
       ras: Ras[T],
       override val rootGroupId: Int,
       override val path: KnownCostPath[T],
-      override val bestNodes: Set[InGroupNode[T]],
-      override val winnerNodes: Set[InGroupNode[T]],
+      override val bestNodes: InGroupNode[T] => Boolean,
+      override val winnerNodes: InGroupNode[T] => Boolean,
       override val costs: InGroupNode[T] => Option[Cost])
     extends Best[T]
 
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
index 4ec7e09f5..0912ab536 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
@@ -40,9 +40,12 @@ object BestFinder {
   }
 
   case class KnownCostGroup[T <: AnyRef](
-      nodeToCost: Map[CanonicalNode[T], KnownCostPath[T]],
+      nodes: Iterable[CanonicalNode[T]],
+      nodeToCost: CanonicalNode[T] => Option[KnownCostPath[T]],
       bestNode: CanonicalNode[T]) {
-    def best(): KnownCostPath[T] = nodeToCost(bestNode)
+    def best(): KnownCostPath[T] = {
+      nodeToCost(bestNode).get
+    }
   }
 
   case class KnownCostCluster[T <: AnyRef](groupToCost: Map[Int, 
KnownCostGroup[T]])
@@ -52,17 +55,21 @@ object BestFinder {
       allGroups: Seq[RasGroup[T]],
       group: RasGroup[T],
       groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = {
+
     val bestPath = groupToCosts(group.id()).best()
     val bestRoot = bestPath.rasPath.node()
     val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, 
g.bestNode) }.toSeq
-    val costsMap = mutable.Map[InGroupNode[T], Cost]()
+    val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
     groupToCosts.foreach {
       case (gid, g) =>
-        g.nodeToCost.foreach {
-          case (n, c) =>
-            costsMap += (InGroupNode(gid, n) -> c.cost)
+        g.nodes.foreach {
+          n =>
+            val c = g.nodeToCost(n)
+            if (c.nonEmpty) {
+              costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost)
+            }
         }
     }
-    Best(ras, group.id(), bestPath, winnerNodes, costsMap.get)
+    Best(ras, group.id(), bestPath, winnerNodes, ign => 
costsMap.get(ign.toHashKey))
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
index 7d2d807ff..6db3600de 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
@@ -23,6 +23,8 @@ import org.apache.gluten.ras.dp.{DpGroupAlgo, DpGroupAlgoDef}
 import org.apache.gluten.ras.memo.MemoState
 import org.apache.gluten.ras.path.{PathKeySet, RasPath}
 
+import java.util
+
 // The best path's each sub-path is considered optimal in its own group.
 private class GroupBasedBestFinder[T <: AnyRef](
     ras: Ras[T],
@@ -94,21 +96,34 @@ private object GroupBasedBestFinder {
     override def solveGroup(
         group: RasGroup[T],
         nodesOutput: InGroupNode[T] => Option[KnownCostPath[T]]): 
Option[KnownCostGroup[T]] = {
+      import scala.collection.JavaConverters._
+
       val nodes = group.nodes(memoState)
       // Allow unsolved children nodes while solving group.
-      val flatNodesOutput =
-        nodes.flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp => 
n -> kcp)).toMap
+      val flatNodesOutput = new util.IdentityHashMap[CanonicalNode[T], 
KnownCostPath[T]]()
+
+      nodes
+        .flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp => n -> 
kcp))
+        .foreach {
+          case (n, kcp) =>
+            assert(!flatNodesOutput.containsKey(n))
+            flatNodesOutput.put(n, kcp)
+        }
 
       if (flatNodesOutput.isEmpty) {
         return None
       }
-      val bestPath = flatNodesOutput.values.reduce {
+      val bestPath = flatNodesOutput.values.asScala.reduce {
         (left, right) =>
           Ordering
             .by((cp: KnownCostPath[T]) => cp.cost)(costComparator)
             .min(left, right)
       }
-      Some(KnownCostGroup(flatNodesOutput, 
bestPath.rasPath.node().self().asCanonical()))
+      Some(
+        KnownCostGroup(
+          nodes,
+          n => Option(flatNodesOutput.get(n)),
+          bestPath.rasPath.node().self().asCanonical()))
     }
 
     override def solveNodeOnCycle(node: InGroupNode[T]): 
Option[KnownCostPath[T]] =
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
index 95f453f47..e90ba448b 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
@@ -73,7 +73,7 @@ object DpClusterAlgo {
       clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput])
     extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput, 
ClusterOutput] {
     override def idOfX(x: InClusterNode[T]): Any = {
-      x
+      x.toHashKey
     }
 
     override def idOfY(y: RasClusterKey): Any = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
index 6c1e998b6..c824fda8e 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
@@ -66,7 +66,7 @@ object DpGroupAlgo {
       groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput])
     extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput, 
GroupOutput] {
     override def idOfX(x: InGroupNode[T]): Any = {
-      x
+      x.toHashKey
     }
 
     override def idOfY(y: RasGroup[T]): Any = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 391e7f196..1be728ae6 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath
 import org.apache.gluten.ras.best.BestFinder
 import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel
 import org.apache.gluten.ras.memo.{Memo, MemoTable}
-import org.apache.gluten.ras.path.{PathFinder, RasPath}
+import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath}
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}
 
@@ -172,7 +172,7 @@ object DpPlanner {
         rule: RuleApplier[T],
         path: RasPath[T]): Unit = {
       val probe = memoTable.probe()
-      rule.apply(path)
+      rule.apply(InClusterPath(thisClusterKey, path))
       val diff = probe.toDiff()
       val changedClusters = diff.changedClusters()
       if (changedClusters.isEmpty) {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
index 821009982..f28edd0dc 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
@@ -608,7 +608,6 @@ object DpZipperAlgo {
   }
 
   private object XKey {
-    // Keep argument "ele" although it is unused. To give compiler type hint.
     def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
         algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
         x: X): XKey[X, Y, XOutput, YOutput] = {
@@ -631,7 +630,6 @@ object DpZipperAlgo {
   }
 
   private object YKey {
-    // Keep argument "ele" although it is unused. To give compiler type hint.
     def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
         algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
         y: Y): YKey[X, Y, XOutput, YOutput] = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index a9737eb02..47945fc14 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -107,8 +107,8 @@ object ExhaustivePlanner {
       finder.find(canonical).foreach(path => onFound(path))
     }
 
-    private def applyRule(rule: RuleApplier[T], path: RasPath[T]): Unit = {
-      rule.apply(path)
+    private def applyRule(rule: RuleApplier[T], icp: InClusterPath[T]): Unit = 
{
+      rule.apply(icp)
     }
 
     private def applyRules(): Unit = {
@@ -116,10 +116,17 @@ object ExhaustivePlanner {
         return
       }
       val shapes = rules.map(_.shape())
-      allClusters
-        .flatMap(c => c.nodes())
-        .foreach(
-          node => findPaths(node, shapes)(path => rules.foreach(rule => 
applyRule(rule, path))))
+      memoState
+        .clusterLookup()
+        .foreach {
+          case (cKey, cluster) =>
+            cluster
+              .nodes()
+              .foreach(
+                node =>
+                  findPaths(node, shapes)(
+                    path => rules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path)))))
+        }
     }
 
     private def applyEnforcerRules(): Unit = {
@@ -129,10 +136,11 @@ object ExhaustivePlanner {
           val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
           if (enforcerRules.nonEmpty) {
             val shapes = enforcerRules.map(_.shape())
-            memoState.clusterLookup()(group.clusterKey()).nodes().foreach {
+            val cKey = group.clusterKey()
+            memoState.clusterLookup()(cKey).nodes().foreach {
               node =>
                 findPaths(node, shapes)(
-                  path => enforcerRules.foreach(rule => applyRule(rule, path)))
+                  path => enforcerRules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path))))
             }
           }
       }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index 945e653eb..e3ae03ebf 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -33,6 +33,8 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
   private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] = 
mutable.ArrayBuffer()
   private val clusterBuffer: mutable.ArrayBuffer[MutableRasCluster[T]] = 
mutable.ArrayBuffer()
   private val clusterDisjointSet: IndexDisjointSet = IndexDisjointSet()
+  private val clusterDummyGroupBuffer = mutable.ArrayBuffer[RasGroup[T]]()
+
   private val groupLookup: mutable.ArrayBuffer[mutable.Map[PropertySet[T], 
RasGroup[T]]] =
     mutable.ArrayBuffer()
 
@@ -46,14 +48,22 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
 
   override def newCluster(metadata: Metadata): RasClusterKey = {
     checkBufferSizes()
-    val key = IntClusterKey(clusterBuffer.size, metadata)
+    val clusterId = clusterBuffer.size
+    val key = IntClusterKey(clusterId, metadata)
     clusterKeyBuffer += key
     clusterBuffer += MutableRasCluster(ras, metadata)
     clusterDisjointSet.grow()
     groupLookup += mutable.Map()
+    // Normal groups start with ID 0, so it's safe to use negative IDs for 
dummy groups.
+    clusterDummyGroupBuffer += RasGroup(ras, key, -clusterId, 
ras.propertySetFactory().any())
     key
   }
 
+  override def dummyGroupOf(key: RasClusterKey): RasGroup[T] = {
+    val ancestor = ancestorClusterIdOf(key)
+    clusterDummyGroupBuffer(ancestor)
+  }
+
   override def groupOf(key: RasClusterKey, propSet: PropertySet[T]): 
RasGroup[T] = {
     val ancestor = ancestorClusterIdOf(key)
     val lookup = groupLookup(ancestor)
@@ -75,7 +85,11 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
   }
 
   override def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit 
= {
-    getCluster(key).add(node)
+    val cluster = getCluster(key)
+    if (cluster.contains(node)) {
+      return
+    }
+    cluster.add(node)
     memoWriteCount += 1
   }
 
@@ -142,6 +156,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     assert(clusterKeyBuffer.size == clusterBuffer.size)
     assert(clusterKeyBuffer.size == clusterDisjointSet.size)
     assert(clusterKeyBuffer.size == groupLookup.size)
+    assert(clusterKeyBuffer.size == clusterDummyGroupBuffer.size)
   }
 
   override def probe(): MemoTable.Probe[T] = new 
ForwardMemoTable.Probe[T](this)
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index a77293586..66626b756 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -17,17 +17,19 @@
 package org.apache.gluten.ras.memo
 
 import org.apache.gluten.ras._
+import org.apache.gluten.ras.Ras.UnsafeKey
 import org.apache.gluten.ras.RasCluster.ImmutableRasCluster
 import org.apache.gluten.ras.property.PropertySet
-import org.apache.gluten.ras.util.CanonicalNodeMap
 import org.apache.gluten.ras.vis.GraphvizVisualizer
 
+import scala.collection.mutable
+
 trait MemoLike[T <: AnyRef] {
   def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T]
 }
 
 trait Closure[T <: AnyRef] {
-  def openFor(node: CanonicalNode[T]): MemoLike[T]
+  def openFor(cKey: RasClusterKey): MemoLike[T]
 }
 
 trait Memo[T <: AnyRef] extends Closure[T] with MemoLike[T] {
@@ -51,82 +53,61 @@ object Memo {
   private class RasMemo[T <: AnyRef](val ras: Ras[T]) extends UnsafeMemo[T] {
     import RasMemo._
     private val memoTable: MemoTable.Writable[T] = MemoTable.create(ras)
-    private val cache: NodeToClusterMap[T] = new NodeToClusterMap(ras)
+    private val cache = mutable.Map[MemoCacheKey[T], RasClusterKey]()
 
     private def newCluster(metadata: Metadata): RasClusterKey = {
       memoTable.newCluster(metadata)
     }
 
     private def addToCluster(clusterKey: RasClusterKey, can: 
CanonicalNode[T]): Unit = {
-      assert(!cache.contains(can))
-      cache.put(can, clusterKey)
       memoTable.addToCluster(clusterKey, can)
     }
 
-    // Replace node's children with node groups. When a group doesn't exist, 
create it.
-    private def canonizeUnsafe(node: T, constraintSet: PropertySet[T], depth: 
Int): T = {
-      assert(depth >= 1)
-      if (depth > 1) {
-        return ras.withNewChildren(
-          node,
-          ras.planModel
-            .childrenOf(node)
-            
.zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node))
-            .map {
-              case (child, constraintSet) =>
-                canonizeUnsafe(child, constraintSet, depth - 1)
-            }
-        )
+    private def clusterOfUnsafe(metadata: Metadata, cacheKey: 
MemoCacheKey[T]): RasClusterKey = {
+      if (cache.contains(cacheKey)) {
+        cache(cacheKey)
+      } else {
+        // Node not yet added to cluster.
+        val cluster = newCluster(metadata)
+        cache += (cacheKey -> cluster)
+        cluster
       }
-      assert(depth == 1)
-      val childrenGroups: Seq[RasGroup[T]] = ras.planModel
-        .childrenOf(node)
-        .zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, 
node))
-        .map {
-          case (child, childConstraintSet) =>
-            memorize(child, childConstraintSet)
-        }
-      val newNode =
-        ras.withNewChildren(node, childrenGroups.map(group => group.self()))
-      newNode
     }
 
-    private def canonize(node: T, constraintSet: PropertySet[T]): 
CanonicalNode[T] = {
-      CanonicalNode(ras, canonizeUnsafe(node, constraintSet, 1))
+    private def dummyGroupOf(clusterKey: RasClusterKey): RasGroup[T] = {
+      memoTable.dummyGroupOf(clusterKey)
+    }
+
+    private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = {
+      MemoCacheKey(ras, n)
     }
 
-    private def insert(n: T, constraintSet: PropertySet[T]): RasClusterKey = {
-      if (ras.planModel.isGroupLeaf(n)) {
-        val plainGroup = memoTable.allGroups()(ras.planModel.getGroupId(n))
-        return plainGroup.clusterKey()
+    private def prepareInsert(n: T): Prepare[T] = {
+      if (ras.isGroupLeaf(n)) {
+        val group = memoTable.allGroups()(ras.planModel.getGroupId(n))
+        return Prepare.cluster(this, group.clusterKey())
       }
 
-      val node = canonize(n, constraintSet)
+      val childrenPrepares = ras.planModel.childrenOf(n).map(child => 
prepareInsert(child))
 
-      if (cache.contains(node)) {
-        cache.get(node)
-      } else {
-        // Node not yet added to cluster.
-        val meta = ras.metadataModel.metadataOf(node.self())
-        val clusterKey = newCluster(meta)
-        addToCluster(clusterKey, node)
-        clusterKey
-      }
+      val canUnsafe = ras.withNewChildren(
+        n,
+        childrenPrepares.map(childPrepare => 
dummyGroupOf(childPrepare.clusterKey()).self()))
+
+      val cacheKey = toCacheKeyUnsafe(canUnsafe)
+
+      val clusterKey = clusterOfUnsafe(ras.metadataModel.metadataOf(n), 
cacheKey)
+
+      Prepare.tree(this, clusterKey, childrenPrepares)
     }
 
     override def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T] 
= {
-      val clusterKey = insert(node, constraintSet)
-      val prevGroupCount = memoTable.allGroups().size
-      val out = memoTable.groupOf(clusterKey, constraintSet)
-      val newGroupCount = memoTable.allGroups().size
-      assert(newGroupCount >= prevGroupCount)
-      out
+      val prepare = prepareInsert(node)
+      prepare.doInsert(node, constraintSet)
     }
 
-    override def openFor(node: CanonicalNode[T]): MemoLike[T] = {
-      assert(cache.contains(node))
-      val targetCluster = cache.get(node)
-      new InCusterMemo[T](this, targetCluster)
+    override def openFor(cKey: RasClusterKey): MemoLike[T] = {
+      new InCusterMemo[T](this, cKey)
     }
 
     override def newState(): MemoState[T] = {
@@ -141,37 +122,116 @@ object Memo {
   }
 
   private object RasMemo {
-    private class InCusterMemo[T <: AnyRef](parent: RasMemo[T], 
preparedCluster: RasClusterKey)
+    private class InCusterMemo[T <: AnyRef](parent: RasMemo[T], targetCluster: 
RasClusterKey)
       extends MemoLike[T] {
+      private val ras = parent.ras
+
+      private def prepareInsert(node: T): Prepare[T] = {
+        assert(!ras.isGroupLeaf(node))
+
+        val childrenPrepares =
+          ras.planModel.childrenOf(node).map(child => 
parent.prepareInsert(child))
+
+        val canUnsafe = ras.withNewChildren(
+          node,
+          childrenPrepares.map {
+            childPrepare => 
parent.dummyGroupOf(childPrepare.clusterKey()).self()
+          })
+
+        val cacheKey = parent.toCacheKeyUnsafe(canUnsafe)
+
+        if (!parent.cache.contains(cacheKey)) {
+          // The new node was not added to memo yet. Add it to the target 
cluster.
+          parent.cache += (cacheKey -> targetCluster)
+          return Prepare.tree(parent, targetCluster, childrenPrepares)
+        }
+
+        // The new node already memorized to memo.
 
-      private def insert(node: T, constraintSet: PropertySet[T]): Unit = {
-        val can = parent.canonize(node, constraintSet)
-        if (parent.cache.contains(can)) {
-          val cachedCluster = parent.cache.get(can)
-          if (cachedCluster == preparedCluster) {
-            return
-          }
-          // The new node already memorized to memo, but in the different 
cluster
-          // with the input node. Merge the two clusters.
-          //
-          // TODO: Traversal up the tree to do more merges.
-          parent.memoTable.mergeClusters(cachedCluster, preparedCluster)
-          // Since new node already memorized, we don't have to add it to 
either of the clusters
-          // anymore.
-          return
+        val cachedCluster = parent.cache(cacheKey)
+        if (cachedCluster == targetCluster) {
+          // The new node already memorized to memo and in the target cluster.
+          return Prepare.tree(parent, targetCluster, childrenPrepares)
         }
-        parent.addToCluster(preparedCluster, can)
+        // The new node already memorized to memo, but in the different 
cluster.
+        // Merge the two clusters.
+        //
+        // TODO: Traverse up the tree to do more merges.
+        parent.memoTable.mergeClusters(cachedCluster, targetCluster)
+        Prepare.tree(parent, targetCluster, childrenPrepares)
       }
 
       override def memorize(node: T, constraintSet: PropertySet[T]): 
RasGroup[T] = {
-        insert(node, constraintSet)
-        parent.memoTable.groupOf(preparedCluster, constraintSet)
+        val prepare = prepareInsert(node)
+        prepare.doInsert(node, constraintSet)
       }
     }
+
+    private trait Prepare[T <: AnyRef] {
+      def clusterKey(): RasClusterKey
+      def doInsert(node: T, constraintSet: PropertySet[T]): RasGroup[T]
+    }
+
+    private object Prepare {
+      def tree[T <: AnyRef](
+          memo: RasMemo[T],
+          cKey: RasClusterKey,
+          children: Seq[Prepare[T]]): Prepare[T] = {
+        new TreePrepare[T](memo, cKey, children)
+      }
+
+      def cluster[T <: AnyRef](memo: RasMemo[T], cKey: RasClusterKey): 
Prepare[T] = {
+        new ClusterPrepare[T](memo, cKey)
+      }
+
+      private class TreePrepare[T <: AnyRef](
+          memo: RasMemo[T],
+          override val clusterKey: RasClusterKey,
+          children: Seq[Prepare[T]])
+        extends Prepare[T] {
+        private val ras = memo.ras
+
+        override def doInsert(node: T, constraintSet: PropertySet[T]): 
RasGroup[T] = {
+          assert(!ras.isGroupLeaf(node))
+          val childrenGroups = children
+            .zip(ras.planModel.childrenOf(node))
+            
.zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node))
+            .map {
+              case ((childPrepare, child), childConstraintSet) =>
+                childPrepare.doInsert(child, childConstraintSet)
+            }
+
+          val canUnsafe = ras.withNewChildren(node, childrenGroups.map(group 
=> group.self()))
+          val can = CanonicalNode(ras, canUnsafe)
+
+          memo.addToCluster(clusterKey, can)
+
+          val group = memo.memoTable.groupOf(clusterKey, constraintSet)
+          group
+        }
+      }
+
+      private class ClusterPrepare[T <: AnyRef](memo: RasMemo[T], cKey: 
RasClusterKey)
+        extends Prepare[T] {
+        private val ras = memo.ras
+        override def doInsert(node: T, constraintSet: PropertySet[T]): 
RasGroup[T] = {
+          assert(ras.isGroupLeaf(node))
+          memo.memoTable.groupOf(cKey, constraintSet)
+        }
+
+        override def clusterKey(): RasClusterKey = cKey
+      }
+    }
+  }
+
+  private object MemoCacheKey {
+    def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = {
+      assert(ras.isCanonical(self))
+      MemoCacheKey[T](ras.toUnsafeKey(self))
+    }
   }
 
-  private class NodeToClusterMap[T <: AnyRef](ras: Ras[T])
-    extends CanonicalNodeMap[T, RasClusterKey](ras)
+  private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T])
 }
 
 trait MemoStore[T <: AnyRef] {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index b54bd8811..3baba8eae 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -44,6 +44,7 @@ object MemoTable {
   trait Writable[T <: AnyRef] extends MemoTable[T] {
     def newCluster(metadata: Metadata): RasClusterKey
     def groupOf(key: RasClusterKey, propertySet: PropertySet[T]): RasGroup[T]
+    def dummyGroupOf(key: RasClusterKey): RasGroup[T]
 
     def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
     def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
index ca712cec4..61fa22e5e 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
@@ -37,7 +37,7 @@ object RasPath {
 
   object PathNode {
     def apply[T <: AnyRef](node: RasNode[T], children: Seq[PathNode[T]]): 
PathNode[T] = {
-      PathNodeImpl(node, children)
+      new PathNodeImpl(node, children)
     }
   }
 
@@ -61,7 +61,7 @@ object RasPath {
       keys: PathKeySet,
       height: Int,
       node: RasPath.PathNode[T]): RasPath[T] = {
-    RasPathImpl(ras, keys, height, node)
+    new RasPathImpl(ras, keys, height, node)
   }
 
   // Returns none if children doesn't share at least one path key.
@@ -103,25 +103,6 @@ object RasPath {
       PathNode(canonical, canonical.getChildrenGroups(allGroups).map(g => 
PathNode(g, List.empty))))
   }
 
-  // Aggregates paths that have same shape but different keys together.
-  // Currently not in use because of bad performance.
-  def aggregate[T <: AnyRef](ras: Ras[T], paths: Iterable[RasPath[T]]): 
Iterable[RasPath[T]] = {
-    // Scala has specialized optimization against small set of input of 
group-by.
-    // So it's better only to pass small inputs to this method if possible.
-    val grouped = paths.groupBy(_.node())
-    grouped.map {
-      case (node, paths) =>
-        val heights = paths.map(_.height()).toSeq.distinct
-        assert(heights.size == 1)
-        val height = heights.head
-        val keys = paths.map(_.keys().keys()).reduce[Set[PathKey]] {
-          case (one, other) =>
-            one.union(other)
-        }
-        RasPath(ras, PathKeySet(keys), height, node)
-    }
-  }
-
   def cartesianProduct[T <: AnyRef](
       ras: Ras[T],
       canonical: CanonicalNode[T],
@@ -171,12 +152,12 @@ object RasPath {
     }
   }
 
-  private case class PathNodeImpl[T <: AnyRef](
+  private class PathNodeImpl[T <: AnyRef](
       override val self: RasNode[T],
       override val children: Seq[PathNode[T]])
     extends PathNode[T]
 
-  private case class RasPathImpl[T <: AnyRef](
+  private class RasPathImpl[T <: AnyRef](
       override val ras: Ras[T],
       override val keys: PathKeySet,
       override val height: Int,
@@ -193,3 +174,19 @@ object RasPath {
     override def plan(): T = built
   }
 }
+
+trait InClusterPath[T <: AnyRef] {
+  def cluster(): RasClusterKey
+  def path(): RasPath[T]
+}
+
+object InClusterPath {
+  def apply[T <: AnyRef](cluster: RasClusterKey, path: RasPath[T]): 
InClusterPath[T] = {
+    new InClusterPathImpl(cluster, path)
+  }
+
+  private class InClusterPathImpl[T <: AnyRef](
+      override val cluster: RasClusterKey,
+      override val path: RasPath[T])
+    extends InClusterPath[T]
+}
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index b99001e93..0a7bf0c76 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -17,14 +17,14 @@
 package org.apache.gluten.ras.rule
 
 import org.apache.gluten.ras._
+import org.apache.gluten.ras.Ras.UnsafeKey
 import org.apache.gluten.ras.memo.Closure
-import org.apache.gluten.ras.path.RasPath
-import org.apache.gluten.ras.util.CanonicalNodeMap
+import org.apache.gluten.ras.path.InClusterPath
 
 import scala.collection.mutable
 
 trait RuleApplier[T <: AnyRef] {
-  def apply(path: RasPath[T]): Unit
+  def apply(icp: InClusterPath[T]): Unit
   def shape(): Shape[T]
 }
 
@@ -42,25 +42,27 @@ object RuleApplier {
 
   private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure: 
Closure[T], rule: RasRule[T])
     extends RuleApplier[T] {
-    private val cache = new CanonicalNodeMap[T, mutable.Set[T]](ras)
+    private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
 
-    override def apply(path: RasPath[T]): Unit = {
-      val can = path.node().self().asCanonical()
+    override def apply(icp: InClusterPath[T]): Unit = {
+      val cKey = icp.cluster()
+      val path = icp.path()
       val plan = path.plan()
-      val appliedPlans = cache.getOrElseUpdate(can, mutable.Set())
-      if (appliedPlans.contains(plan)) {
+      val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
+      val pKey = ras.toUnsafeKey(plan)
+      if (appliedPlans.contains(pKey)) {
         return
       }
-      apply0(can, plan)
-      appliedPlans += plan
+      apply0(cKey, plan)
+      appliedPlans += pKey
     }
 
-    private def apply0(can: CanonicalNode[T], plan: T): Unit = {
+    private def apply0(cKey: RasClusterKey, plan: T): Unit = {
       val equivalents = rule.shift(plan)
       equivalents.foreach {
         equiv =>
           closure
-            .openFor(can)
+            .openFor(cKey)
             .memorize(equiv, ras.propertySetFactory().get(equiv))
       }
     }
@@ -73,32 +75,35 @@ object RuleApplier {
       closure: Closure[T],
       rule: EnforcerRule[T])
     extends RuleApplier[T] {
-    private val cache = new CanonicalNodeMap[T, mutable.Set[T]](ras)
+    private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
     private val constraint = rule.constraint()
     private val constraintDef = constraint.definition()
 
-    override def apply(path: RasPath[T]): Unit = {
+    override def apply(icp: InClusterPath[T]): Unit = {
+      val cKey = icp.cluster()
+      val path = icp.path()
       val can = path.node().self().asCanonical()
       if (can.propSet().get(constraintDef).satisfies(constraint)) {
         return
       }
       val plan = path.plan()
-      val appliedPlans = cache.getOrElseUpdate(can, mutable.Set())
-      if (appliedPlans.contains(plan)) {
+      val pKey = ras.toUnsafeKey(plan)
+      val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
+      if (appliedPlans.contains(pKey)) {
         return
       }
-      apply0(can, plan)
-      appliedPlans += plan
+      apply0(cKey, plan)
+      appliedPlans += pKey
     }
 
-    private def apply0(can: CanonicalNode[T], plan: T): Unit = {
+    private def apply0(cKey: RasClusterKey, plan: T): Unit = {
       val propSet = ras.propertySetFactory().get(plan)
       val constraintSet = propSet.withProp(constraint)
       val equivalents = rule.shift(plan)
       equivalents.foreach {
         equiv =>
           closure
-            .openFor(can)
+            .openFor(cKey)
             .memorize(equiv, constraintSet)
       }
     }
@@ -110,11 +115,11 @@ object RuleApplier {
     extends RuleApplier[T] {
     private val ruleShape = rule.shape()
 
-    override def apply(path: RasPath[T]): Unit = {
-      if (!ruleShape.identify(path)) {
+    override def apply(icp: InClusterPath[T]): Unit = {
+      if (!ruleShape.identify(icp.path())) {
         return
       }
-      rule.apply(path)
+      rule.apply(icp)
     }
 
     override def shape(): Shape[T] = ruleShape
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala
deleted file mode 100644
index 887e00bdc..000000000
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.ras.util
-
-import org.apache.gluten.ras.{CanonicalNode, Ras}
-
-import scala.collection.mutable
-
-// Arbitrary node key.
-class NodeKey[T <: AnyRef](ras: Ras[T], val node: T) {
-  override def hashCode(): Int = ras.planModel.hashCode(node)
-
-  override def equals(obj: Any): Boolean = {
-    obj match {
-      case other: NodeKey[T] => ras.planModel.equals(node, other.node)
-      case _ => false
-    }
-  }
-
-  override def toString(): String = s"NodeKey($node)"
-}
-
-// Canonical node map.
-class CanonicalNodeMap[T <: AnyRef, V](ras: Ras[T]) {
-  private val map: mutable.Map[NodeKey[T], V] = mutable.Map()
-
-  def contains(node: CanonicalNode[T]): Boolean = {
-    map.contains(keyOf(node))
-  }
-
-  def put(node: CanonicalNode[T], value: V): Unit = {
-    map.put(keyOf(node), value)
-  }
-
-  def get(node: CanonicalNode[T]): V = {
-    map(keyOf(node))
-  }
-
-  def getOrElseUpdate(node: CanonicalNode[T], op: => V): V = {
-    map.getOrElseUpdate(keyOf(node), op)
-  }
-
-  private def keyOf(node: CanonicalNode[T]): NodeKey[T] = {
-    new NodeKey(ras, node.self())
-  }
-}
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 600a61edc..11f6051b0 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -43,13 +43,13 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], 
memoState: MemoState[T], best
 
     object IsBestNode {
       def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])): 
Boolean = {
-        bestNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(), 
nodeAndGroupToTest._1))
+        bestNodes(InGroupNode(nodeAndGroupToTest._2.id(), 
nodeAndGroupToTest._1))
       }
     }
 
     object IsWinnerNode {
       def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])): 
Boolean = {
-        winnerNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(), 
nodeAndGroupToTest._1))
+        winnerNodes(InGroupNode(nodeAndGroupToTest._2.id(), 
nodeAndGroupToTest._1))
       }
     }
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
index acd96442c..f1c319873 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
@@ -97,7 +97,7 @@ class OperationSuite extends AnyFunSuite {
 
     val ras =
       Ras[TestNode](
-        PlanModelImpl,
+        planModel,
         CostModelImpl,
         MetadataModelImpl,
         PropertyModelImpl,
@@ -108,7 +108,7 @@ class OperationSuite extends AnyFunSuite {
     val optimized = planner.plan()
     assert(optimized == Unary2(49, Leaf2(29)))
 
-    planModel.assertPlanOpsLte((200, 50, 50, 50))
+    planModel.assertPlanOpsLte((200, 50, 100, 50))
 
     val state = planner.newState()
     val allPaths = state.memoState().collectAllPaths(RasPath.INF_DEPTH).toSeq
@@ -127,7 +127,7 @@ class OperationSuite extends AnyFunSuite {
 
     val ras =
       Ras[TestNode](
-        PlanModelImpl,
+        planModel,
         CostModelImpl,
         MetadataModelImpl,
         PropertyModelImpl,
@@ -163,7 +163,7 @@ class OperationSuite extends AnyFunSuite {
 
     val ras =
       Ras[TestNode](
-        PlanModelImpl,
+        planModel,
         CostModelImpl,
         MetadataModelImpl,
         PropertyModelImpl,
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index 8a68bbba8..e48604116 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.ras
 import org.apache.gluten.ras.Best.BestNotFoundException
 import org.apache.gluten.ras.RasConfig.PlannerType
 import org.apache.gluten.ras.RasSuiteBase._
+import org.apache.gluten.ras.memo.Memo
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
@@ -37,6 +38,30 @@ abstract class PropertySuite extends AnyFunSuite {
 
   protected def conf: RasConfig
 
+  test("Group memo - cache") {
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        NodeTypePropertyModelWithOutEnforcerRules,
+        ExplainImpl,
+        RasRule.Factory.none())
+        .withNewConfig(_ => conf)
+
+    val memo = Memo(ras)
+
+    memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1, 
TypedLeaf(TypeA, 1)))))
+    val leafGroup = memo.memorize(ras, TypedLeaf(TypeA, 1))
+    memo
+      .openFor(leafGroup.clusterKey())
+      .memorize(ras, TypedLeaf(TypeB, 1))
+    memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1, 
TypedLeaf(TypeB, 1)))))
+    val state = memo.newState()
+    assert(state.allClusters().size == 4)
+    assert(state.getGroupCount() == 8)
+  }
+
   test(s"Get property") {
     val leaf = PLeaf(10, DummyProperty(0))
     val unary = PUnary(5, DummyProperty(0), leaf)
@@ -112,7 +137,7 @@ abstract class PropertySuite extends AnyFunSuite {
         TypedLeaf(TypeB, 10)))
   }
 
-  ignore(s"Memo cache hit - (A, B)") {
+  test(s"Memo cache hit - (A, B)") {
     object ReplaceLeafAByLeafBRule extends RasRule[TestNode] {
       override def shift(node: TestNode): Iterable[TestNode] = {
         node match {
@@ -163,8 +188,8 @@ abstract class PropertySuite extends AnyFunSuite {
     val out = planner.plan()
     assert(out == TypedLeaf(TypeA, 1))
 
-    // FIXME: Cluster 2 and 1 are currently able to merge but it's better to
-    //  have them identified as the same right after HitCacheOp is applied
+    // Cluster 2 and 1 are able to merge but we'd make sure
+    // they are identified as the same right after HitCacheOp is applied
     val clusterCount = planner.newState().memoState().allClusters().size
     assert(clusterCount == 2)
   }
@@ -531,6 +556,7 @@ object PropertySuite {
   }
 
   object DummyPropertyDef extends PropertyDef[TestNode, DummyProperty] {
+    override def any(): DummyProperty = DummyProperty(Int.MinValue)
     override def getProperty(plan: TestNode): DummyProperty = {
       plan match {
         case Group(_, _, _) => throw new IllegalStateException()
@@ -669,6 +695,8 @@ object PropertySuite {
     }
 
     override def toString: String = "NodeTypeDef"
+
+    override def any(): NodeType = TypeAny
   }
 
   trait NodeType extends Property[TestNode] {
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index abb8bdecd..0ad825181 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -66,8 +66,7 @@ abstract class RasSuite extends AnyFunSuite {
     val group = memo.memorize(ras, Unary(50, Unary(50, Leaf(30))))
     val state = memo.newState()
     assert(group.nodes(state).size == 1)
-    val can = group.nodes(state).head.asCanonical()
-    memo.openFor(can).memorize(ras, Unary(30, Leaf(90)))
+    memo.openFor(group.clusterKey()).memorize(ras, Unary(30, Leaf(90)))
     assert(memo.newState().allGroups().size == 4)
   }
 
@@ -87,8 +86,7 @@ abstract class RasSuite extends AnyFunSuite {
     assert(group.nodes(state).size == 1)
     val leaf40Group = memo.memorize(ras, Leaf(40))
     assert(leaf40Group.nodes(state).size == 1)
-    val can = leaf40Group.nodes(state).head.asCanonical()
-    memo.openFor(can).memorize(ras, Leaf(30))
+    memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30))
     assert(memo.newState().allGroups().size == 3)
   }
 
@@ -108,8 +106,7 @@ abstract class RasSuite extends AnyFunSuite {
     assert(group.nodes(state).size == 1)
     val leaf40Group = memo.memorize(ras, Leaf(40))
     assert(leaf40Group.nodes(state).size == 1)
-    val can = leaf40Group.nodes(state).head.asCanonical()
-    memo.openFor(can).memorize(ras, Leaf(30))
+    memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30))
     assert(memo.newState().allGroups().size == 5)
   }
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
index e092ea4f2..8158aec16 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.ras.path
 
-import org.apache.gluten.ras.Ras
+import org.apache.gluten.ras.{CanonicalNode, Ras}
 import org.apache.gluten.ras.RasSuiteBase._
 import org.apache.gluten.ras.mock.MockRasPath
 import org.apache.gluten.ras.rule.RasRule
@@ -26,7 +26,7 @@ import org.scalatest.funsuite.AnyFunSuite
 class RasPathSuite extends AnyFunSuite {
   import RasPathSuite._
 
-  test("Path aggregate - empty") {
+  test("Cartesian product - empty") {
     val ras =
       Ras[TestNode](
         PlanModelImpl,
@@ -35,10 +35,21 @@ class RasPathSuite extends AnyFunSuite {
         PropertyModelImpl,
         ExplainImpl,
         RasRule.Factory.reuse(List.empty))
-    assert(RasPath.aggregate(ras, List.empty) == List.empty)
+    assert(
+      RasPath.cartesianProduct(
+        ras,
+        CanonicalNode(ras, Binary("b", ras.dummyGroupLeaf(), 
ras.dummyGroupLeaf())),
+        List(
+          List.empty,
+          List(
+            MockRasPath.mock(
+              ras,
+              Leaf("l", 1),
+              PathKeySet(Set(DummyPathKey(3)))
+            )))
+      ) == List.empty)
   }
-
-  test("Path aggregate") {
+  test("Cartesian product") {
     val ras =
       Ras[TestNode](
         PlanModelImpl,
@@ -54,6 +65,7 @@ class RasPathSuite extends AnyFunSuite {
     val n4 = "n4"
     val n5 = "n5"
     val n6 = "n6"
+
     val path1 = MockRasPath.mock(
       ras,
       Unary(n5, Leaf(n6, 1)),
@@ -66,31 +78,37 @@ class RasPathSuite extends AnyFunSuite {
     )
     val path3 = MockRasPath.mock(
       ras,
-      Unary(n1, Unary(n2, Leaf(n3, 1))),
-      PathKeySet(Set(DummyPathKey(1), DummyPathKey(2)))
+      Leaf(n6, 1),
+      PathKeySet(Set(DummyPathKey(1)))
     )
     val path4 = MockRasPath.mock(
       ras,
-      Unary(n1, Unary(n2, Leaf(n3, 1))),
-      PathKeySet(Set(DummyPathKey(4)))
+      Leaf(n3, 1),
+      PathKeySet(Set(DummyPathKey(3)))
     )
+
     val path5 = MockRasPath.mock(
       ras,
-      Unary(n5, Leaf(n6, 1)),
+      Unary(n2, Leaf(n3, 1)),
       PathKeySet(Set(DummyPathKey(4)))
     )
-    val out = RasPath
-      .aggregate(ras, List(path1, path2, path3, path4, path5))
-      .toSeq
-      .sortBy(_.height())
-    assert(out.size == 2)
-    assert(out.head.height() == 2)
-    assert(out.head.plan() == Unary(n5, Leaf(n6, 1)))
-    assert(out.head.keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(3), 
DummyPathKey(4))))
 
-    assert(out(1).height() == 3)
-    assert(out(1).plan() == Unary(n1, Unary(n2, Leaf(n3, 1))))
-    assert(out(1).keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(2), 
DummyPathKey(4))))
+    val product = RasPath.cartesianProduct(
+      ras,
+      CanonicalNode(ras, Binary(n4, ras.dummyGroupLeaf(), 
ras.dummyGroupLeaf())),
+      List(
+        List(path1, path2),
+        List(path3, path4, path5)
+      ))
+
+    val out = product.toList
+    assert(out.size == 3)
+
+    assert(
+      out.map(_.plan()) == List(
+        Binary(n4, Unary(n5, Leaf(n6, 1)), Leaf(n6, 1)),
+        Binary(n4, Unary(n5, Leaf(n6, 1)), Leaf(n3, 1)),
+        Binary(n4, Unary(n1, Unary(n2, Leaf(n3, 1))), Leaf(n6, 1))))
   }
 }
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index cab3d1818..de71cba5b 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -262,6 +262,8 @@ object DistributedSuite {
       case (d: Distribution, p: DNode) => p.getDistributionConstraints(d)
       case _ => throw new UnsupportedOperationException()
     }
+
+    override def any(): Distribution = AnyDistribution
   }
 
   trait Ordering extends Property[TestNode]
@@ -315,6 +317,8 @@ object DistributedSuite {
         case (o: Ordering, p: DNode) => p.getOrderingConstraints(o)
         case _ => throw new UnsupportedOperationException()
       }
+
+    override def any(): Ordering = AnyOrdering
   }
 
   private class EnforceDistribution(distribution: Distribution) extends 
RasRule[TestNode] {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to