https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/86383
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand. >From 680d04d71e663aac51ea8f4dc4885d0bfd050b19 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Sat, 23 Mar 2024 08:24:46 +0000 Subject: [PATCH] [mlir][Arith] `ValueBoundsOpInterface`: Support `arith.select` This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand. --- .../Arith/IR/ValueBoundsOpInterfaceImpl.cpp | 70 +++++++++++++++++++ .../Arith/value-bounds-op-interface-impl.mlir | 31 ++++++++ 2 files changed, 101 insertions(+) diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp index 9c6b50e767ea26..bb7b9c939fcb09 100644 --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -66,6 +66,75 @@ struct MulIOpInterface } }; +struct SelectOpInterface + : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface, + SelectOp> { + + static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + Value value = selectOp.getResult(); + Value condition = selectOp.getCondition(); + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + + if (isa<ShapedType>(condition.getType())) { + // If the condition is a shaped type, the condition is applied + // element-wise. All three operands must have the same shape. + cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim); + cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim); + cstr.bound(value)[*dim] == cstr.getExpr(condition, dim); + return; + } + + // Populate constraints for the true/false values (and all values on the + // backward slice, as long as the current stop condition is not satisfied). + cstr.populateConstraints(trueValue, dim); + cstr.populateConstraints(falseValue, dim); + auto boundsBuilder = cstr.bound(value); + if (dim) + boundsBuilder[*dim]; + + // Compare yielded values. + // If trueValue <= falseValue: + // * result <= falseValue + // * result >= trueValue + if (cstr.compare(trueValue, dim, + ValueBoundsConstraintSet::ComparisonOperator::LE, + falseValue, dim)) { + if (dim) { + cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); + cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); + } else { + cstr.bound(value) >= trueValue; + cstr.bound(value) <= falseValue; + } + } + // If falseValue <= trueValue: + // * result <= trueValue + // * result >= falseValue + if (cstr.compare(falseValue, dim, + ValueBoundsConstraintSet::ComparisonOperator::LE, + trueValue, dim)) { + if (dim) { + cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); + cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); + } else { + cstr.bound(value) >= falseValue; + cstr.bound(value) <= trueValue; + } + } + } + + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr); + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + populateBounds(cast<SelectOp>(op), dim, cstr); + } +}; } // namespace } // namespace arith } // namespace mlir @@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels( arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx); arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx); arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx); + arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx); }); } diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir index 83d5f1c9c9e86c..8fb3ba1a1eccef 100644 --- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir @@ -74,3 +74,34 @@ func.func @arith_const() -> index { %0 = "test.reify_bound"(%c5) : (index) -> (index) return %0 : index } + +// ----- + +// CHECK-LABEL: func @arith_select( +func.func @arith_select(%c: i1) -> (index, index) { + // CHECK: arith.constant 5 : index + %c5 = arith.constant 5 : index + // CHECK: arith.constant 9 : index + %c9 = arith.constant 9 : index + %r = arith.select %c, %c5, %c9 : index + // CHECK: %[[c5:.*]] = arith.constant 5 : index + // CHECK: %[[c10:.*]] = arith.constant 10 : index + %0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) + %1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index) + // CHECK: return %[[c5]], %[[c10]] + return %0, %1 : index, index +} + +// ----- + +// CHECK-LABEL: func @arith_select_elementwise( +// CHECK-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>) +func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index { + %r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32> + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]] + %0 = "test.reify_bound"(%r) {type = "EQ", dim = 0} + : (tensor<?xf32>) -> (index) + // CHECK: return %[[dim]] + return %0 : index +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits