[llvm-branch-commits] [mlir] [mlir][Arith] `ValueBoundsOpInterface`: Support `arith.select` (PR #86383)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Mar 23 01:34:00 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#<!-- -->85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
---
Full diff: https://github.com/llvm/llvm-project/pull/86383.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+70)
- (modified) mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir (+31)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 9c6b50e767ea26..bb7b9c939fcb09 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,6 +66,75 @@ struct MulIOpInterface
}
};
+struct SelectOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+
+ static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ Value value = selectOp.getResult();
+ Value condition = selectOp.getCondition();
+ Value trueValue = selectOp.getTrueValue();
+ Value falseValue = selectOp.getFalseValue();
+
+ if (isa<ShapedType>(condition.getType())) {
+ // If the condition is a shaped type, the condition is applied
+ // element-wise. All three operands must have the same shape.
+ cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+ cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+ cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+ return;
+ }
+
+ // Populate constraints for the true/false values (and all values on the
+ // backward slice, as long as the current stop condition is not satisfied).
+ cstr.populateConstraints(trueValue, dim);
+ cstr.populateConstraints(falseValue, dim);
+ auto boundsBuilder = cstr.bound(value);
+ if (dim)
+ boundsBuilder[*dim];
+
+ // Compare yielded values.
+ // If trueValue <= falseValue:
+ // * result <= falseValue
+ // * result >= trueValue
+ if (cstr.compare(trueValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ falseValue, dim)) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+ } else {
+ cstr.bound(value) >= trueValue;
+ cstr.bound(value) <= falseValue;
+ }
+ }
+ // If falseValue <= trueValue:
+ // * result <= trueValue
+ // * result >= falseValue
+ if (cstr.compare(falseValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ trueValue, dim)) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+ } else {
+ cstr.bound(value) >= falseValue;
+ cstr.bound(value) <= trueValue;
+ }
+ }
+ }
+
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+ }
+
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<SelectOp>(op), dim, cstr);
+ }
+};
} // namespace
} // namespace arith
} // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
+ arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
index 83d5f1c9c9e86c..8fb3ba1a1eccef 100644
--- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
%0 = "test.reify_bound"(%c5) : (index) -> (index)
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: func @arith_select(
+func.func @arith_select(%c: i1) -> (index, index) {
+ // CHECK: arith.constant 5 : index
+ %c5 = arith.constant 5 : index
+ // CHECK: arith.constant 9 : index
+ %c9 = arith.constant 9 : index
+ %r = arith.select %c, %c5, %c9 : index
+ // CHECK: %[[c5:.*]] = arith.constant 5 : index
+ // CHECK: %[[c10:.*]] = arith.constant 10 : index
+ %0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+ %1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+ // CHECK: return %[[c5]], %[[c10]]
+ return %0, %1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @arith_select_elementwise(
+// CHECK-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
+func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
+ %r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
+ %0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
+ : (tensor<?xf32>) -> (index)
+ // CHECK: return %[[dim]]
+ return %0 : index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/86383
More information about the llvm-branch-commits
mailing list