llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This PR ports all in-tree dialect extensions to use the `PyConcreteType`, 
`PyConcreteAttribute` CRTPs instead of `mlir_pure_subclass`. After this PR we 
can soft deprecate `mlir_pure_subclass`.

depends on https://github.com/llvm/llvm-project/pull/174118

---

Patch is 111.46 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/174156.diff


11 Files Affected:

- (modified) mlir/lib/Bindings/Python/DialectAMDGPU.cpp (+74-36) 
- (modified) mlir/lib/Bindings/Python/DialectGPU.cpp (+87-65) 
- (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+164-133) 
- (modified) mlir/lib/Bindings/Python/DialectNVGPU.cpp (+31-18) 
- (modified) mlir/lib/Bindings/Python/DialectPDL.cpp (+145-83) 
- (modified) mlir/lib/Bindings/Python/DialectQuant.cpp (+454-355) 
- (modified) mlir/lib/Bindings/Python/DialectSMT.cpp (+63-26) 
- (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+125-109) 
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+150-98) 
- (modified) mlir/python/mlir/dialects/transform/extras/__init__.py (+6-5) 
- (modified) mlir/test/python/dialects/pdl_types.py (+107-104) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp 
b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..26115c3635b7b 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,96 @@
 
 #include "mlir-c/Dialect/AMDGPU.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
-  auto amdgpuTDMBaseType =
-      mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
-                         mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMBaseTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMBaseType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMBaseType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
-        return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
-      },
-      "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
-      nb::arg("element_type"), nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &elementType, DefaultingPyMlirContext context) {
+          return TDMBaseType(
+              context->getRef(),
+              mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+        },
+        "Gets an instance of TDMBaseType in the same context",
+        nb::arg("element_type"), nb::arg("context").none() = nb::none());
+  }
+};
 
-  auto amdgpuTDMDescriptorType = mlir_type_subclass(
-      m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
-      mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAAMDGPUTDMDescriptorType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMDescriptorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMDescriptorType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMDescriptorType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
-      },
-      "Gets an instance of TDMDescriptorType in the same context",
-      nb::arg("cls"), nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return TDMDescriptorType(
+              context->getRef(),
+              mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+        },
+        "Gets an instance of TDMDescriptorType in the same context",
+        nb::arg("context").none() = nb::none());
+  }
+};
 
-  auto amdgpuTDMGatherBaseType = mlir_type_subclass(
-      m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
-      mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAAMDGPUTDMGatherBaseType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+  static constexpr const char *pyClassName = "TDMGatherBaseType";
+  using PyConcreteType::PyConcreteType;
 
-  amdgpuTDMGatherBaseType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirType elementType, MlirType indexType,
-         MlirContext ctx) {
-        return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, 
indexType));
-      },
-      "Gets an instance of TDMGatherBaseType in the same context",
-      nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
-      nb::arg("ctx") = nb::none());
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const PyType &elementType, const PyType &indexType,
+           DefaultingPyMlirContext context) {
+          return TDMGatherBaseType(
+              context->getRef(),
+              mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+                                             indexType));
+        },
+        "Gets an instance of TDMGatherBaseType in the same context",
+        nb::arg("element_type"), nb::arg("index_type"),
+        nb::arg("context").none() = nb::none());
+  }
 };
 
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+  TDMBaseType::bind(m);
+  TDMDescriptorType::bind(m);
+  TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
 NB_MODULE(_mlirDialectsAMDGPU, m) {
   m.doc() = "MLIR AMDGPU dialect.";
 
-  populateDialectAMDGPUSubmodule(m);
+  mlir::python::mlir::amdgpu::populateDialectAMDGPUSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp 
b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..ea3748cc88b85 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
 #include "mlir-c/Dialect/GPU.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
 // 
-----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
 // 
-----------------------------------------------------------------------------
 
-NB_MODULE(_mlirDialectsGPU, m) {
-  m.doc() = "MLIR GPU Dialect";
-  //===-------------------------------------------------------------------===//
-  // AsyncTokenType
-  //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+  static constexpr const char *pyClassName = "AsyncTokenType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          return AsyncTokenType(context->getRef(),
+                                
mlirGPUAsyncTokenTypeGet(context.get()->get()));
+        },
+        "Gets an instance of AsyncTokenType in the same context",
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+  static constexpr const char *pyClassName = "ObjectAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
 
-  auto mlirGPUAsyncTokenType =
-      mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+           std::optional<MlirAttribute> mlirObjectProps,
+           std::optional<MlirAttribute> mlirKernelsAttr,
+           DefaultingPyMlirContext context) {
+          MlirStringRef objectStrRef = mlirStringRefCreate(
+              static_cast<char *>(const_cast<void *>(object.data())),
+              object.size());
+          return ObjectAttr(
+              context->getRef(),
+              mlirGPUObjectAttrGetWithKernels(
+                  mlirAttributeGetContext(target), target, format, 
objectStrRef,
+                  mlirObjectProps.has_value() ? *mlirObjectProps
+                                              : MlirAttribute{nullptr},
+                  mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+                                              : MlirAttribute{nullptr}));
+        },
+        "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+        "kernels"_a = nb::none(), "context"_a = nb::none(),
+        "Gets a gpu.object from parameters.");
 
-  mlirGPUAsyncTokenType.def_classmethod(
-      "get",
-      [](const nb::object &cls, MlirContext ctx) {
-        return cls(mlirGPUAsyncTokenTypeGet(ctx));
-      },
-      "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
-      nb::arg("ctx") = nb::none());
+    c.def_prop_ro("target", [](MlirAttribute self) {
+      return mlirGPUObjectAttrGetTarget(self);
+    });
+    c.def_prop_ro("format", [](MlirAttribute self) {
+      return mlirGPUObjectAttrGetFormat(self);
+    });
+    c.def_prop_ro("object", [](MlirAttribute self) {
+      MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+      return nb::bytes(stringRef.data, stringRef.length);
+    });
+    c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+      if (mlirGPUObjectAttrHasProperties(self))
+        return nb::cast(mlirGPUObjectAttrGetProperties(self));
+      return nb::none();
+    });
+    c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+      if (mlirGPUObjectAttrHasKernels(self))
+        return nb::cast(mlirGPUObjectAttrGetKernels(self));
+      return nb::none();
+    });
+  }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
-  //===-------------------------------------------------------------------===//
-  // ObjectAttr
-  //===-------------------------------------------------------------------===//
+// 
-----------------------------------------------------------------------------
+// Module initialization.
+// 
-----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+  m.doc() = "MLIR GPU Dialect";
 
-  mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
-      .def_classmethod(
-          "get",
-          [](const nb::object &cls, MlirAttribute target, uint32_t format,
-             const nb::bytes &object,
-             std::optional<MlirAttribute> mlirObjectProps,
-             std::optional<MlirAttribute> mlirKernelsAttr) {
-            MlirStringRef objectStrRef = mlirStringRefCreate(
-                static_cast<char *>(const_cast<void *>(object.data())),
-                object.size());
-            return cls(mlirGPUObjectAttrGetWithKernels(
-                mlirAttributeGetContext(target), target, format, objectStrRef,
-                mlirObjectProps.has_value() ? *mlirObjectProps
-                                            : MlirAttribute{nullptr},
-                mlirKernelsAttr.has_value() ? *mlirKernelsAttr
-                                            : MlirAttribute{nullptr}));
-          },
-          "cls"_a, "target"_a, "format"_a, "object"_a,
-          "properties"_a = nb::none(), "kernels"_a = nb::none(),
-          "Gets a gpu.object from parameters.")
-      .def_property_readonly(
-          "target",
-          [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
-      .def_property_readonly(
-          "format",
-          [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
-      .def_property_readonly(
-          "object",
-          [](MlirAttribute self) {
-            MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
-            return nb::bytes(stringRef.data, stringRef.length);
-          })
-      .def_property_readonly("properties",
-                             [](MlirAttribute self) -> nb::object {
-                               if (mlirGPUObjectAttrHasProperties(self))
-                                 return nb::cast(
-                                     mlirGPUObjectAttrGetProperties(self));
-                               return nb::none();
-                             })
-      .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
-        if (mlirGPUObjectAttrHasKernels(self))
-          return nb::cast(mlirGPUObjectAttrGetKernels(self));
-        return nb::none();
-      });
+  mlir::python::mlir::gpu::AsyncTokenType::bind(m);
+  mlir::python::mlir::gpu::ObjectAttr::bind(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp 
b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
 #include "mlir-c/Support.h"
 #include "mlir-c/Target/LLVMIR.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 
 using namespace nanobind::literals;
-
 using namespace llvm;
 using namespace mlir;
-using namespace mlir::python;
 using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
-  
//===--------------------------------------------------------------------===//
-  // StructType
-  
//===--------------------------------------------------------------------===//
-
-  auto llvmStructType = mlir_type_subclass(
-      m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
-  llvmStructType
-      .def_classmethod(
-          "get_literal",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirLocation loc) {
-            CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
-            MlirType type = mlirLLVMStructTypeLiteralGetChecked(
-                loc, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "loc"_a = nb::none())
-      .def_classmethod(
-          "get_literal_unchecked",
-          [](const nb::object &cls, const std::vector<MlirType> &elements,
-             bool packed, MlirContext context) {
-            CollectDiagnosticsToStringScope scope(context);
-
-            MlirType type = mlirLLVMStructTypeLiteralGet(
-                context, elements.size(), elements.data(), packed);
-            if (mlirTypeIsNull(type)) {
-              throw nb::value_error(scope.takeMessage().c_str());
-            }
-            return cls(type);
-          },
-          "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
-          "context"_a = nb::none());
-
-  llvmStructType.def_classmethod(
-      "get_identified",
-      [](const nb::object &cls, const std::string &name, MlirContext context) {
-        return cls(mlirLLVMStructTypeIdentifiedGet(
-            context, mlirStringRefCreate(name.data(), name.size())));
-      },
-      "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirLLVMStructTypeGetTypeID;
+  static constexpr const char *pyClassName = "StructType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_literal",
+        [](const std::vector<MlirType> &elements, bool packed, MlirLocation 
loc,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(
+              mlirLocationGetContext(loc));
+
+          MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+              loc, elements.size(), elements.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_literal_unchecked",
+        [](const std::vector<MlirType> &elements, bool packed,
+           DefaultingPyMlirContext context) {
+          python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+          MlirType type = mlirLLVMStructTypeLiteralGet(
+              context.get()->get(), elements.size(), elements.data(), packed);
+          if (mlirTypeIsNull(type)) {
+            throw nb::value_error(scope.takeMessage().c_str());
+          }
+          return StructType(context->getRef(), type);
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_static(
+        "get_identified",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), 
name.size())));
+        },
+        "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+    c.def_static(
+        "get_opaque",
+        [](const std::string &name, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeOpaqueGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), 
name.size())));
+        },
+        "name"_a, "context"_a = nb::none());
+
+    c.def(
+        "set_body",
+        [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+          MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+              self, elements.size(), elements.data(), packed);
+          if (!mlirLogicalResultIsSuccess(result)) {
+            throw nb::value_error(
+                "Struct body already set to different content.");
+          }
+        },
+        "elements"_a, nb::kw_only(), "packed"_a = false);
+
+    c.def_static(
+        "new_identified",
+        [](const std::string &name, const std::vector<MlirType> &elements,
+           bool packed, DefaultingPyMlirContext context) {
+          return StructType(context->getRef(),
+                            mlirLLVMStructTypeIdentifiedNewGet(
+                                context.get()->get(),
+                                mlirStringRefCreate(name.data(), 
name.length()),
+                                elements.size(), elements.data(), packed));
+        },
+        "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+        "context"_a = nb::none());
+
+    c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+      if (mlirLLVMStructTypeIsLiteral(type))
+        return std::nullopt;
+
+      MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+      return StringRef(stringRef.data, stringRef.length).str();
+    });
+
+    c.def_prop_ro("body", [](PyType type) -> nb::object {
+      // Don't crash in absence of a body.
+      if (mlirLLVMStructTypeIsOpaque(type))
+        return nb::none();
+
+      nb::list body;
+      for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+           i < e; ++i) {
+        body.append(mlirLLVMStructTypeGetElementType(type, i));
+      }
+      return body;
+    });
+
+    c.def_prop_ro("packed",
+                  [](PyType type) { return mlirLLVMStructTypeIsPacked(type); 
})...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/174156
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to