This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e5f483cd5f [REFACTOR][NODE] Remove node redirect headers (#18829)
e5f483cd5f is described below
commit e5f483cd5f52481c35e728819ed450268bdc3ea3
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Feb 27 06:36:57 2026 -0500
[REFACTOR][NODE] Remove node redirect headers (#18829)
---
include/tvm/ir/attrs.h | 4 +-
include/tvm/ir/env_func.h | 2 +-
include/tvm/ir/expr.h | 4 +-
include/tvm/ir/instrument.h | 2 +-
include/tvm/ir/module.h | 1 +
include/tvm/{node => ir}/serialization.h | 24 +++--
include/tvm/ir/source_map.h | 2 +-
include/tvm/ir/type.h | 1 -
include/tvm/node/node.h | 63 ------------
include/tvm/node/reflection.h | 41 --------
include/tvm/node/serialization.h | 27 +-----
include/tvm/node/structural_equal.h | 90 +----------------
include/tvm/node/structural_hash.h | 108 +--------------------
include/tvm/relax/distributed/axis_group_graph.h | 4 +-
include/tvm/relax/exec_builder.h | 5 +-
include/tvm/relax/expr.h | 2 +-
include/tvm/relax/struct_info.h | 3 +-
include/tvm/runtime/object.h | 3 +
include/tvm/s_tir/meta_schedule/arg_info.h | 1 -
include/tvm/script/ir_builder/base.h | 2 +-
include/tvm/script/ir_builder/ir/frame.h | 1 -
include/tvm/script/ir_builder/ir/ir.h | 1 -
include/tvm/script/printer/doc.h | 1 -
include/tvm/script/printer/ir_docsifier.h | 2 +-
include/tvm/script/printer/ir_docsifier_functor.h | 2 +-
include/tvm/target/tag.h | 1 -
include/tvm/target/target.h | 2 +-
include/tvm/target/target_kind.h | 2 +-
include/tvm/target/virtual_device.h | 2 +-
include/tvm/tir/expr.h | 1 -
include/tvm/tir/function.h | 1 +
include/tvm/tir/index_map.h | 1 +
include/tvm/tir/stmt.h | 1 +
include/tvm/tir/var.h | 1 -
python/tvm/runtime/_tensor.py | 2 +-
.../tvm/s_tir/dlight/analysis/common_analysis.py | 1 -
src/arith/conjunctive_normal_form.cc | 4 +-
src/arith/iter_affine_map.cc | 6 +-
src/arith/solve_linear_inequality.cc | 4 +-
src/arith/transitive_comparison_analyzer.cc | 2 +-
src/contrib/msc/core/printer/msc_base_printer.cc | 1 +
src/ir/module.cc | 2 +-
src/{node => ir}/serialization.cc | 2 +-
src/{node => ir}/structural_equal.cc | 10 +-
src/{node => ir}/structural_hash.cc | 8 +-
src/ir/transform.cc | 6 +-
src/node/reflection.cc | 1 -
src/node/script_printer.cc | 1 +
src/relax/analysis/struct_info_analysis.cc | 8 +-
.../backend/adreno/annotate_custom_storage.cc | 2 +-
.../backend/adreno/fold_vdevice_scope_change.cc | 2 +-
src/relax/backend/vm/vm_shape_lower.cc | 2 +-
.../distributed/transform/legalize_redistribute.cc | 2 +-
src/relax/distributed/transform/lower_distir.cc | 3 +-
.../distributed/transform/propagate_sharding.cc | 9 +-
src/relax/ir/block_builder.cc | 4 +-
src/relax/ir/dataflow_block_rewriter.cc | 2 +-
src/relax/ir/dataflow_expr_rewriter.cc | 8 +-
src/relax/ir/dataflow_matcher.cc | 15 +--
src/relax/ir/expr_functor.cc | 2 +-
src/relax/op/distributed/utils.cc | 4 +-
src/relax/op/tensor/manipulate.cc | 2 +-
src/relax/training/utils.cc | 2 +-
src/relax/transform/alter_op_impl.cc | 2 +-
src/relax/transform/canonicalize_bindings.cc | 5 +-
src/relax/transform/convert_layout.cc | 2 +-
src/relax/transform/eliminate_common_subexpr.cc | 6 +-
src/relax/transform/fold_constant.cc | 3 +-
src/relax/transform/fuse_tir.cc | 2 +-
src/relax/transform/kill_after_last_use.cc | 2 +-
src/relax/transform/lift_transform_params.cc | 4 +-
src/relax/transform/remove_unused_outputs.cc | 1 +
.../specialize_primfunc_based_on_callsite.cc | 2 +-
src/s_tir/meta_schedule/module_equality.cc | 6 +-
src/s_tir/meta_schedule/utils.h | 7 +-
src/s_tir/schedule/primitive/compute_inline.cc | 2 +-
.../schedule/primitive/layout_transformation.cc | 3 +-
src/s_tir/schedule/utils.h | 2 +-
src/s_tir/transform/inject_software_pipeline.cc | 3 +-
.../transform/using_assume_to_reduce_branches.cc | 7 +-
src/script/printer/relax/expr.cc | 1 +
src/script/printer/relax/utils.h | 2 +-
src/script/printer/utils.h | 2 +-
src/support/ffi_testing.cc | 1 +
src/support/scalars.h | 1 +
src/target/source/codegen_metal.cc | 1 +
src/te/operation/compute_op.cc | 2 +-
src/te/operation/create_primfunc.cc | 2 +-
src/tir/transform/common_subexpr_elim_tools.cc | 2 +-
src/tir/transform/common_subexpr_elim_tools.h | 12 +--
src/tir/transform/vectorize_loop.cc | 3 +-
tests/cpp/arith_simplify_test.cc | 5 +-
tests/cpp/expr_test.cc | 2 +-
tests/cpp/nested_msg_test.cc | 12 ++-
tests/cpp/target/virtual_device_test.cc | 10 +-
tests/python/relax/test_group_gemm_flashinfer.py | 2 +-
tests/scripts/release/make_notes.py | 1 -
97 files changed, 177 insertions(+), 468 deletions(-)
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 3da7f8d1c1..cf52ec32ea 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -29,12 +29,12 @@
#define TVM_IR_ATTRS_H_
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
-#include <tvm/node/structural_equal.h>
-#include <tvm/node/structural_hash.h>
#include <functional>
#include <string>
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index 264198333e..65d9a6f387 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -26,7 +26,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
#include <string>
#include <utility>
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index faf2c18c1c..7c4b8e7cb2 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -28,7 +28,9 @@
#include <tvm/ffi/string.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
-#include <tvm/node/node.h>
+#include <tvm/node/cast.h>
+#include <tvm/node/repr_printer.h>
+#include <tvm/node/script_printer.h>
#include <tvm/runtime/object.h>
#include <algorithm>
diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h
index c14549f412..cfd859406d 100644
--- a/include/tvm/ir/instrument.h
+++ b/include/tvm/ir/instrument.h
@@ -28,7 +28,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
-#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
#include <utility>
#include <vector>
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index becd19ed70..543c895ce5 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -33,6 +33,7 @@
#include <tvm/ir/global_info.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
+#include <tvm/node/script_printer.h>
#include <string>
#include <unordered_map>
diff --git a/include/tvm/node/serialization.h b/include/tvm/ir/serialization.h
similarity index 67%
copy from include/tvm/node/serialization.h
copy to include/tvm/ir/serialization.h
index 5a8e098cfd..59bdb87067 100644
--- a/include/tvm/node/serialization.h
+++ b/include/tvm/ir/serialization.h
@@ -18,29 +18,33 @@
*/
/*!
- * Utility functions for serialization.
- * \file tvm/node/serialization.h
+ * \file tvm/ir/serialization.h
+ * \brief Utility functions for serialization.
+ *
+ * This is a thin forwarding header to ffi/extra/serialization.h.
+ * Prefer using ffi::ToJSONGraph / ffi::FromJSONGraph directly.
*/
-#ifndef TVM_NODE_SERIALIZATION_H_
-#define TVM_NODE_SERIALIZATION_H_
+#ifndef TVM_IR_SERIALIZATION_H_
+#define TVM_IR_SERIALIZATION_H_
+#include <tvm/ffi/extra/json.h>
+#include <tvm/ffi/extra/serialization.h>
#include <tvm/runtime/base.h>
-#include <tvm/runtime/object.h>
#include <string>
namespace tvm {
+
/*!
- * \brief save the node as well as all the node it depends on as json.
- * This can be used to serialize any TVM object
+ * \brief Save the node as well as all the node it depends on as json.
+ * This can be used to serialize any TVM object.
*
* \return the string representation of the node.
*/
TVM_DLL std::string SaveJSON(ffi::Any node);
/*!
- * \brief Internal implementation of LoadJSON
- * Load tvm Node object from json and return a shared_ptr of Node.
+ * \brief Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
@@ -48,4 +52,4 @@ TVM_DLL std::string SaveJSON(ffi::Any node);
TVM_DLL ffi::Any LoadJSON(std::string json_str);
} // namespace tvm
-#endif // TVM_NODE_SERIALIZATION_H_
+#endif // TVM_IR_SERIALIZATION_H_
diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h
index 60a30ffe17..19aba9461c 100644
--- a/include/tvm/ir/source_map.h
+++ b/include/tvm/ir/source_map.h
@@ -23,9 +23,9 @@
#ifndef TVM_IR_SOURCE_MAP_H_
#define TVM_IR_SOURCE_MAP_H_
+#include <tvm/ffi/container/array.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/object.h>
#include <fstream>
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 5e38f38769..902778c3db 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -52,7 +52,6 @@
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/source_map.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
deleted file mode 100644
index 734a28c133..0000000000
--- a/include/tvm/node/node.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file tvm/node/node.h
- * \brief Definitions and helper macros for IR/AST nodes.
- *
- * The node folder contains base utilities for IR/AST nodes,
- * invariant of which specific language dialect.
- *
- * We implement AST/IR nodes as sub-classes of runtime::Object.
- * The base class Node is just an alias of runtime::Object.
- *
- * Besides the runtime type checking provided by Object,
- * node folder contains additional functionalities such as
- * reflection and serialization, which are important features
- * for building a compiler infra.
- */
-#ifndef TVM_NODE_NODE_H_
-#define TVM_NODE_NODE_H_
-
-#include <tvm/ffi/memory.h>
-#include <tvm/node/cast.h>
-#include <tvm/node/repr_printer.h>
-#include <tvm/node/structural_equal.h>
-#include <tvm/node/structural_hash.h>
-#include <tvm/runtime/base.h>
-#include <tvm/runtime/object.h>
-
-#include <string>
-#include <type_traits>
-#include <utility>
-#include <vector>
-
-namespace tvm {
-
-using ffi::Any;
-using ffi::AnyView;
-using ffi::Object;
-using ffi::ObjectPtr;
-using ffi::ObjectPtrEqual;
-using ffi::ObjectPtrHash;
-using ffi::ObjectRef;
-using ffi::PackedArgs;
-using ffi::TypeIndex;
-
-} // namespace tvm
-#endif // TVM_NODE_NODE_H_
diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
deleted file mode 100644
index d5716f96f6..0000000000
--- a/include/tvm/node/reflection.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file tvm/node/reflection.h
- * \brief Reflection utilities for IR/AST nodes.
- */
-#ifndef TVM_NODE_REFLECTION_H_
-#define TVM_NODE_REFLECTION_H_
-
-#include <tvm/ffi/container/map.h>
-#include <tvm/ffi/container/string.h>
-
-namespace tvm {
-
-/*!
- * \brief Create an object from a type key and a map of fields.
- * \param type_key The type key of the object.
- * \param fields The fields of the object.
- * \return The created object.
- */
-TVM_DLL ffi::Any CreateObject(const ffi::String& type_key,
- const ffi::Map<ffi::String, ffi::Any>& fields);
-
-} // namespace tvm
-#endif // TVM_NODE_REFLECTION_H_
diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h
index 5a8e098cfd..892ff5fdf1 100644
--- a/include/tvm/node/serialization.h
+++ b/include/tvm/node/serialization.h
@@ -18,34 +18,13 @@
*/
/*!
- * Utility functions for serialization.
* \file tvm/node/serialization.h
+ * \brief Forwarding header. Use tvm/ir/serialization.h instead.
*/
#ifndef TVM_NODE_SERIALIZATION_H_
#define TVM_NODE_SERIALIZATION_H_
-#include <tvm/runtime/base.h>
-#include <tvm/runtime/object.h>
+// This header has moved to tvm/ir/serialization.h
+#include <tvm/ir/serialization.h>
-#include <string>
-
-namespace tvm {
-/*!
- * \brief save the node as well as all the node it depends on as json.
- * This can be used to serialize any TVM object
- *
- * \return the string representation of the node.
- */
-TVM_DLL std::string SaveJSON(ffi::Any node);
-
-/*!
- * \brief Internal implementation of LoadJSON
- * Load tvm Node object from json and return a shared_ptr of Node.
- * \param json_str The json string to load from.
- *
- * \return The shared_ptr of the Node.
- */
-TVM_DLL ffi::Any LoadJSON(std::string json_str);
-
-} // namespace tvm
#endif // TVM_NODE_SERIALIZATION_H_
diff --git a/include/tvm/node/structural_equal.h
b/include/tvm/node/structural_equal.h
index 4f00e1770b..cbf7652b80 100644
--- a/include/tvm/node/structural_equal.h
+++ b/include/tvm/node/structural_equal.h
@@ -18,96 +18,12 @@
*/
/*!
* \file tvm/node/structural_equal.h
- * \brief Structural equality comparison.
+ * \brief Forwarding header. Use tvm/ffi/extra/structural_equal.h instead.
*/
#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
#define TVM_NODE_STRUCTURAL_EQUAL_H_
-#include <tvm/ffi/container/array.h>
-#include <tvm/ffi/reflection/access_path.h>
-#include <tvm/node/functor.h>
-#include <tvm/runtime/data_type.h>
+// This header has moved to tvm/ffi/extra/structural_equal.h
+#include <tvm/ffi/extra/structural_equal.h>
-#include <cmath>
-#include <string>
-
-namespace tvm {
-
-/*!
- * \brief Equality definition of base value class.
- */
-class BaseValueEqual {
- public:
- bool operator()(const double& lhs, const double& rhs) const {
- if (std::isnan(lhs) && std::isnan(rhs)) {
- // IEEE floats do not compare as equivalent to each other.
- // However, for the purpose of comparing IR representation, two
- // NaN values are equivalent.
- return true;
- } else if (std::isnan(lhs) || std::isnan(rhs)) {
- return false;
- } else if (lhs == rhs) {
- return true;
- } else {
- // fuzzy float pt comparison
- constexpr double atol = 1e-9;
- double diff = lhs - rhs;
- return diff > -atol && diff < atol;
- }
- }
-
- bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs
== rhs; }
- bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs
== rhs; }
- bool operator()(const ffi::Optional<int64_t>& lhs, const
ffi::Optional<int64_t>& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const ffi::Optional<double>& lhs, const
ffi::Optional<double>& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
- bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs;
}
- bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs; }
- bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs
== rhs; }
- template <typename ENum, typename = typename
std::enable_if<std::is_enum<ENum>::value>::type>
- bool operator()(const ENum& lhs, const ENum& rhs) const {
- return lhs == rhs;
- }
-};
-
-/*!
- * \brief Content-aware structural equality comparator for objects.
- *
- * The structural equality is recursively defined in the DAG of IR nodes via
SEqual.
- * There are two kinds of nodes:
- *
- * - Graph node: a graph node in lhs can only be mapped as equal to
- * one and only one graph node in rhs.
- * - Normal node: equality is recursively defined without the restriction
- * of graph nodes.
- *
- * Vars(tir::Var, relax::Var) nodes are graph nodes.
- *
- * A var-type node(e.g. tir::Var) can be mapped as equal to another var
- * with the same type if one of the following condition holds:
- *
- * - They appear in a same definition point(e.g. function argument).
- * - They points to the same VarNode via the same_as relation.
- * - They appear in a same usage point, and map_free_vars is set to be True.
- */
-class StructuralEqual : public BaseValueEqual {
- public:
- // inheritate operator()
- using BaseValueEqual::operator();
- /*!
- * \brief Compare objects via strutural equal.
- * \param lhs The left operand.
- * \param rhs The right operand.
- * \param map_free_params Whether or not to map free variables.
- * \return The comparison result.
- */
- TVM_DLL bool operator()(const ffi::Any& lhs, const ffi::Any& rhs,
- const bool map_free_params = false) const;
-};
-
-} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_EQUAL_H_
diff --git a/include/tvm/node/structural_hash.h
b/include/tvm/node/structural_hash.h
index ba7cbaf88a..8f90820b15 100644
--- a/include/tvm/node/structural_hash.h
+++ b/include/tvm/node/structural_hash.h
@@ -17,113 +17,13 @@
* under the License.
*/
/*!
- * \file tvm/node/structural_equal.h
- * \brief Structural hash class.
+ * \file tvm/node/structural_hash.h
+ * \brief Forwarding header. Use tvm/ffi/extra/structural_hash.h instead.
*/
#ifndef TVM_NODE_STRUCTURAL_HASH_H_
#define TVM_NODE_STRUCTURAL_HASH_H_
-#include <tvm/node/functor.h>
-#include <tvm/runtime/data_type.h>
-#include <tvm/runtime/tensor.h>
+// This header has moved to tvm/ffi/extra/structural_hash.h
+#include <tvm/ffi/extra/structural_hash.h>
-#include <cmath>
-#include <functional>
-#include <limits>
-#include <string>
-
-namespace tvm {
-
-/*!
- * \brief Hash definition of base value classes.
- */
-class BaseValueHash {
- protected:
- template <typename T, typename U>
- uint64_t Reinterpret(T value) const {
- union Union {
- T a;
- U b;
- } u;
- static_assert(sizeof(Union) == sizeof(T), "sizeof(Union) != sizeof(T)");
- static_assert(sizeof(Union) == sizeof(U), "sizeof(Union) != sizeof(U)");
- u.b = 0;
- u.a = value;
- return u.b;
- }
-
- public:
- uint64_t operator()(const float& key) const { return Reinterpret<float,
uint32_t>(key); }
- uint64_t operator()(const double& key) const {
- if (std::isnan(key)) {
- // The IEEE format defines more than one bit-pattern that
- // represents NaN. For the purpose of comparing IR
- // representations, all NaN values are considered equivalent.
- return Reinterpret<double,
uint64_t>(std::numeric_limits<double>::quiet_NaN());
- } else {
- return Reinterpret<double, uint64_t>(key);
- }
- }
- uint64_t operator()(const int64_t& key) const { return Reinterpret<int64_t,
uint64_t>(key); }
- uint64_t operator()(const uint64_t& key) const { return key; }
- uint64_t operator()(const int& key) const { return Reinterpret<int,
uint32_t>(key); }
- uint64_t operator()(const bool& key) const { return key; }
- uint64_t operator()(const runtime::DataType& key) const {
- return Reinterpret<DLDataType, uint32_t>(key);
- }
- template <typename ENum, typename = typename
std::enable_if<std::is_enum<ENum>::value>::type>
- uint64_t operator()(const ENum& key) const {
- return Reinterpret<int64_t, uint64_t>(static_cast<int64_t>(key));
- }
- uint64_t operator()(const std::string& key) const {
- return tvm::ffi::details::StableHashBytes(key.data(), key.length());
- }
- uint64_t operator()(const ffi::Optional<int64_t>& key) const {
- if (key.has_value()) {
- return Reinterpret<int64_t, uint64_t>(*key);
- } else {
- return 0;
- }
- }
- uint64_t operator()(const ffi::Optional<double>& key) const {
- if (key.has_value()) {
- return Reinterpret<double, uint64_t>(*key);
- } else {
- return 0;
- }
- }
- /*!
- * \brief Compute structural hash value for a POD value in Any.
- * \param key The Any object.
- * \return The hash value.
- */
- TVM_FFI_INLINE uint64_t HashPODValueInAny(const ffi::Any& key) const {
- return ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(key)->v_uint64;
- }
-};
-
-/*!
- * \brief Content-aware structural hashing.
- *
- * The structural hash value is recursively defined in the DAG of IRNodes.
- * There are two kinds of nodes:
- *
- * - Normal node: the hash value is defined by its content and type only.
- * - Graph node: each graph node will be assigned a unique index ordered by
the
- * first occurrence during the visit. The hash value of a graph node is
- * combined from the hash values of its contents and the index.
- */
-class StructuralHash : public BaseValueHash {
- public:
- // inherit operator()
- using BaseValueHash::operator();
- /*!
- * \brief Compute structural hashing value for an object.
- * \param key The left operand.
- * \return The hash value.
- */
- TVM_DLL uint64_t operator()(const ffi::Any& key) const;
-};
-
-} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_HASH_H_
diff --git a/include/tvm/relax/distributed/axis_group_graph.h
b/include/tvm/relax/distributed/axis_group_graph.h
index 6ea322938f..6dc2022d5f 100644
--- a/include/tvm/relax/distributed/axis_group_graph.h
+++ b/include/tvm/relax/distributed/axis_group_graph.h
@@ -251,7 +251,7 @@ using AxisShardingSpec = std::pair<DeviceMesh, int>;
class AxisShardingSpecEqual {
public:
bool operator()(const AxisShardingSpec& lhs, const AxisShardingSpec& rhs)
const {
- return StructuralEqual()(lhs.first, rhs.first) && lhs.second == rhs.second;
+ return ffi::StructuralEqual()(lhs.first, rhs.first) && lhs.second ==
rhs.second;
}
};
@@ -259,7 +259,7 @@ class AxisShardingSpecHash {
public:
size_t operator()(const AxisShardingSpec& sharding_spec) const {
size_t seed = 0;
- seed ^= StructuralHash()(sharding_spec.first);
+ seed ^= ffi::StructuralHash()(sharding_spec.first);
seed ^= std::hash<int>()(sharding_spec.second) << 1;
return seed;
}
diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h
index 222dea3fb1..66bae5411a 100644
--- a/include/tvm/relax/exec_builder.h
+++ b/include/tvm/relax/exec_builder.h
@@ -23,6 +23,8 @@
#ifndef TVM_RELAX_EXEC_BUILDER_H_
#define TVM_RELAX_EXEC_BUILDER_H_
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
@@ -178,7 +180,8 @@ class ExecBuilderNode : public Object {
/*! \brief The mutable internal executable. */
ObjectPtr<vm::VMExecutable> exec_; // mutable
/*! \brief internal dedup map when creating index for a new constant */
- std::unordered_map<ffi::Any, vm::Index, StructuralHash, StructuralEqual>
const_dedup_map_;
+ std::unordered_map<ffi::Any, vm::Index, ffi::StructuralHash,
ffi::StructuralEqual>
+ const_dedup_map_;
};
class ExecBuilder : public ObjectRef {
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 9b4fa91379..f8cebafa55 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -25,9 +25,9 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/source_map.h>
-#include <tvm/node/node.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/tensor.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index 12b97e20c2..c51a6db5a2 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -22,10 +22,11 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/env_func.h>
#include <tvm/ir/source_map.h>
-#include <tvm/node/node.h>
+#include <tvm/node/cast.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/type.h>
+#include <tvm/runtime/object.h>
#include <utility>
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 80279a4862..e186e85b9d 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -139,7 +139,10 @@
static_assert(static_cast<int>(TypeIndex::kCustomStaticIndex) >=
} // namespace runtime
+using tvm::ffi::Object;
using tvm::ffi::ObjectPtr;
+using tvm::ffi::ObjectPtrEqual;
+using tvm::ffi::ObjectPtrHash;
using tvm::ffi::ObjectRef;
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_
diff --git a/include/tvm/s_tir/meta_schedule/arg_info.h
b/include/tvm/s_tir/meta_schedule/arg_info.h
index cf70550874..ae2c3c9057 100644
--- a/include/tvm/s_tir/meta_schedule/arg_info.h
+++ b/include/tvm/s_tir/meta_schedule/arg_info.h
@@ -22,7 +22,6 @@
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/function.h>
diff --git a/include/tvm/script/ir_builder/base.h
b/include/tvm/script/ir_builder/base.h
index 47ed628da0..e5679c6064 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -22,7 +22,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
-#include <tvm/node/node.h>
+#include <tvm/node/cast.h>
#include <vector>
diff --git a/include/tvm/script/ir_builder/ir/frame.h
b/include/tvm/script/ir_builder/ir/frame.h
index 53efc9df7f..cff1242c88 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -23,7 +23,6 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/module.h>
-#include <tvm/node/node.h>
#include <tvm/script/ir_builder/base.h>
#include <vector>
diff --git a/include/tvm/script/ir_builder/ir/ir.h
b/include/tvm/script/ir_builder/ir/ir.h
index 9fe3d7e1ac..761e2a995f 100644
--- a/include/tvm/script/ir_builder/ir/ir.h
+++ b/include/tvm/script/ir_builder/ir/ir.h
@@ -21,7 +21,6 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
-#include <tvm/node/node.h>
#include <tvm/script/ir_builder/ir/frame.h>
#include <vector>
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 9ce980d268..9d63ae08e3 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -22,7 +22,6 @@
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index cf8c72daf8..bd8c37780c 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -22,7 +22,7 @@
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
-#include <tvm/node/node.h>
+#include <tvm/node/cast.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier_functor.h>
diff --git a/include/tvm/script/printer/ir_docsifier_functor.h
b/include/tvm/script/printer/ir_docsifier_functor.h
index 68caa5ff4d..2cc1782d92 100644
--- a/include/tvm/script/printer/ir_docsifier_functor.h
+++ b/include/tvm/script/printer/ir_docsifier_functor.h
@@ -20,8 +20,8 @@
#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
#include <tvm/ffi/function.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/logging.h>
+#include <tvm/runtime/object.h>
#include <optional>
#include <string>
diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h
index b8de3fffba..c6c828e2df 100644
--- a/include/tvm/target/tag.h
+++ b/include/tvm/target/tag.h
@@ -26,7 +26,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/attr_registry_map.h>
-#include <tvm/node/node.h>
#include <tvm/target/target.h>
#include <utility>
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 9a0bedd1cc..b71a4952b5 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -27,7 +27,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
-#include <tvm/node/node.h>
+#include <tvm/node/cast.h>
#include <tvm/runtime/device_api.h>
#include <tvm/support/with.h>
#include <tvm/target/target_kind.h>
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 02ac88a9af..86f55a2410 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -28,7 +28,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/config_schema.h>
#include <tvm/node/attr_registry_map.h>
-#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
#include <memory>
#include <unordered_map>
diff --git a/include/tvm/target/virtual_device.h
b/include/tvm/target/virtual_device.h
index 889d0eff89..52512edda8 100644
--- a/include/tvm/target/virtual_device.h
+++ b/include/tvm/target/virtual_device.h
@@ -365,7 +365,7 @@ class VirtualDeviceCache {
private:
/*! \brief Already constructed VirtualDevices. */
- std::unordered_set<VirtualDevice, StructuralHash, StructuralEqual> cache_;
+ std::unordered_set<VirtualDevice, ffi::StructuralHash, ffi::StructuralEqual>
cache_;
};
/*! brief The attribute key for the virtual device. This key will be promoted
to first class on
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 5297654691..34c11bdd3e 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -30,7 +30,6 @@
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/node/functor.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 97dfbb1330..31c6e3bc5c 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -27,6 +27,7 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ir/function.h>
+#include <tvm/node/script_printer.h>
#include <tvm/runtime/tensor.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index 6866431ee4..c4e716cbe1 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -29,6 +29,7 @@
#include <tvm/ffi/container/array.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/object.h>
+#include <tvm/runtime/tensor.h>
#include <tvm/tir/var.h>
#include <utility>
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 4fcb91403f..b41c92d66a 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -25,6 +25,7 @@
#define TVM_TIR_STMT_H_
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/node/script_printer.h>
#include <tvm/tir/expr.h>
#include <string>
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 521b03a472..e83064b864 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -25,7 +25,6 @@
#define TVM_TIR_VAR_H_
#include <tvm/ir/expr.h>
-#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <functional>
diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py
index 9eea8a36d2..13557f9bd9 100644
--- a/python/tvm/runtime/_tensor.py
+++ b/python/tvm/runtime/_tensor.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import, redefined-outer-name
-# ruff: noqa: E722, F401, RUF005
+# ruff: noqa: F401, RUF005
"""Runtime Tensor API"""
import ctypes
diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py
b/python/tvm/s_tir/dlight/analysis/common_analysis.py
index fb44d3fcea..3a14836109 100644
--- a/python/tvm/s_tir/dlight/analysis/common_analysis.py
+++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: E722
# pylint: disable=missing-function-docstring, missing-class-docstring
# pylint: disable=unused-argument, unused-variable
diff --git a/src/arith/conjunctive_normal_form.cc
b/src/arith/conjunctive_normal_form.cc
index 7a87a1fbab..47b6156cfa 100644
--- a/src/arith/conjunctive_normal_form.cc
+++ b/src/arith/conjunctive_normal_form.cc
@@ -133,10 +133,10 @@ class AndOfOrs {
std::vector<std::vector<Key>> chunks_;
/*! \brief Mapping from internal Key to PrimExpr */
- std::unordered_map<Key, PrimExpr, StructuralHash, StructuralEqual>
key_to_expr_;
+ std::unordered_map<Key, PrimExpr, ffi::StructuralHash, ffi::StructuralEqual>
key_to_expr_;
/*! \brief Mapping from PrimExpr to internal Key */
- std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual>
expr_to_key_;
+ std::unordered_map<PrimExpr, Key, ffi::StructuralHash, ffi::StructuralEqual>
expr_to_key_;
/*! \brief Cached key representing tir::Bool(true) */
Key key_true_;
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 4d08c79072..1779c42583 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -458,10 +458,12 @@ class IterMapRewriter : public ExprMutator {
// usage of an input iterator. (e.g. (i-1) occurring in the
// expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be
// left-padded by 31 for each occurrence.)
- std::unordered_map<IterMark, IterPaddingInfo, StructuralHash,
StructuralEqual> padded_iter_map_;
+ std::unordered_map<IterMark, IterPaddingInfo, ffi::StructuralHash,
ffi::StructuralEqual>
+ padded_iter_map_;
// Map from padded iter mark to it's origin mark
- std::unordered_map<IterMark, IterMark, StructuralHash, StructuralEqual>
padded_origin_map_;
+ std::unordered_map<IterMark, IterMark, ffi::StructuralHash,
ffi::StructuralEqual>
+ padded_origin_map_;
/* If update_iterator_padding_ is true, allow the extents of the IterMap to
be
* padded beyond the original iterators.
diff --git a/src/arith/solve_linear_inequality.cc
b/src/arith/solve_linear_inequality.cc
index 3b8e96773b..6c932ea522 100644
--- a/src/arith/solve_linear_inequality.cc
+++ b/src/arith/solve_linear_inequality.cc
@@ -103,7 +103,7 @@ void AddInequality(std::vector<PrimExpr>* inequality_set,
const PrimExpr& new_in
Analyzer* analyzer) {
if (analyzer->CanProve(new_ineq) ||
std::find_if(inequality_set->begin(), inequality_set->end(), [&](const
PrimExpr& e) {
- return StructuralEqual()(e, new_ineq);
+ return ffi::StructuralEqual()(e, new_ineq);
}) != inequality_set->end()) {
// redundant: follows from the vranges
// or has already been added
@@ -175,7 +175,7 @@ void MoveEquality(std::vector<PrimExpr>* upper_bounds,
std::vector<PrimExpr>* lo
// those exist in both upper & lower bounds will be moved to equalities
for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(),
- [&](const PrimExpr& e) { return
StructuralEqual()(e, *ub); });
+ [&](const PrimExpr& e) { return
ffi::StructuralEqual()(e, *ub); });
if (lb != lower_bounds->end()) {
equalities->push_back(*lb);
lower_bounds->erase(lb);
diff --git a/src/arith/transitive_comparison_analyzer.cc
b/src/arith/transitive_comparison_analyzer.cc
index 3794ff150b..23aaf2140c 100644
--- a/src/arith/transitive_comparison_analyzer.cc
+++ b/src/arith/transitive_comparison_analyzer.cc
@@ -139,7 +139,7 @@ class TransitiveComparisonAnalyzer::Impl {
* \see ExprToKey
* \see ExprToPreviousKey
*/
- std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual>
expr_to_key;
+ std::unordered_map<PrimExpr, Key, ffi::StructuralHash, ffi::StructuralEqual>
expr_to_key;
/*! \brief Internal representation of a comparison operator */
struct Comparison {
diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc
b/src/contrib/msc/core/printer/msc_base_printer.cc
index fa6fc378f2..aeecd79750 100644
--- a/src/contrib/msc/core/printer/msc_base_printer.cc
+++ b/src/contrib/msc/core/printer/msc_base_printer.cc
@@ -23,6 +23,7 @@
#include "msc_base_printer.h"
+#include <cmath>
#include <utility>
#include "../utils.h"
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 04e3026b0f..935d9e0ccd 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -21,13 +21,13 @@
* \brief The global module in TVM.
*/
#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/type_functor.h>
-#include <tvm/node/structural_equal.h>
#include <algorithm>
#include <fstream>
diff --git a/src/node/serialization.cc b/src/ir/serialization.cc
similarity index 97%
rename from src/node/serialization.cc
rename to src/ir/serialization.cc
index 2faf8d170b..4d9074e98c 100644
--- a/src/node/serialization.cc
+++ b/src/ir/serialization.cc
@@ -18,7 +18,7 @@
*/
/*!
- * \file node/serialization.cc
+ * \file src/ir/serialization.cc
* \brief Utilities to serialize TVM AST/IR objects.
*/
#include <tvm/ffi/extra/json.h>
diff --git a/src/node/structural_equal.cc b/src/ir/structural_equal.cc
similarity index 90%
rename from src/node/structural_equal.cc
rename to src/ir/structural_equal.cc
index e33d7c7746..1d7cbd23d0 100644
--- a/src/node/structural_equal.cc
+++ b/src/ir/structural_equal.cc
@@ -17,7 +17,7 @@
* under the License.
*/
/*!
- * \file src/node/structural_equal.cc
+ * \file src/ir/structural_equal.cc
*/
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
@@ -25,8 +25,8 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
-#include <tvm/node/node.h>
-#include <tvm/node/structural_equal.h>
+#include <tvm/node/repr_printer.h>
+#include <tvm/node/script_printer.h>
#include <optional>
#include <unordered_map>
@@ -80,8 +80,4 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("node.GetFirstStructuralMismatch",
ffi::StructuralEqual::GetFirstMismatch);
}
-bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs,
- bool map_free_params) const {
- return ffi::StructuralEqual::Equal(lhs, rhs, map_free_params);
-}
} // namespace tvm
diff --git a/src/node/structural_hash.cc b/src/ir/structural_hash.cc
similarity index 96%
rename from src/node/structural_hash.cc
rename to src/ir/structural_hash.cc
index f32f0756c0..ad74742e51 100644
--- a/src/node/structural_hash.cc
+++ b/src/ir/structural_hash.cc
@@ -17,7 +17,7 @@
* under the License.
*/
/*!
- * \file src/node/structural_hash.cc
+ * \file src/ir/structural_hash.cc
*/
#include <tvm/ffi/extra/base64.h>
#include <tvm/ffi/extra/module.h>
@@ -26,8 +26,6 @@
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/functor.h>
-#include <tvm/node/node.h>
-#include <tvm/node/structural_hash.h>
#include <tvm/runtime/profiling.h>
#include <tvm/support/io.h>
#include <tvm/target/codegen.h>
@@ -81,10 +79,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
}
-uint64_t StructuralHash::operator()(const ffi::Any& object) const {
- return ffi::StructuralHash::Hash(object, false);
-}
-
struct RefToObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const ObjectRef& ref) {
return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(ref);
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 148918be8e..9d2dec6b1a 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -21,12 +21,12 @@
* \file src/ir/transform.cc
* \brief Infrastructure for transformation passes.
*/
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
-#include <tvm/node/structural_hash.h>
#include <tvm/relax/expr.h>
#include <tvm/runtime/device_api.h>
@@ -312,10 +312,10 @@ IRModule Pass::operator()(IRModule mod, const
PassContext& pass_ctx) const {
IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node,
const PassContext& pass_ctx) {
- size_t before_pass_hash = tvm::StructuralHash()(mod);
+ size_t before_pass_hash = ffi::StructuralHash()(mod);
IRModule copy_mod = mod;
IRModule ret = node->operator()(mod, pass_ctx);
- size_t after_pass_hash = tvm::StructuralHash()(copy_mod);
+ size_t after_pass_hash = ffi::StructuralHash()(copy_mod);
if (before_pass_hash != after_pass_hash) {
// The chance of getting a hash conflict between a module and the same
module but mutated
// must be very low.
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index 2565a02b64..a5f4ffb84d 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -25,7 +25,6 @@
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
-#include <tvm/node/node.h>
namespace tvm {
diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc
index 09413ba007..1774ba4b4b 100644
--- a/src/node/script_printer.cc
+++ b/src/node/script_printer.cc
@@ -19,6 +19,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
+#include <tvm/node/cast.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/script_printer.h>
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index cd951896d8..101e8e8b74 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -485,7 +485,7 @@ class StructInfoBaseChecker
// analyzer
arith::Analyzer* analyzer_;
// struct equal checker
- StructuralEqual struct_equal_;
+ ffi::StructuralEqual struct_equal_;
// customizable functions.
/*!
@@ -742,7 +742,7 @@ class StructInfoBasePreconditionCollector
return Bool(false);
}
- StructuralEqual struct_equal;
+ ffi::StructuralEqual struct_equal;
if (!struct_equal(lhs->device_mesh, rhs->device_mesh) ||
!struct_equal(lhs->placement, rhs->placement)) {
return Bool(false);
@@ -1154,7 +1154,7 @@ class StructInfoLCAFinder
// analyzer
arith::Analyzer* analyzer_;
// struct equal checker
- StructuralEqual struct_equal_;
+ ffi::StructuralEqual struct_equal_;
// check arrays
ffi::Optional<ffi::Array<StructInfo>> UnifyArray(const
ffi::Array<StructInfo>& lhs,
@@ -1303,7 +1303,7 @@ class NonNegativeExpressionCollector :
relax::StructInfoVisitor {
}
ffi::Array<PrimExpr> expressions_;
- std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> dedup_lookup_;
+ std::unordered_set<PrimExpr, ffi::StructuralHash, ffi::StructuralEqual>
dedup_lookup_;
};
ffi::Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo) {
diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc
b/src/relax/backend/adreno/annotate_custom_storage.cc
index 861e57aeb7..396e5f9cbb 100644
--- a/src/relax/backend/adreno/annotate_custom_storage.cc
+++ b/src/relax/backend/adreno/annotate_custom_storage.cc
@@ -236,7 +236,7 @@
*
*/
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend/adreno/transform.h>
#include <tvm/relax/dataflow_matcher.h>
diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc
b/src/relax/backend/adreno/fold_vdevice_scope_change.cc
index a7103cde95..af34a3ac10 100644
--- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc
+++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc
@@ -23,7 +23,7 @@
* store into global scope avoiding unnecessary device copy.
*/
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend/adreno/transform.h>
#include <tvm/relax/dataflow_matcher.h>
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 68c266c5dd..4cca5d7b6c 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -69,7 +69,7 @@ struct MatchShapeTodoItem {
/*! \brief Slot map used for shape lowering. */
using PrimExprSlotMap =
- std::unordered_map<PrimExpr, PrimExprSlot*, StructuralHash,
tir::ExprDeepEqual>;
+ std::unordered_map<PrimExpr, PrimExprSlot*, ffi::StructuralHash,
tir::ExprDeepEqual>;
// Collector to collect PrimExprSlotMap
class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor {
diff --git a/src/relax/distributed/transform/legalize_redistribute.cc
b/src/relax/distributed/transform/legalize_redistribute.cc
index 7faa469787..88fdb14ffd 100644
--- a/src/relax/distributed/transform/legalize_redistribute.cc
+++ b/src/relax/distributed/transform/legalize_redistribute.cc
@@ -74,7 +74,7 @@ class RedistributeLegalizer : public ExprMutator {
// and the device mesh must be 1d
// todo: extend the ccl ops so that it can support 2d device mesh, and
different sharding
// dimension
- TVM_FFI_ICHECK(StructuralEqual()(input_sinfo->device_mesh,
attrs->device_mesh));
+ TVM_FFI_ICHECK(ffi::StructuralEqual()(input_sinfo->device_mesh,
attrs->device_mesh));
TVM_FFI_ICHECK(input_sinfo->device_mesh->shape.size() == 1);
// only support "S[x]"-> "R" and "R" -> "S[x]"
PlacementSpec input_spec = input_sinfo->placement->dim_specs[0];
diff --git a/src/relax/distributed/transform/lower_distir.cc
b/src/relax/distributed/transform/lower_distir.cc
index 83300f80ac..49fc366a53 100644
--- a/src/relax/distributed/transform/lower_distir.cc
+++ b/src/relax/distributed/transform/lower_distir.cc
@@ -256,7 +256,8 @@ class DistIRSharder : public ExprMutator {
Function func_;
ffi::Array<Var> new_params_;
- std::unordered_map<TupleGetItem, Var, StructuralHash, StructuralEqual>
tuple_getitem_remap_;
+ std::unordered_map<TupleGetItem, Var, ffi::StructuralHash,
ffi::StructuralEqual>
+ tuple_getitem_remap_;
};
namespace transform {
diff --git a/src/relax/distributed/transform/propagate_sharding.cc
b/src/relax/distributed/transform/propagate_sharding.cc
index 703857da91..1123b1db25 100644
--- a/src/relax/distributed/transform/propagate_sharding.cc
+++ b/src/relax/distributed/transform/propagate_sharding.cc
@@ -283,7 +283,7 @@ class ShardingConflictHandler : public ExprVisitor {
}
if (device_mesh.defined()) {
- TVM_FFI_ICHECK(StructuralEqual()(device_mesh.value(),
sharding_spec.first))
+ TVM_FFI_ICHECK(ffi::StructuralEqual()(device_mesh.value(),
sharding_spec.first))
<< "Sharding conflict detected for tensor " << var->name_hint()
<< ": Device Mesh mismatch"
<< ". Conflict Handling logic will be added in the future.";
@@ -561,7 +561,7 @@ class DistributedIRBuilder : public ExprMutator {
if (const auto* inferred_dtensor_sinfo =
new_call->struct_info_.as<DTensorStructInfoNode>()) {
Expr new_value = RemoveAnnotateSharding(new_call);
- if (!StructuralEqual()(
+ if (!ffi::StructuralEqual()(
DTensorStructInfo(inferred_dtensor_sinfo->tensor_sinfo,
device_mesh, placements[0]),
new_call->struct_info_)) {
new_value = InsertRedistribute(new_value, device_mesh, placements[0]);
@@ -577,7 +577,7 @@ class DistributedIRBuilder : public ExprMutator {
Var new_var = builder_->Emit(new_call);
var_remap_[binding->var->vid] = new_var;
for (int i = 0; i <
static_cast<int>(inferred_tuple_sinfo->fields.size()); i++) {
- if (!StructuralEqual()(
+ if (!ffi::StructuralEqual()(
DTensorStructInfo(
Downcast<DTensorStructInfo>(inferred_tuple_sinfo->fields[i])->tensor_sinfo,
device_mesh, placements[i]),
@@ -607,7 +607,8 @@ class DistributedIRBuilder : public ExprMutator {
}
ffi::Map<Var, Var> input_tensor_remap_;
- std::unordered_map<TupleGetItem, Var, StructuralHash, StructuralEqual>
tuple_getitem_remap_;
+ std::unordered_map<TupleGetItem, Var, ffi::StructuralHash,
ffi::StructuralEqual>
+ tuple_getitem_remap_;
AxisGroupGraph axis_group_graph_;
};
namespace transform {
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 057351e3d0..c2af644fba 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -445,7 +445,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
*/
std::unique_ptr<
std::unordered_map<BaseFunc, std::unordered_set<GlobalVar,
ObjectPtrHash, ObjectPtrEqual>,
- StructuralHashIgnoreNDarray, StructuralEqual>>
+ StructuralHashIgnoreNDarray, ffi::StructuralEqual>>
ctx_func_dedup_map_ = nullptr;
/*!
@@ -455,7 +455,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
if (ctx_func_dedup_map_ != nullptr) return;
ctx_func_dedup_map_ = std::make_unique<
std::unordered_map<BaseFunc, std::unordered_set<GlobalVar,
ObjectPtrHash, ObjectPtrEqual>,
- StructuralHashIgnoreNDarray, StructuralEqual>>();
+ StructuralHashIgnoreNDarray,
ffi::StructuralEqual>>();
for (const auto& kv : context_mod_->functions) {
const GlobalVar gv = kv.first;
const BaseFunc func = kv.second;
diff --git a/src/relax/ir/dataflow_block_rewriter.cc
b/src/relax/ir/dataflow_block_rewriter.cc
index 2f2d1dac9a..b13fb84105 100644
--- a/src/relax/ir/dataflow_block_rewriter.cc
+++ b/src/relax/ir/dataflow_block_rewriter.cc
@@ -23,8 +23,8 @@
*/
#include <tvm/arith/analyzer.h>
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/node/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
#include <tvm/relax/dataflow_pattern.h>
diff --git a/src/relax/ir/dataflow_expr_rewriter.cc
b/src/relax/ir/dataflow_expr_rewriter.cc
index 72f62041db..a95b51745d 100644
--- a/src/relax/ir/dataflow_expr_rewriter.cc
+++ b/src/relax/ir/dataflow_expr_rewriter.cc
@@ -22,9 +22,9 @@
* \brief A transform to match a Relax Expr and rewrite
*/
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
-#include <tvm/node/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
#include <tvm/relax/dataflow_pattern.h>
@@ -543,7 +543,7 @@ std::optional<std::vector<Expr>>
TupleRewriterNode::TryMatchByBindingIndex(
for (size_t i = 1; i < indices.size(); i++) {
for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) {
if (auto it = merged_matches.find(pat); it != merged_matches.end()) {
- if (!StructuralEqual()(expr, (*it).second)) {
+ if (!ffi::StructuralEqual()(expr, (*it).second)) {
return std::nullopt;
}
} else {
@@ -698,7 +698,7 @@ PatternMatchingRewriter
PatternMatchingRewriter::FromModule(IRModule mod) {
auto sinfo_pattern = GetStructInfo(func_pattern);
auto sinfo_replacement = GetStructInfo(func_replacement);
- TVM_FFI_CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement),
ValueError)
+ TVM_FFI_CHECK(ffi::StructuralEqual()(sinfo_pattern, sinfo_replacement),
ValueError)
<< "The pattern and replacement must have the same signature, "
<< "but the pattern has struct info " << sinfo_pattern
<< ", while the replacement has struct info " << sinfo_replacement;
@@ -832,7 +832,7 @@ class PatternMatchingMutator : public ExprMutator {
Expr VisitExpr_(const SeqExprNode* seq) override {
SeqExpr prev = Downcast<SeqExpr>(ExprMutator::VisitExpr_(seq));
- StructuralEqual struct_equal;
+ ffi::StructuralEqual struct_equal;
while (auto opt = TryRewriteSeqExpr(prev)) {
SeqExpr next = Downcast<SeqExpr>(builder_->Normalize(opt.value()));
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 2f7099937f..3c0e57dc07 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -25,7 +25,7 @@
#include "dataflow_matcher.h"
#include <tvm/arith/analyzer.h>
-#include <tvm/node/structural_equal.h>
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
#include <tvm/relax/dataflow_pattern.h>
@@ -67,7 +67,7 @@ bool MatchAttrs(const Any& attrs, const ffi::Map<ffi::String,
ffi::Any>& attribu
auto attr_name = kv.first;
auto attr_value = kv.second;
if (dict_attrs->dict.count(attr_name)) {
- if (!StructuralEqual()(attr_value, dict_attrs->dict[attr_name])) {
+ if (!ffi::StructuralEqual()(attr_value, dict_attrs->dict[attr_name])) {
return false;
}
} else {
@@ -89,7 +89,7 @@ bool MatchAttrs(const Any& attrs, const ffi::Map<ffi::String,
ffi::Any>& attribu
if (attributes.count(field_name)) {
ffi::reflection::FieldGetter field_getter(field_info);
ffi::Any field_value = field_getter(obj);
- if (!StructuralEqual()(attributes[field_name], field_value)) {
+ if (!ffi::StructuralEqual()(attributes[field_name], field_value)) {
success = false;
return true;
}
@@ -194,7 +194,7 @@ bool DFPatternMatcher::VisitDFPattern_(const
AttrPatternNode* attr_pattern, cons
if (Op::HasAttrMap(attr_name)) {
auto op_map = Op::GetAttrMap<ffi::Any>(attr_name);
if (op_map.count(op)) {
- matches &= StructuralEqual()(attr_value, op_map[op]);
+ matches &= ffi::StructuralEqual()(attr_value, op_map[op]);
} else {
matches = false;
}
@@ -208,7 +208,7 @@ bool DFPatternMatcher::VisitDFPattern_(const
AttrPatternNode* attr_pattern, cons
matches = true;
for (auto kv : attributes) {
if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) {
- matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]);
+ matches &= ffi::StructuralEqual()(kv.second,
op->attrs->dict[kv.first]);
} else {
matches = false;
break;
@@ -332,7 +332,7 @@ bool DFPatternMatcher::VisitDFPattern_(const
CallPatternNode* op, const Expr& ex
bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr&
expr0) {
auto expr = UnwrapBindings(expr0, var2val_);
- return StructuralEqual()(op->expr, expr);
+ return ffi::StructuralEqual()(op->expr, expr);
}
bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const
Expr& expr0) {
@@ -570,7 +570,8 @@ bool DFPatternMatcher::VisitDFPattern_(const
DataTypePatternNode* op, const Expr
// no need to jump, as var.dtype == value.dtype
auto expr_sinfo = expr.as<ExprNode>()->struct_info_;
if (const TensorStructInfoNode* tensor_sinfo =
expr_sinfo.as<TensorStructInfoNode>()) {
- return (StructuralEqual()(op->dtype, tensor_sinfo->dtype)) &&
VisitDFPattern(op->pattern, expr);
+ return (ffi::StructuralEqual()(op->dtype, tensor_sinfo->dtype)) &&
+ VisitDFPattern(op->pattern, expr);
}
return false;
}
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 13ef41eede..fdedf80911 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -848,7 +848,7 @@ Var ExprMutator::WithStructInfo(Var var, StructInfo
struct_info) {
if (var->struct_info_.defined()) {
// use same-as as a quick path
if (var->struct_info_.same_as(struct_info) ||
- StructuralEqual()(var->struct_info_, struct_info)) {
+ ffi::StructuralEqual()(var->struct_info_, struct_info)) {
return var;
} else {
Var new_var = var.as<DataflowVarNode>() ? DataflowVar(var->vid,
struct_info, var->span)
diff --git a/src/relax/op/distributed/utils.cc
b/src/relax/op/distributed/utils.cc
index d8a23da382..57c80abbf6 100644
--- a/src/relax/op/distributed/utils.cc
+++ b/src/relax/op/distributed/utils.cc
@@ -45,8 +45,8 @@ StructInfo InferShardingSpec(const Call& call, const
BlockBuilder& ctx,
ffi::Array<distributed::DTensorStructInfo> input_dtensor_sinfos =
GetInputDTensorStructInfo(call, ctx);
for (int i = 1; i < static_cast<int>(input_dtensor_sinfos.size()); i++) {
- TVM_FFI_ICHECK(StructuralEqual()(input_dtensor_sinfos[0]->device_mesh,
- input_dtensor_sinfos[i]->device_mesh));
+ TVM_FFI_ICHECK(ffi::StructuralEqual()(input_dtensor_sinfos[0]->device_mesh,
+
input_dtensor_sinfos[i]->device_mesh));
}
distributed::DeviceMesh device_mesh = input_dtensor_sinfos[0]->device_mesh;
Var output_var("output", orig_output_sinfo);
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index e5f3d19e8d..fc6ec6b8aa 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -163,7 +163,7 @@ ffi::Optional<ffi::Array<PrimExpr>> CheckConcatOutputShape(
// For the specified axis, we compute the sum of shape value over each
tensor.
// Special case, if all concatenated values have the same shape
- StructuralEqual structural_equal;
+ ffi::StructuralEqual structural_equal;
PrimExpr first_concat_dim = shape_values[0][axis];
bool all_same = std::all_of(shape_values.begin(), shape_values.end(),
[&](const auto& a) {
return structural_equal(a[axis], first_concat_dim);
diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc
index 614f20dba7..000a0d4b7d 100644
--- a/src/relax/training/utils.cc
+++ b/src/relax/training/utils.cc
@@ -150,7 +150,7 @@ class AppendLossMutator : private ExprMutator {
* sets up var_remap_ from loss parameter Vars to backbone returned Vars.
*/
void CheckAndRemapLossParams(const ffi::Array<Var>& loss_func_params) {
- static StructuralEqual checker;
+ static ffi::StructuralEqual checker;
TVM_FFI_ICHECK(static_cast<int>(loss_func_params.size()) >=
num_backbone_outputs_)
<< "The number of parameters of the loss function is " <<
loss_func_params.size()
<< ", which is less than the given num_backbone_outputs " <<
num_backbone_outputs_;
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
index f066bd02da..93db755059 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -26,7 +26,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/manipulate.h>
#include <tvm/relax/expr_functor.h>
diff --git a/src/relax/transform/canonicalize_bindings.cc
b/src/relax/transform/canonicalize_bindings.cc
index 05c86d9263..98fd075f55 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -92,7 +92,7 @@ class SymbolicVarCanonicalizer : public ExprMutator {
// within each branch.
auto new_sinfo =
VisitExprDepStructInfoField(Downcast<StructInfo>(op->struct_info_));
- StructuralEqual struct_equal;
+ ffi::StructuralEqual struct_equal;
if (!struct_equal(new_sinfo, GetStructInfo(true_b))) {
auto output_var = Var("then_branch_with_dyn", new_sinfo);
@@ -351,7 +351,8 @@ class CanonicalizePlanner : public ExprVisitor {
if (binding.as<VarBindingNode>()) {
return true;
} else if (auto match_cast = binding.as<MatchCastNode>()) {
- return StructuralEqual()(GetStructInfo(binding->var),
GetStructInfo(match_cast->value));
+ return ffi::StructuralEqual()(GetStructInfo(binding->var),
+ GetStructInfo(match_cast->value));
} else {
TVM_FFI_THROW(InternalError) << "Invalid binding type: " <<
binding->GetTypeKey();
}
diff --git a/src/relax/transform/convert_layout.cc
b/src/relax/transform/convert_layout.cc
index c71675bb26..888554df67 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -22,7 +22,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/nested_msg.h>
#include <tvm/relax/op_attr_types.h>
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 7e7f069cdd..20e4ce4f59 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -24,6 +24,8 @@
*
* Currently it removes common subexpressions within a Function.
*/
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
@@ -58,7 +60,7 @@ struct ReplacementKey {
}
friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) {
- tvm::StructuralEqual eq;
+ ffi::StructuralEqual eq;
return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast);
}
};
@@ -76,7 +78,7 @@ struct ReplacementKey {
template <>
struct std::hash<tvm::relax::ReplacementKey> {
std::size_t operator()(const tvm::relax::ReplacementKey& key) const {
- tvm::StructuralHash hasher;
+ tvm::ffi::StructuralHash hasher;
return tvm::support::HashCombine(hasher(key.bound_value),
hasher(key.match_cast));
}
};
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 3a289ebfff..5194941c26 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -412,7 +412,8 @@ class ConstantFolder : public ExprMutator {
}
// cache for function build, via structural equality
- std::unordered_map<tir::PrimFunc, ffi::Optional<ffi::Function>,
StructuralHash, StructuralEqual>
+ std::unordered_map<tir::PrimFunc, ffi::Optional<ffi::Function>,
ffi::StructuralHash,
+ ffi::StructuralEqual>
func_build_cache_;
};
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 4a36047906..3f739cd243 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -492,7 +492,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor {
// structurally equal to the `new_buf` passed
auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) {
if (auto it = relax_to_tir_var_map_.find(expr); it !=
relax_to_tir_var_map_.end()) {
- TVM_FFI_ICHECK(StructuralEqual()((*it).second, new_buf))
+ TVM_FFI_ICHECK(ffi::StructuralEqual()((*it).second, new_buf))
<< "Inconsistent buffers " << (*it).second << " and " << new_buf
<< " mapped to the same relax var: " << expr;
}
diff --git a/src/relax/transform/kill_after_last_use.cc
b/src/relax/transform/kill_after_last_use.cc
index bae9794ecc..969319063c 100644
--- a/src/relax/transform/kill_after_last_use.cc
+++ b/src/relax/transform/kill_after_last_use.cc
@@ -52,7 +52,7 @@ class UnusedTrivialBindingRemover : public ExprMutator {
}
void VisitBinding_(const MatchCastNode* binding) override {
if (binding->value.as<VarNode>() &&
- StructuralEqual()(GetStructInfo(binding->var),
GetStructInfo(binding->value))) {
+ ffi::StructuralEqual()(GetStructInfo(binding->var),
GetStructInfo(binding->value))) {
has_trivial_binding.insert(binding->var.get());
}
ExprVisitor::VisitBinding_(binding);
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index 15ba2b82e8..0e9cc204ca 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -525,7 +525,7 @@ class ParamRemapper : private ExprFunctor<void(const Expr&,
const Expr&)> {
int index_i = j + num_inputs_i;
int index_0 = j + num_inputs_0;
mapper.VisitExpr(functions[i]->params[index_i],
functions[0]->params[index_0]);
- StructuralEqual eq;
+ ffi::StructuralEqual eq;
eq(functions[i]->params[index_i]->struct_info_,
functions[0]->params[index_0]->struct_info_);
}
@@ -642,7 +642,7 @@ class GlobalLiftableBindingCollector : public
BaseLiftableBindingCollector {
// The mapping between the unified bindings and the original bindings in
different functions.
// The unified binding is the binding with all variables replaced by the
unified variables as
// defined in var_remap_.
- std::unordered_map<Expr, std::vector<Binding>, StructuralHash,
StructuralEqual>
+ std::unordered_map<Expr, std::vector<Binding>, ffi::StructuralHash,
ffi::StructuralEqual>
original_bindings_;
}; // namespace
diff --git a/src/relax/transform/remove_unused_outputs.cc
b/src/relax/transform/remove_unused_outputs.cc
index 9de26d8b1a..192dc7acef 100644
--- a/src/relax/transform/remove_unused_outputs.cc
+++ b/src/relax/transform/remove_unused_outputs.cc
@@ -24,6 +24,7 @@
#include <tvm/relax/utils.h>
#include <algorithm>
+#include <cmath>
#include <optional>
#include <tuple>
diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc
b/src/relax/transform/specialize_primfunc_based_on_callsite.cc
index 10fc575e72..8a38baedd7 100644
--- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc
+++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc
@@ -21,7 +21,7 @@
* \brief Update PrimFunc buffers based on updated scope (or structure) info.
*/
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/nested_msg.h>
diff --git a/src/s_tir/meta_schedule/module_equality.cc
b/src/s_tir/meta_schedule/module_equality.cc
index 6973ba8096..fff1a88c33 100644
--- a/src/s_tir/meta_schedule/module_equality.cc
+++ b/src/s_tir/meta_schedule/module_equality.cc
@@ -21,8 +21,6 @@
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ir/module.h>
-#include <tvm/node/structural_equal.h>
-#include <tvm/node/structural_hash.h>
#include <tvm/tir/analysis.h>
#include <memory>
@@ -33,8 +31,8 @@ namespace meta_schedule {
class ModuleEqualityStructural : public ModuleEquality {
public:
- size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); }
- bool Equal(IRModule lhs, IRModule rhs) const { return
tvm::StructuralEqual()(lhs, rhs); }
+ size_t Hash(IRModule mod) const { return ffi::StructuralHash()(mod); }
+ bool Equal(IRModule lhs, IRModule rhs) const { return
ffi::StructuralEqual()(lhs, rhs); }
ffi::String GetName() const { return "structural"; }
};
diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h
index 6b2dd3c96f..d5569d07ec 100644
--- a/src/s_tir/meta_schedule/utils.h
+++ b/src/s_tir/meta_schedule/utils.h
@@ -21,8 +21,9 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/optional.h>
-#include <tvm/node/node.h>
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
+#include <tvm/node/cast.h>
+#include <tvm/runtime/object.h>
#include <tvm/s_tir/meta_schedule/arg_info.h>
#include <tvm/s_tir/meta_schedule/builder.h>
#include <tvm/s_tir/meta_schedule/cost_model.h>
@@ -228,7 +229,7 @@ inline ffi::String SHash2Hex(const ObjectRef& obj) {
std::ostringstream os;
size_t hash_code = 0;
if (obj.defined()) {
- hash_code = StructuralHash()(obj);
+ hash_code = ffi::StructuralHash()(obj);
}
os << "0x" << std::setw(16) << std::setfill('0') << std::hex << hash_code;
return os.str();
diff --git a/src/s_tir/schedule/primitive/compute_inline.cc
b/src/s_tir/schedule/primitive/compute_inline.cc
index 4ceb444ecd..17f804514d 100644
--- a/src/s_tir/schedule/primitive/compute_inline.cc
+++ b/src/s_tir/schedule/primitive/compute_inline.cc
@@ -744,7 +744,7 @@ class ReverseComputeInliner : public BaseInliner {
if (const auto* if_ = producer_block->body.as<IfThenElseNode>()) {
if (!if_->else_case.defined()) {
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
- if (!StructuralEqual()(predicate, if_predicate)) {
+ if (!ffi::StructuralEqual()(predicate, if_predicate)) {
predicate = analyzer_.Simplify(predicate && if_->condition);
producer_block.CopyOnWrite()->body = if_->then_case;
}
diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc
b/src/s_tir/schedule/primitive/layout_transformation.cc
index f608a4b0a3..2d8629c06f 100644
--- a/src/s_tir/schedule/primitive/layout_transformation.cc
+++ b/src/s_tir/schedule/primitive/layout_transformation.cc
@@ -18,7 +18,8 @@
*/
#include <tvm/arith/analyzer.h>
-#include <tvm/node/node.h>
+#include <tvm/node/cast.h>
+#include <tvm/runtime/object.h>
#include <optional>
#include <variant>
diff --git a/src/s_tir/schedule/utils.h b/src/s_tir/schedule/utils.h
index 715e34b09f..d8aebb2f6d 100644
--- a/src/s_tir/schedule/utils.h
+++ b/src/s_tir/schedule/utils.h
@@ -22,7 +22,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/s_tir/schedule/instruction.h>
#include <tvm/s_tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/state.h>
diff --git a/src/s_tir/transform/inject_software_pipeline.cc
b/src/s_tir/transform/inject_software_pipeline.cc
index 6e749dbe64..1e1bb446e4 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -21,6 +21,7 @@
* \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize
producers and consumers
*/
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
@@ -784,7 +785,7 @@ class PipelineRewriter : public StmtExprMutator {
auto stage_id = commit_group_indices[i];
auto predicate = new_blocks[i].predicate;
for (; i < commit_group_indices.size() && commit_group_indices[i] ==
stage_id; ++i) {
- TVM_FFI_ICHECK(tvm::StructuralEqual()(predicate,
new_blocks[i].predicate))
+ TVM_FFI_ICHECK(ffi::StructuralEqual()(predicate,
new_blocks[i].predicate))
<< "Predicates in the same stage are expected to be identical";
group_bodies.push_back(new_blocks[i].block->body);
}
diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc
b/src/s_tir/transform/using_assume_to_reduce_branches.cc
index 2c356c8f8e..e506d19854 100644
--- a/src/s_tir/transform/using_assume_to_reduce_branches.cc
+++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc
@@ -204,7 +204,7 @@ class ParseAssumeAndOvercompute : public
IRMutatorWithAnalyzer {
PrimExpr current_predicate_and_context = CurrentScopePredicate();
PrimExpr buffer_predicate_and_context =
buffer_assumption.buffer_context &&
buffer_assumption.buffer_predicate;
- bool current_context_and_buffer_constraint_is_same = StructuralEqual()(
+ bool current_context_and_buffer_constraint_is_same =
ffi::StructuralEqual::Equal(
current_predicate_and_context, buffer_predicate_and_context,
/*map_free_vars=*/true);
if (current_context_and_buffer_constraint_is_same) {
@@ -251,10 +251,11 @@ class ParseAssumeAndOvercompute : public
IRMutatorWithAnalyzer {
}
auto n = this->CopyOnWrite(op);
- if (StructuralEqual()(then_clause_in_then_context,
else_clause_in_then_context)) {
+ if (ffi::StructuralEqual()(then_clause_in_then_context,
else_clause_in_then_context)) {
n->value = analyzer_->Simplify(else_clause);
return Stmt(n);
- } else if (StructuralEqual()(then_clause_in_else_context,
else_clause_in_else_context)) {
+ } else if (ffi::StructuralEqual()(then_clause_in_else_context,
+ else_clause_in_else_context)) {
n->value = analyzer_->Simplify(then_clause);
return Stmt(n);
} else {
diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc
index 0c8cd3c123..99d9618639 100644
--- a/src/script/printer/relax/expr.cc
+++ b/src/script/printer/relax/expr.cc
@@ -19,6 +19,7 @@
#include <tvm/relax/distributed/struct_info.h>
+#include <cmath>
#include <limits>
#include "./utils.h"
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 7dddfaecbb..558abaef33 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -125,7 +125,7 @@ inline ffi::Optional<ExprDoc> StructInfoAsAnn(const
relax::Var& v, const AccessP
inferred_sinfo = trivial_binding->struct_info_.as<relax::StructInfo>();
}
- if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) {
+ if (inferred_sinfo && ffi::StructuralEqual()(inferred_sinfo,
v->struct_info_)) {
return std::nullopt;
}
}
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index 2ea588c5ee..c0dbd2e46c 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -19,7 +19,7 @@
#ifndef TVM_SCRIPT_PRINTER_UTILS_H_
#define TVM_SCRIPT_PRINTER_UTILS_H_
-#include <tvm/node/serialization.h>
+#include <tvm/ir/serialization.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <string>
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 0f8806a117..d82c96f263 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -26,6 +26,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>
diff --git a/src/support/scalars.h b/src/support/scalars.h
index fa5a3482f5..069ed62445 100644
--- a/src/support/scalars.h
+++ b/src/support/scalars.h
@@ -25,6 +25,7 @@
#ifndef TVM_SUPPORT_SCALARS_H_
#define TVM_SUPPORT_SCALARS_H_
+#include <cmath>
#include <string>
#include "tvm/ir/expr.h"
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 8dfdd977ac..517dbe07b5 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -26,6 +26,7 @@
#include <tvm/tir/transform.h>
#include <algorithm>
+#include <cmath>
#include <sstream>
#include <string>
#include <unordered_map>
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index 61abb61018..5ee7feb116 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -63,7 +63,7 @@ static inline void AssertReduceEqual(const tir::ReduceNode*
a, const tir::Reduce
"each reduction must be structurally identical, "
"except for the ReduceNode::value_index. ";
- StructuralEqual eq;
+ ffi::StructuralEqual eq;
TVM_FFI_ICHECK(a->combiner.same_as(b->combiner))
<< shared_text << "However, the reduction operation " << a->combiner <<
" does not match "
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 831abb9299..3d2536e423 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -250,7 +250,7 @@ ffi::Array<Buffer> GenerateOutputBuffers(const
te::ComputeOp& compute_op, Create
ffi::Array<te::Tensor> tensors;
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) ->
bool {
- StructuralEqual eq;
+ ffi::StructuralEqual eq;
return eq(a->combiner, b->combiner) && //
eq(a->source, b->source) && //
eq(a->axis, b->axis) && //
diff --git a/src/tir/transform/common_subexpr_elim_tools.cc
b/src/tir/transform/common_subexpr_elim_tools.cc
index 1c52c6f97f..4aa4cbbe76 100644
--- a/src/tir/transform/common_subexpr_elim_tools.cc
+++ b/src/tir/transform/common_subexpr_elim_tools.cc
@@ -797,7 +797,7 @@ std::vector<std::pair<PrimExpr, size_t>>
SyntacticToSemanticComputations(
// normalized. This normalized table will keep the count for each set of
equivalent terms
// (i.e. each equivalence class), together with a term that did appear in
this equivalence class
// (in practice, the first term of the equivalence class that was
encoutered).
- support::OrderedMap<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash,
ExprDeepEqual>
+ support::OrderedMap<PrimExpr, std::pair<PrimExpr, size_t>,
ffi::StructuralHash, ExprDeepEqual>
norm_table;
// In order to avoid frequent rehashing if the norm_table becomes big, we
immediately ask for
diff --git a/src/tir/transform/common_subexpr_elim_tools.h
b/src/tir/transform/common_subexpr_elim_tools.h
index b9c056dcf2..cd548ec0ed 100644
--- a/src/tir/transform/common_subexpr_elim_tools.h
+++ b/src/tir/transform/common_subexpr_elim_tools.h
@@ -26,6 +26,7 @@
#ifndef TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_TOOLS_H_
#define TVM_TIR_TRANSFORM_COMMON_SUBEXPR_ELIM_TOOLS_H_
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/string.h>
#include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis
#include <tvm/tir/expr.h>
@@ -46,13 +47,12 @@ namespace tir {
/*!
* \brief A computation table is a hashtable which associates to each
expression being computed
a number (which is the number of time that it is computed)
- It is important to note that the hash used is a StructuralHash (and
not an ObjectPtrHash)
- as we need to hash similarly deeply equal terms.
- The comparison used is ExprDeepEqual, which is stricter than
StructuralEqual (as it does
- not do variables remapping), so it is compatible with StructuralHash
(intended to be used
- with StructuralEqual).
+ It is important to note that the hash used is a ffi::StructuralHash
(and not an
+ ObjectPtrHash) as we need to hash similarly deeply equal terms. The
comparison used is
+ ExprDeepEqual, which is stricter than ffi::StructuralEqual (as it does not do
variables remapping),
+ so it is compatible with ffi::StructuralHash (intended to be used with
ffi::StructuralEqual).
*/
-using ComputationTable = support::OrderedMap<PrimExpr, size_t, StructuralHash,
ExprDeepEqual>;
+using ComputationTable = support::OrderedMap<PrimExpr, size_t,
ffi::StructuralHash, ExprDeepEqual>;
/*!
* \brief A cache of computations is made of a pair of two hashtables, which
respectively associate
diff --git a/src/tir/transform/vectorize_loop.cc
b/src/tir/transform/vectorize_loop.cc
index 2e8f181199..719d27e743 100644
--- a/src/tir/transform/vectorize_loop.cc
+++ b/src/tir/transform/vectorize_loop.cc
@@ -22,6 +22,7 @@
*/
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
@@ -168,7 +169,7 @@ class TryPredicateBufferAccesses : public StmtExprMutator {
Ramp ramp = Downcast<Ramp>(node->indices[0]);
// The vectorized access pattern must match the base of the predicate
- if (!tvm::StructuralEqual()(ramp->base, base_)) {
+ if (!ffi::StructuralEqual()(ramp->base, base_)) {
return node;
}
diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc
index 495739d766..9f61086176 100644
--- a/tests/cpp/arith_simplify_test.cc
+++ b/tests/cpp/arith_simplify_test.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <tvm/arith/analyzer.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/runtime/logging.h>
#include <tvm/te/operation.h>
@@ -55,7 +56,7 @@ TEST(Simplify, Mod) {
}
TEST(ConstantFold, Broadcast) {
- tvm::StructuralEqual checker;
+ tvm::ffi::StructuralEqual checker;
auto i32x4 = tvm::tir::Broadcast(tvm::IntImm(tvm::DataType::Int(32), 10), 4);
auto i64x4 = tvm::cast(i32x4->dtype.with_bits(64), i32x4);
auto i64x4_expected =
tvm::tir::Broadcast(tvm::IntImm(tvm::DataType::Int(64), 10), 4);
@@ -63,7 +64,7 @@ TEST(ConstantFold, Broadcast) {
}
TEST(ConstantFold, Ramp) {
- tvm::StructuralEqual checker;
+ tvm::ffi::StructuralEqual checker;
auto i32x4 = tvm::tir::Ramp(tvm::IntImm(tvm::DataType::Int(32), 10),
tvm::IntImm(tvm::DataType::Int(32), 1), 4);
auto i64x4 = tvm::cast(i32x4->dtype.with_bits(64), i32x4);
diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc
index 67c9fe99cf..1d3aa62f66 100644
--- a/tests/cpp/expr_test.cc
+++ b/tests/cpp/expr_test.cc
@@ -40,7 +40,7 @@ TEST(Expr, VarTypeAnnotation) {
using namespace tvm::tir;
Var x("x", DataType::Float(32));
Var y("y", PrimType(DataType::Float(32)));
- StructuralEqual checker;
+ tvm::ffi::StructuralEqual checker;
TVM_FFI_ICHECK(checker(x->dtype, y->dtype));
TVM_FFI_ICHECK(checker(x->type_annotation, y->type_annotation));
}
diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc
index b1f7b80c99..02b662875c 100644
--- a/tests/cpp/nested_msg_test.cc
+++ b/tests/cpp/nested_msg_test.cc
@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/relax/nested_msg.h>
#include <tvm/relax/struct_info.h>
#include <tvm/runtime/data_type.h>
@@ -215,10 +216,11 @@ TEST(NestedMsg, MapToNestedMsgBySInfo) {
auto arr1 = arr[1].NestedArray();
EXPECT_TRUE(arr1[0].IsLeaf());
- EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(),
TupleGetItem(TupleGetItem(x, 1), 0)));
+ EXPECT_TRUE(
+ tvm::ffi::StructuralEqual()(arr1[0].LeafValue(),
TupleGetItem(TupleGetItem(x, 1), 0)));
EXPECT_TRUE(arr[2].IsLeaf());
- EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2)));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x,
2)));
}
TEST(NestedMsg, NestedMsgToExpr) {
@@ -246,13 +248,13 @@ TEST(NestedMsg, NestedMsgToExpr) {
});
Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})});
- EXPECT_TRUE(StructuralEqual()(expr, expected));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(expr, expected));
// test simplified
relax::Var t("t", sf1);
NestedMsg<Expr> msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)};
auto expr1 = NestedMsgToExpr<Expr>(msg1, [](ffi::Optional<Expr> leaf) {
return leaf.value(); });
- EXPECT_TRUE(StructuralEqual()(expr1, t));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(expr1, t));
}
TEST(NestedMsg, CombineNestedMsg) {
@@ -323,7 +325,7 @@ TEST(NestedMsg, TransformTupleLeaf) {
Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})});
- EXPECT_TRUE(StructuralEqual()(
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(
TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg2}), ftransleaf),
expected));
EXPECT_TRUE(
diff --git a/tests/cpp/target/virtual_device_test.cc
b/tests/cpp/target/virtual_device_test.cc
index 60b643396c..8e2000852d 100644
--- a/tests/cpp/target/virtual_device_test.cc
+++ b/tests/cpp/target/virtual_device_test.cc
@@ -32,7 +32,7 @@ TEST(VirtualDevice, Join_Defined) {
ffi::Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
EXPECT_TRUE(actual.operator bool());
VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global");
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected));
}
{
Target target_a = Target("cuda");
@@ -41,7 +41,7 @@ TEST(VirtualDevice, Join_Defined) {
ffi::Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
EXPECT_TRUE(actual.operator bool());
VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global");
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected));
}
{
Target target_a = Target("cuda");
@@ -50,7 +50,7 @@ TEST(VirtualDevice, Join_Defined) {
ffi::Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
EXPECT_TRUE(actual.operator bool());
VirtualDevice expected = VirtualDevice(kDLCUDA, 2, target_a);
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected));
}
{
Target target_a = Target("cuda");
@@ -59,7 +59,7 @@ TEST(VirtualDevice, Join_Defined) {
ffi::Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
EXPECT_TRUE(actual.operator bool());
VirtualDevice expected = rhs;
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual.value(), expected));
}
}
@@ -96,7 +96,7 @@ TEST(VirtualDevice, Default) {
VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "local");
VirtualDevice actual = VirtualDevice::Default(lhs, rhs);
VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global");
- EXPECT_TRUE(StructuralEqual()(actual, expected));
+ EXPECT_TRUE(tvm::ffi::StructuralEqual()(actual, expected));
}
TEST(VirtualDevice, Constructor_Invalid) {
diff --git a/tests/python/relax/test_group_gemm_flashinfer.py
b/tests/python/relax/test_group_gemm_flashinfer.py
index c128fdb35b..2d15758490 100644
--- a/tests/python/relax/test_group_gemm_flashinfer.py
+++ b/tests/python/relax/test_group_gemm_flashinfer.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: E501, E722, F401, F841, RUF005
+# ruff: noqa: E501, F401, F841, RUF005
"""Test for FlashInfer GroupedGemm TVM integration"""
diff --git a/tests/scripts/release/make_notes.py
b/tests/scripts/release/make_notes.py
index b599f4b4f3..82e5a4372b 100644
--- a/tests/scripts/release/make_notes.py
+++ b/tests/scripts/release/make_notes.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: E722
import argparse
import csv