[llvm-branch-commits] [mlir] [mlir][SCF] `ValueBoundsConstraintSet`: Support `scf.if` (branches) (PR #85895)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Mar 21 19:04:47 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/85895

>From 8057ddd7f467891b5fec9c1f7426fd06012453fb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 22 Mar 2024 02:03:32 +0000
Subject: [PATCH] [mlir][SCF] `ValueBoundsConstraintSet`: Support preliminary
 support for branches

This commit adds support for `scf.if` to `ValueBoundsConstraintSet`.

Example:
```
%0 = scf.if ... -> index {
  scf.yield %a : index
} else {
  scf.yield %b : index
}
```

The following constraints hold for %0:
* %0 >= min(%a, %b)
* %0 <= max(%a, %b)

Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b:
* %0 >= %a
* %0 <= %b

This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  22 ++++
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  63 ++++++++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  62 +++++++++
 .../SCF/value-bounds-op-interface-impl.mlir   | 119 +++++++++++++++++-
 4 files changed, 264 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 77e1af070c3fe9..ef074bcfe0be87 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -199,6 +199,28 @@ class ValueBoundsConstraintSet {
                        std::optional<int64_t> dim1 = std::nullopt,
                        std::optional<int64_t> dim2 = std::nullopt);
 
+  /// Traverse the IR starting from the given value/dim and add populate
+  /// constraints as long as the currently set stop condition holds. Also
+  /// processes all values/dims that are already on the worklist.
+  void populateConstraints(Value value, std::optional<int64_t> dim);
+
+  /// Comparison operator for `ValueBoundsConstraintSet::compare`.
+  enum ComparisonOperator { LT, LE, EQ, GT, GE };
+
+  /// Try to prove that, based on the current state of this constraint set
+  /// (i.e., without analyzing additional IR or adding new constraints), it can
+  /// be deduced that the first given value/dim is LE/LT/EQ/GT/GE than the
+  /// second given value/dim.
+  ///
+  /// Return "true" if the specified relation between the two values/dims was
+  /// proven to hold. Return "false" if the specified relation could not be
+  /// proven. This could be because the specified relation does in fact not hold
+  /// or because there is not enough information in the constraint set. In other
+  /// words, if we do not know for sure, this function returns "false".
+  bool compare(Value value1, std::optional<int64_t> dim1,
+               ComparisonOperator cmp, Value value2,
+               std::optional<int64_t> dim2);
+
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 1e13e60068ee7f..72a25d0f0b30b0 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -111,6 +111,68 @@ struct ForOpInterface
   }
 };
 
+struct IfOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto ifOp = cast<IfOp>(op);
+    unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+    Value thenValue = ifOp.thenYield().getResults()[resultNum];
+    Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+    // Populate constraints for the yielded value (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(thenValue, /*valueDim=*/std::nullopt);
+    cstr.populateConstraints(elseValue, /*valueDim=*/std::nullopt);
+
+    // Compare yielded values.
+    // If thenValue <= elseValue:
+    // * result <= elseValue
+    // * result >= thenValue
+    if (cstr.compare(thenValue, /*dim1=*/std::nullopt,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     elseValue, /*dim2=*/std::nullopt)) {
+      cstr.bound(value) >= thenValue;
+      cstr.bound(value) <= elseValue;
+    }
+    // If elseValue <= thenValue:
+    // * result <= thenValue
+    // * result >= elseValue
+    if (cstr.compare(elseValue, /*dim1=*/std::nullopt,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     thenValue, /*dim2=*/std::nullopt)) {
+      cstr.bound(value) >= elseValue;
+      cstr.bound(value) <= thenValue;
+    }
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    // See `populateBoundsForIndexValue` for documentation.
+    auto ifOp = cast<IfOp>(op);
+    unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+    Value thenValue = ifOp.thenYield().getResults()[resultNum];
+    Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+    cstr.populateConstraints(thenValue, dim);
+    cstr.populateConstraints(elseValue, dim);
+
+    if (cstr.compare(thenValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     elseValue, dim)) {
+      cstr.bound(value)[dim] >= cstr.getExpr(thenValue, dim);
+      cstr.bound(value)[dim] <= cstr.getExpr(elseValue, dim);
+    }
+    if (cstr.compare(elseValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     thenValue, dim)) {
+      cstr.bound(value)[dim] >= cstr.getExpr(elseValue, dim);
+      cstr.bound(value)[dim] <= cstr.getExpr(thenValue, dim);
+    }
+  }
+};
+
 } // namespace
 } // namespace scf
 } // namespace mlir
@@ -119,5 +181,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
     scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
+    scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
   });
 }
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ec710bbacc758f..c88532d2325f0c 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -575,6 +575,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
                               {{value1, dim1}, {value2, dim2}});
 }
 
+void ValueBoundsConstraintSet::populateConstraints(Value value,
+                                                   std::optional<int64_t> dim) {
+  // `getExpr` pushes the value/dim onto the worklist (unless it was already
+  // analyzed).
+  (void)getExpr(value, dim);
+  // Process all values/dims on the worklist. This may traverse and analyze
+  // additional IR, depending the current stop function.
+  processWorklist();
+}
+
+bool ValueBoundsConstraintSet::compare(Value value1,
+                                       std::optional<int64_t> dim1,
+                                       ComparisonOperator cmp, Value value2,
+                                       std::optional<int64_t> dim2) {
+  // This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
+  // hold.
+  //
+  // Example for ComparisonOperator::LE and index-typed values: We would like to
+  // prove that value1 <= value2. Proof by contradiction: add the inverse
+  // relation (value1 > value2) to the constraint set and check if the resulting
+  // constraint set is "empty" (i.e. has no solution). In that case,
+  // value1 > value2 must be incorrect and we can deduce that value1 <= value2
+  // holds.
+
+  // We cannot use prove anything if the constraint set is already empty.
+  if (cstr.isEmpty()) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "cannot compare value/dims: constraint system is already empty");
+    return false;
+  }
+
+  // EQ can be expressed as LE and GE.
+  if (cmp == EQ)
+    return compare(value1, dim1, ComparisonOperator::LE, value2, dim2) &&
+           compare(value1, dim1, ComparisonOperator::GE, value2, dim2);
+
+  // Construct inequality. For the above example: value1 > value2.
+  // `IntegerRelation` inequalities are expressed in the "flattened" form and
+  // with ">= 0". I.e., value1 - value2 - 1 >= 0.
+  SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
+  if (cmp == LT || cmp == LE) {
+    eq[getPos(value1, dim1)]++;
+    eq[getPos(value2, dim2)]--;
+  } else if (cmp == GT || cmp == GE) {
+    eq[getPos(value1, dim1)]--;
+    eq[getPos(value2, dim2)]++;
+  } else {
+    llvm_unreachable("unsupported comparison operator");
+  }
+  if (cmp == LE || cmp == GE)
+    eq[cstr.getNumDimAndSymbolVars()] -= 1;
+
+  // Add inequality to the constraint set and check if it made the constraint
+  // set empty.
+  int64_t ineqPos = cstr.getNumInequalities();
+  cstr.addInequality(eq);
+  bool isEmpty = cstr.isEmpty();
+  cstr.removeInequality(ineqPos);
+  return isEmpty;
+}
+
 FailureOr<bool>
 ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
                                    std::optional<int64_t> dim1,
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index e4d71415924994..0ea06737886d41 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
-// RUN:     -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN:     -verify-diagnostics -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @scf_for(
 //  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
@@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
   "test.some_use"(%reify1) : (index) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_if_constant(
+func.func @scf_if_constant(%c : i1) {
+  // CHECK: arith.constant 4 : index
+  // CHECK: arith.constant 9 : index
+  %c4 = arith.constant 4 : index
+  %c9 = arith.constant 9 : index
+  %r = scf.if %c -> index {
+    scf.yield %c4 : index
+  } else {
+    scf.yield %c9 : index
+  }
+
+  // CHECK: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+// CHECK-LABEL: func @scf_if_dynamic(
+//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
+  %c4 = arith.constant 4 : index
+  %r = scf.if %c -> index {
+    %add1 = arith.addi %a, %b : index
+    scf.yield %add1 : index
+  } else {
+    %add2 = arith.addi %b, %c4 : index
+    %add3 = arith.addi %add2, %a : index
+    scf.yield %add3 : index
+  }
+
+  // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+  // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[lb]], %[[ub]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
+  %r = scf.if %c -> index {
+    scf.yield %a : index
+  } else {
+    scf.yield %b : index
+  }
+  // The reified bound would be min(%a, %b). min/max expressions are not
+  // supported in reified bounds.
+  // expected-error @below{{could not reify bound}}
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  "test.some_use"(%reify1) : (index) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_tensor_dim(
+func.func @scf_if_tensor_dim(%c : i1) {
+  // CHECK: arith.constant 4 : index
+  // CHECK: arith.constant 9 : index
+  %c4 = arith.constant 4 : index
+  %c9 = arith.constant 9 : index
+  %t1 = tensor.empty(%c4) : tensor<?xf32>
+  %t2 = tensor.empty(%c9) : tensor<?xf32>
+  %r = scf.if %c -> tensor<?xf32> {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    scf.yield %t2 : tensor<?xf32>
+  }
+
+  // CHECK: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @scf_if_eq(
+//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
+  %c0 = arith.constant 0 : index
+  %r = scf.if %c -> index {
+    %add1 = arith.addi %a, %b : index
+    scf.yield %add1 : index
+  } else {
+    %add2 = arith.addi %b, %c0 : index
+    %add3 = arith.addi %add2, %a : index
+    scf.yield %add3 : index
+  }
+
+  // CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+  %reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[eq]])
+  "test.some_use"(%reify1) : (index) -> ()
+  return
+}



More information about the llvm-branch-commits mailing list