Author: Stella Laurenzo Date: 2020-11-29T18:09:07-08:00 New Revision: ba0fe76b7eb87f91499931e76317ddd1cb493aa1
URL: https://github.com/llvm/llvm-project/commit/ba0fe76b7eb87f91499931e76317ddd1cb493aa1 DIFF: https://github.com/llvm/llvm-project/commit/ba0fe76b7eb87f91499931e76317ddd1cb493aa1.diff LOG: [mlir][Python] Add an Operation.result property. * If ODS redefines this, it is fine, but I have found this accessor to be universally useful in the old npcomp bindings and I'm closing gaps that will let me switch. Differential Revision: https://reviews.llvm.org/D92287 Added: Modified: mlir/lib/Bindings/Python/IRModules.cpp mlir/lib/Bindings/Python/IRModules.h mlir/test/Bindings/Python/ir_operation.py Removed: ################################################################################ diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index d34fe998583f..d270e44debae 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -23,6 +23,8 @@ using namespace mlir; using namespace mlir::python; using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; //------------------------------------------------------------------------------ // Docstrings (trivial, non-duplicated docstrings are included inline). @@ -631,7 +633,7 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, getContext()->get(), {canonKey->data(), canonKey->size()}); if (mlirDialectIsNull(dialect)) { throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, - llvm::Twine("Dialect '") + key + "' not found"); + Twine("Dialect '") + key + "' not found"); } return dialect; } @@ -793,7 +795,7 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, return created; } -void PyOperation::checkValid() { +void PyOperation::checkValid() const { if (!valid) { throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); } @@ -817,7 +819,7 @@ void PyOperationBase::print(py::object fileObject, bool binary, PyFileAccumulator accum(fileObject, binary); py::gil_scoped_release(); - mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(), + mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), accum.getUserData()); mlirOpPrintingFlagsDestroy(flags); } @@ -975,7 +977,7 @@ py::object PyOperation::createOpView() { MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto opViewClass = PyGlobals::get().lookupRawOpViewClass( - llvm::StringRef(identStr.data, identStr.length)); + StringRef(identStr.data, identStr.length)); if (opViewClass) return (*opViewClass)(getRef().getObject()); return py::cast(PyOpView(getRef().getObject())); @@ -1044,7 +1046,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) { (*refOperation)->checkValid(); beforeOp = (*refOperation)->get(); } - mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get()); + mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); operation.setAttached(); } @@ -1158,7 +1160,7 @@ class PyConcreteValue : public PyValue { static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") + + throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + DerivedTy::pyClassName + " (from " + origRepr + ")"); } @@ -1416,9 +1418,9 @@ class PyConcreteAttribute : public BaseTy { static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); - throw SetPyError(PyExc_ValueError, - llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); + throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); } return orig; } @@ -1449,7 +1451,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { // in C API. if (mlirAttributeIsNull(attr)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(type)).cast<std::string>() + "' and expected floating point type."); } @@ -1943,7 +1945,7 @@ class PyConcreteType : public BaseTy { static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + DerivedTy::pyClassName + " (from " + origRepr + ")"); } @@ -2142,7 +2144,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> { } throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast<std::string>() + "' and expected floating point or integer type."); }, @@ -2247,7 +2249,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast<std::string>() + "' and expected floating point or integer type."); } @@ -2278,7 +2280,7 @@ class PyRankedTensorType if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast<std::string>() + "' and expected floating point, integer, vector or " "complex " @@ -2309,7 +2311,7 @@ class PyUnrankedTensorType if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast<std::string>() + "' and expected floating point, integer, vector or " "complex " @@ -2344,7 +2346,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast<std::string>() + "' and expected floating point, integer, vector or " "complex " @@ -2390,7 +2392,7 @@ class PyUnrankedMemRefType if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast<std::string>() + "' and expected floating point, integer, vector or " "complex " @@ -2544,7 +2546,7 @@ void mlir::python::populateIRSubmodule(py::module &m) { self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("Dialect '") + name + "' not found"); + Twine("Dialect '") + name + "' not found"); } return PyDialectDescriptor(self.getRef(), dialect); }, @@ -2763,6 +2765,26 @@ void mlir::python::populateIRSubmodule(py::module &m) { return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") + .def_property_readonly( + "result", + [](PyOperationBase &self) { + auto &operation = self.getOperation(); + auto numResults = mlirOperationGetNumResults(operation); + if (numResults != 1) { + auto name = mlirIdentifierStr(mlirOperationGetName(operation)); + throw SetPyError( + PyExc_ValueError, + Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)"); + } + return PyOpResult(operation.getRef(), + mlirOperationGetResult(operation, 0)); + }, + "Shortcut to get an op result if it has only one (throws an error " + "otherwise).") .def("__iter__", [](PyOperationBase &self) { return PyRegionIterator(self.getOperation().getRef()); @@ -2931,7 +2953,7 @@ void mlir::python::populateIRSubmodule(py::module &m) { // in C API. if (mlirAttributeIsNull(type)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("Unable to parse attribute: '") + + Twine("Unable to parse attribute: '") + attrSpec + "'"); } return PyAttribute(context->getRef(), type); @@ -3042,8 +3064,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { // in C API. if (mlirTypeIsNull(type)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("Unable to parse type: '") + - typeSpec + "'"); + Twine("Unable to parse type: '") + typeSpec + + "'"); } return PyType(context->getRef(), type); }, diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h index d24607fb02c2..0cdc7e6a66fe 100644 --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -425,7 +425,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject { pybind11::object parentKeepAlive = pybind11::object()); /// Gets the backing operation. - MlirOperation get() { + operator MlirOperation() const { return get(); } + MlirOperation get() const { checkValid(); return operation; } @@ -440,7 +441,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { assert(!attached && "operation already attached"); attached = true; } - void checkValid(); + void checkValid() const; /// Gets the owning block or raises an exception if the operation has no /// owning block. diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py index ddc4c2129844..e3867b99a9b4 100644 --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -474,6 +474,7 @@ def testOperationPrint(): run(testOperationPrint) +# CHECK-LABEL: TEST: testKnownOpView def testKnownOpView(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -503,3 +504,36 @@ def testKnownOpView(): print(repr(custom)) run(testKnownOpView) + + +# CHECK-LABEL: TEST: testSingleResultProperty +def testSingleResultProperty(): + with Context(), Location.unknown(): + Context.current.allow_unregistered_dialects = True + module = Module.parse(r""" + "custom.no_result"() : () -> () + %0:2 = "custom.two_result"() : () -> (f32, f32) + %1 = "custom.one_result"() : () -> f32 + """) + print(module) + + try: + module.body.operations[0].result + except ValueError as e: + # CHECK: Cannot call .result on operation custom.no_result which has 0 results + print(e) + else: + assert False, "Expected exception" + + try: + module.body.operations[1].result + except ValueError as e: + # CHECK: Cannot call .result on operation custom.two_result which has 2 results + print(e) + else: + assert False, "Expected exception" + + # CHECK: %1 = "custom.one_result"() : () -> f32 + print(module.body.operations[2]) + +run(testSingleResultProperty) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits