[llvm-branch-commits] [mlir] [mlir][Arith] `ValueBoundsOpInterface`: Support `arith.select` (PR #86383)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Mar 23 01:33:29 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/86383

This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.

>From 680d04d71e663aac51ea8f4dc4885d0bfd050b19 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sat, 23 Mar 2024 08:24:46 +0000
Subject: [PATCH] [mlir][Arith] `ValueBoundsOpInterface`: Support
 `arith.select`

This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
---
 .../Arith/IR/ValueBoundsOpInterfaceImpl.cpp   | 70 +++++++++++++++++++
 .../Arith/value-bounds-op-interface-impl.mlir | 31 ++++++++
 2 files changed, 101 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 9c6b50e767ea26..bb7b9c939fcb09 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,6 +66,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(trueValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     falseValue, dim)) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(falseValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     trueValue, dim)) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
     arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
     arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
     arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
+    arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
index 83d5f1c9c9e86c..8fb3ba1a1eccef 100644
--- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
   %0 = "test.reify_bound"(%c5) : (index) -> (index)
   return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @arith_select(
+func.func @arith_select(%c: i1) -> (index, index) {
+  // CHECK: arith.constant 5 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: arith.constant 9 : index
+  %c9 = arith.constant 9 : index
+  %r = arith.select %c, %c5, %c9 : index
+  // CHECK: %[[c5:.*]] = arith.constant 5 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: return %[[c5]], %[[c10]]
+  return %0, %1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @arith_select_elementwise(
+//  CHECK-SAME:     %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
+func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
+  %r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
+  // CHECK: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
+  %0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  // CHECK: return %[[dim]]
+  return %0 : index
+}



More information about the llvm-branch-commits mailing list