https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/79626
>From ab475c9ffb7c3562bad4772389e97b82e9f110c0 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Fri, 26 Jan 2024 11:55:06 -0600 Subject: [PATCH 1/3] Add elementwise criteria to match.structured.body --- .../Linalg/TransformOps/LinalgMatchOps.td | 4 +++ .../Linalg/TransformOps/LinalgMatchOps.cpp | 9 ++++- .../Dialect/Linalg/match-ops-interpreter.mlir | 34 +++++++++++++++++++ .../Dialect/Linalg/match-ops-invalid.mlir | 2 +- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td index 162dd05f93030f2..dfeb8ae5d5ddbcb 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td @@ -106,6 +106,9 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [ * `passthrough`: the body of the structured payload op only forwards inputs to the outputs (copy or broadcast). + * `elementwise`: the body of the structured payload op represents an + elementwise operation. + * `contraction`: the body of the structured payload op is a contraction of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and `<red>` are binary operations whose names are specified in the attribute @@ -123,6 +126,7 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [ let arguments = (ins TransformHandleTypeInterface:$operand_handle, OptionalAttr<I64Attr>:$reduction_position, UnitAttr:$passthrough, + UnitAttr:$elementwise, OptionalAttr<StrArrayAttr>:$contraction); let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)"; let extraClassDeclaration = SingleOpMatcher.extraDeclaration; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 115da4b90e063ac..fb18886c16b16d5 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/FunctionImplementation.h" @@ -187,6 +188,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( } return DiagnosedSilenceableFailure::success(); } + if (getElementwise()) { + if (!isElementwise(linalgOp)) + return emitSilenceableError() << "not elementwise"; + return DiagnosedSilenceableFailure::success(); + } if (std::optional<ArrayAttr> contractionOps = getContraction()) { Block &body = linalgOp->getRegion(0).front(); std::string message; @@ -209,13 +215,14 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( LogicalResult transform::MatchStructuredBodyOp::verify() { int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + - getContraction().has_value(); + getElementwise() + getContraction().has_value(); if (numOptions > 1) { std::string attributeNames; llvm::raw_string_ostream os(attributeNames); llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(), getPassthroughAttrName(), + getElementwiseAttrName(), getContractionAttrName()}, os); return emitOpError() << "only one of {" << os.str() << "} is allowed"; diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir index a7353a4c38881e4..0efe70a7b9ae1eb 100644 --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -180,6 +180,40 @@ module attributes { transform.with_named_sequence } { // ----- +module attributes { transform.with_named_sequence } { + transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "elementwise" : !transform.any_op + transform.yield + } + + transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.match.structured.body %arg1 { elementwise } : !transform.any_op + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { + transform.foreach_match in %arg0 + @match_structured_body_elementwise -> @print_elementwise + : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %out: tensor<2xf32>) -> tensor<2xf32> attributes { transform.target_tag = "start_here" } { + %cst0 = arith.constant 0.0 : f32 + // expected-remark @below {{elementwise}} + %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32> + // expected-remark @below {{elementwise}} + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>) + return %add : tensor<2xf32> + } +} + +// ----- + module attributes { transform.with_named_sequence } { transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) { transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir index ec99e205090c4cb..9ff430a35036063 100644 --- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir @@ -64,7 +64,7 @@ transform.sequence failures(suppress) { ^bb0(%arg0: !transform.any_op): transform.match.structured %arg0 : !transform.any_op { ^bb1(%arg1: !transform.any_op): - // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}} + // expected-error @below {{only one of {"reduction_position", "passthrough", "elementwise", "contraction"} is allowed}} transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op transform.match.structured.yield } >From a1cb4dfafcc64c51409d67e6396b93320508af99 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Fri, 26 Jan 2024 14:48:19 -0600 Subject: [PATCH 2/3] Add brodcast elementwise test --- .../Dialect/Linalg/match-ops-interpreter.mlir | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir index 0efe70a7b9ae1eb..6e05c6e17de18bf 100644 --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -182,7 +182,7 @@ module attributes { transform.with_named_sequence } { module attributes { transform.with_named_sequence } { transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) { - transform.test_print_remark_at_operand %arg0, "elementwise" : !transform.any_op + transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op transform.yield } @@ -202,13 +202,22 @@ module attributes { transform.with_named_sequence } { transform.yield } - func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %out: tensor<2xf32>) -> tensor<2xf32> attributes { transform.target_tag = "start_here" } { + func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { %cst0 = arith.constant 0.0 : f32 // expected-remark @below {{elementwise}} %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32> // expected-remark @below {{elementwise}} %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>) - return %add : tensor<2xf32> + // expected-remark @below {{elementwise}} + %add_bcast = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg3: f32): + %0 = arith.addf %arg0, %arg1 : f32 + linalg.yield %0 : f32 + } -> tensor<2x3xf32> + return %add, %add_bcast : tensor<2xf32>, tensor<2x3xf32> } } >From bd1a89f888060d94c3326e2218bbfb2d9bda24c1 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Fri, 26 Jan 2024 15:02:29 -0600 Subject: [PATCH 3/3] Add non-elementwise test --- .../Dialect/Linalg/match-ops-interpreter.mlir | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir index 6e05c6e17de18bf..24c7bdd9e1050ee 100644 --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -202,12 +202,26 @@ module attributes { transform.with_named_sequence } { transform.yield } - func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { + func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { %cst0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index // expected-remark @below {{elementwise}} %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32> // expected-remark @below {{elementwise}} %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>) + %non_elementwise = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg3: f32): + %0 = arith.addf %arg0, %arg1 : f32 + %1 = tensor.dim %add, %c0 : tensor<2xf32> + %2 = arith.subi %1, %c1 : index + %3 = tensor.extract %add[%2] : tensor<2xf32> + %4 = arith.mulf %0, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x3xf32> // expected-remark @below {{elementwise}} %add_bcast = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -217,7 +231,7 @@ module attributes { transform.with_named_sequence } { %0 = arith.addf %arg0, %arg1 : f32 linalg.yield %0 : f32 } -> tensor<2x3xf32> - return %add, %add_bcast : tensor<2xf32>, tensor<2x3xf32> + return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32> } } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits