comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r650300650



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into 
fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the 
wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, 
DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted 
accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const 
CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = 
new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& 
accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = 
accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) 
const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = 
accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) 
const {
+    /* Returns whether t is a type with only target mixed precision type 
elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const 
DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's 
already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an 
integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the 
example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the 
following after processing A's output:
   ```
   (A, fp16): cast_to_fp16
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, 
fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and 
a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> 
cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable 
as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not 
necessary. I think we can simply return `expr` when `expr_dtype == 
wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it 
possible to create this op lazily? For example, when casting the output, we 
only create a cache entry but don't really create the node. Once the entry is 
queried by the followed ops for the first time, we create the cast node and 
update the cache.
   
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into 
fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the 
wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, 
DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted 
accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const 
CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = 
new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& 
accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = 
accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) 
const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = 
accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) 
const {
+    /* Returns whether t is a type with only target mixed precision type 
elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const 
DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's 
already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an 
integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the 
example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the 
following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, 
fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and 
a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> 
cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable 
as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not 
necessary. I think we can simply return `expr` when `expr_dtype == 
wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it 
possible to create this op lazily? For example, when casting the output, we 
only create a cache entry but don't really create the node. Once the entry is 
queried by the followed ops for the first time, we create the cast node and 
update the cache.
   
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into 
fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the 
wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, 
DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted 
accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const 
CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = 
new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& 
accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = 
accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) 
const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = 
accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) 
const {
+    /* Returns whether t is a type with only target mixed precision type 
elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const 
DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's 
already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an 
integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the 
example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the 
following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, 
fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and 
a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> 
cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable 
as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not 
necessary. I think we can simply return `expr` when `expr_dtype == 
wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it 
possible to create this op lazily? For example, when casting the output, we 
only create a cache entry but don't really create the node. Once the entry is 
queried by the followed ops for the first time, we create the cast node and 
update the cache.
   
   Another direction is removing the case and let this pass generate cast ops 
as many as it wants, and we run SimplifyExpr pass afterward to cancel 
back-to-back cast ops. I would actually recommend this approach due to its 
simple design if it doesn't hurt the final performance.
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into 
fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the 
wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, 
DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted 
accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const 
CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = 
new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& 
accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = 
accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) 
const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = 
accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) 
const {
+    /* Returns whether t is a type with only target mixed precision type 
elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const 
DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's 
already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an 
integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the 
example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the 
following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, 
fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and 
a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> 
cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable 
as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not 
necessary. I think we can simply return `expr` when `expr_dtype == 
wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it 
possible to create this op lazily? For example, when casting the output, we 
only create a cache entry but don't really create the node. Once the entry is 
queried by the followed ops for the first time, we create the cast node and 
update the cache.
   
   Another direction is removing the cache and let this pass generate cast ops 
as many as it wants, and we run SimplifyExpr pass afterward to cancel 
back-to-back cast ops. I would actually recommend this approach due to its 
simple design if it doesn't hurt the final performance.
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into 
fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the 
wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, 
DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted 
accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const 
CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = 
new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& 
accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = 
accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) 
const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = 
accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) 
const {
+    /* Returns whether t is a type with only target mixed precision type 
elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const 
DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's 
already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an 
integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the 
example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the 
following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, 
fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and 
a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> 
cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable 
as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not 
necessary. I think we can simply return `expr` when `expr_dtype == 
wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it 
possible to create this op lazily? For example, when casting the output, we 
only create a cache entry but don't really create the node. Once the entry is 
queried by the followed ops for the first time, we create the cast node and 
update the cache.
   
   Another direction I would actually recommend is removing the cache and 
letting this pass generate cast ops as many as it wants, and we run 
SimplifyExpr pass afterward to cancel back-to-back cast ops. IIUC, this should 
generate the same IR as the current pass, so it doesn't hurt the final 
performance (please correct me if I missed something).
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to