Author: Jacques Pienaar Date: 2020-11-28T15:53:59-08:00 New Revision: 6dd9596b19d7679c562f8e866be6d0c3d7c21994
URL: https://github.com/llvm/llvm-project/commit/6dd9596b19d7679c562f8e866be6d0c3d7c21994 DIFF: https://github.com/llvm/llvm-project/commit/6dd9596b19d7679c562f8e866be6d0c3d7c21994.diff LOG: [mlir] Add a shape function library op Op with mapping from ops to corresponding shape functions for those op in the library and mechanism to associate shape functions to functions. The mapping of operand to shape function is kept separate from the shape functions themselves as the operation is associated to the shape function and not vice versa, and one could have a common library of shape functions that can be used in different contexts. Use fully qualified names and require a name for shape fn lib ops for now and an explicit print/parse (based around the generated one & GPU module op ones). Differential Revision: https://reviews.llvm.org/D91672 Added: mlir/test/Analysis/test-shape-fn-report.mlir mlir/test/lib/Dialect/Shape/CMakeLists.txt mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp Modified: mlir/include/mlir/Dialect/Shape/IR/Shape.h mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td mlir/lib/Dialect/Shape/IR/Shape.cpp mlir/test/lib/Dialect/CMakeLists.txt mlir/test/lib/Dialect/Test/TestOps.td mlir/tools/mlir-opt/CMakeLists.txt mlir/tools/mlir-opt/mlir-opt.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h index f40d6154544a..cb5ed56e16a2 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -14,6 +14,7 @@ #ifndef MLIR_SHAPE_IR_SHAPE_H #define MLIR_SHAPE_IR_SHAPE_H +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index a852d900cf69..52768e49001d 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -18,6 +18,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// // Shape op definitions @@ -492,7 +493,7 @@ def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> { } def Shape_YieldOp : Shape_Op<"yield", - [HasParent<"ReduceOp">, + [HasParent<"ReduceOp, FunctionLibraryOp">, NoSideEffect, ReturnLike, Terminator]> { @@ -780,4 +781,62 @@ def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Shape collection ops. +//===----------------------------------------------------------------------===// + +def Shape_FunctionLibraryOp : Shape_Op<"function_library", + [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, + SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> { + let summary = "Represents shape functions and corresponding ops"; + let description = [{ + Represents a list of shape functions and the ops whose shape transfer + functions they represent. + + Example: + + ```mlir + shape.function_library { + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } + } mapping { + std.atan = @same_result_shape + } + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + OptionalAttr<StrAttr>:$sym_visibility); + let arguments = (ins DictionaryAttr:$mapping); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + /// Returns an associated shape function for an operation if defined. + FuncOp getShapeFunction(Operation *op); + }]; + + let builders = [OpBuilderDAG<(ins "StringRef":$name)>]; + let skipDefaultBuilders = 1; + + let printer = [{ ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +//===----------------------------------------------------------------------===// +// ShapeFunctionLibraryTerminatorOp +//===----------------------------------------------------------------------===// + +def ShapeFunctionLibraryTerminatorOp : Shape_Op<"fn_lib_terminator", + [Terminator, HasParent<"FunctionLibraryOp">]> { + let summary = "A pseudo op that marks the end of a shape function library"; + let description = [{ + `shape_fn_lib_terminator` is a special pseudo terminator operation for the + shape function library. It has no semantic meaning beyond keeping the body + well-formed. + }]; + let assemblyFormat = "attr-dict"; +} + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index cfac2abae3e6..d8c7f4c6736d 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Function.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Transforms/InliningUtils.h" @@ -558,6 +559,65 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { return builder.getIndexTensorAttr(extents); } +//===----------------------------------------------------------------------===// +// FunctionLibraryOp +//===----------------------------------------------------------------------===// + +void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, + StringRef name) { + ensureTerminator(*result.addRegion(), builder, result.location); + result.attributes.push_back(builder.getNamedAttr( + ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); +} + +FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { + auto attr = mapping() + .get(op->getName().getIdentifier()) + .dyn_cast_or_null<FlatSymbolRefAttr>(); + if (!attr) + return nullptr; + return lookupSymbol<FuncOp>(attr); +} + +ParseResult parseFunctionLibraryOp(OpAsmParser &parser, + OperationState &result) { + // Parse the op name. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + auto *bodyRegion = result.addRegion(); + if (parser.parseRegion(*bodyRegion)) + return failure(); + + FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), + result.location); + if (parser.parseKeyword("mapping")) + return failure(); + + DictionaryAttr mappingAttr; + if (parser.parseAttribute(mappingAttr, + parser.getBuilder().getType<NoneType>(), "mapping", + result.attributes)) + return failure(); + return success(); +} + +void print(OpAsmPrinter &p, FunctionLibraryOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.getName()); + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); + p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + p << " mapping "; + p.printAttributeWithoutType(op.mappingAttr()); +} + //===----------------------------------------------------------------------===// // GetExtentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir new file mode 100644 index 000000000000..ad5c8e64a1b7 --- /dev/null +++ b/mlir/test/Analysis/test-shape-fn-report.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics + +// expected-remark@+1 {{associated shape function: same_result_shape}} +func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32> + attributes {shape.function = @shape_lib::@same_result_shape} { + // expected-remark@+1 {{no associated way}} + %0 = tanh %arg : tensor<10x20xf32> + // expected-remark@+1 {{associated shape function: same_result_shape}} + %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32> + return %1 : tensor<10x20xf32> +} + +// The shape function library with some local functions. +shape.function_library @shape_lib { + // Test shape function that returns the shape of input arg as result shape. + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + return %0 : !shape.shape + } +} mapping { + test.same_operand_result_type = @same_result_shape +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index b220d0d81632..adee9f8a1514 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Affine) +add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(Test) add_subdirectory(Tosa) diff --git a/mlir/test/lib/Dialect/Shape/CMakeLists.txt b/mlir/test/lib/Dialect/Shape/CMakeLists.txt new file mode 100644 index 000000000000..6c041ab9c371 --- /dev/null +++ b/mlir/test/lib/Dialect/Shape/CMakeLists.txt @@ -0,0 +1,16 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRShapeTestPasses + TestShapeFunctions.cpp + + EXCLUDE_FROM_LIBMLIR + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRShape + MLIRSupport + ) diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp new file mode 100644 index 000000000000..688f24e5ec47 --- /dev/null +++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp @@ -0,0 +1,73 @@ +//===- TestShapeFunctions.cpp - Passes to test shape function ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include <queue> + +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This is a pass that reports shape functions associated with ops. +struct ReportShapeFnPass + : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { + void runOnOperation() override; +}; +} // end anonymous namespace + +void ReportShapeFnPass::runOnOperation() { + auto module = getOperation(); + + // Lookup shape function library. + shape::FunctionLibraryOp shapeFnLib = nullptr; + for (auto lib : module.getOps<shape::FunctionLibraryOp>()) { + if (shapeFnLib) { + lib.emitError("duplicate shape library op") + .attachNote(shapeFnLib.getLoc()) + << "previous mapping"; + return signalPassFailure(); + } + shapeFnLib = lib; + }; + + // Report the shape function available to refine the op. + auto shapeFnId = Identifier::get("shape.function", &getContext()); + auto remarkShapeFn = [&](Operation *op) { + if (op->isKnownTerminator()) + return; + if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { + op->emitRemark() << "implements InferType op interface"; + } else if (auto fn = shapeFnLib.getShapeFunction(op)) { + op->emitRemark() << "associated shape function: " << fn.getName(); + } else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { + auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); + op->emitRemark() << "associated shape function: " << fn.getName(); + } else { + op->emitRemark() << "no associated way to refine shape"; + } + }; + + module.getBodyRegion().walk([&](FuncOp func) { + // Skip ops in the shape function library. + if (isa<shape::FunctionLibraryOp>(func.getParentOp())) + return; + + func.walk([&](Operation *op) { remarkShapeFn(op); }); + }); +} + +namespace mlir { +void registerShapeFunctionTestPasses() { + PassRegistration<ReportShapeFnPass>( + "test-shape-function-report", + "Test pass to report associated shape functions"); +} +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index aef5b5166ae2..5a17eebfd32c 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -134,6 +134,12 @@ def VariadicWithSameOperandsResult : let results = (outs AnySignlessInteger:$result); } +def SameOperandsResultType : TEST_Op< + "same_operand_result_type", [SameOperandsAndResultType]> { + let arguments = (ins AnyTensor:$operand); + let results = (outs AnyTensor:$result); +} + //===----------------------------------------------------------------------===// // Test Results //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 483dcfec0c0f..e8b0842a9e33 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -13,6 +13,7 @@ set(LLVM_LINK_COMPONENTS if(MLIR_INCLUDE_TESTS) set(test_libs MLIRAffineTransformsTestPasses + MLIRShapeTestPasses MLIRSPIRVTestPasses MLIRTestDialect MLIRTestIR diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index a0e36cf82534..4095cc21cbaf 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -32,6 +32,7 @@ namespace mlir { void registerConvertToTargetEnvPass(); void registerPassManagerTestPass(); void registerPrintOpAvailabilityPass(); +void registerShapeFunctionTestPasses(); void registerSideEffectTestPasses(); void registerSliceAnalysisTestPass(); void registerSymbolTestPasses(); @@ -98,6 +99,7 @@ void registerTestPasses() { registerConvertToTargetEnvPass(); registerPassManagerTestPass(); registerPrintOpAvailabilityPass(); + registerShapeFunctionTestPasses(); registerSideEffectTestPasses(); registerSliceAnalysisTestPass(); registerSymbolTestPasses(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits