AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652142144



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn 
a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to 
the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but 
don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in 
lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the 
wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, 
DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, 
String]
+// The fields are          : [ConversionCategory, accumulation_datatype, 
output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = 
runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) 
const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate 
the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = 
cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& 
accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = 
accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) 
const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = 
accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }

Review comment:
       Hey Matthew, do you have an example model I could use to understand this 
problem a little bit more?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to