[Mlir-commits] [mlir] [mlir][SCF] `ValueBoundsConstraintSet`: Support `scf.if` (branches) (PR #85895)

Matthias Springer llvmlistbot at llvm.org
Thu Mar 21 01:10:36 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/85895

>From 13c9d7f5a70f80de8b2c2ba404b7df2e887f8dba Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 21 Mar 2024 07:58:36 +0000
Subject: [PATCH 1/4] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Add
 columns for constant values/dims

`ValueBoundsConstraintSet` maintains an internal constraint set (`IntegerRelation`), where every analyzed index-typed SSA value or dimension of a shaped type is represented with a dimension/symbol. Prior to this change, index-typed values with a statically known constant value and static shaped type dimensions were not added to the constraint set. Instead, `getExpr` directly returned an affine constrant expression.

With this commit, dynamic and static values/dimension sizes are treated in the same way: in either case, a dimension/symbol is added to the constraint set. This is needed for a subsequent commit that adds support for branches.
---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  5 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 67 ++++++++++++++-----
 2 files changed, 55 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 28dadfb9ecf868..94a8a8b429c801 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -280,12 +280,13 @@ class ValueBoundsConstraintSet {
 
   /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
   /// "false", a dimension is added. The value/dimension is added to the
-  /// worklist.
+  /// worklist if `addToWorklist` is set.
   ///
   /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
   /// cannot be multiplied. Furthermore, bounds can only be queried for
   /// dimensions but not for symbols.
-  int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
+  int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
+                 bool addToWorklist = true);
 
   /// Insert an anonymous column into the constraint set. The column is not
   /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 85abc2df894797..02af3a83166dfb 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -105,25 +105,57 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
   assertValidValueDim(value, dim);
 #endif // NDEBUG
 
+  // Helper function that returns an affine expression that represents column
+  // `pos` in the constraint set.
+  auto getPosExpr = [&](int64_t pos) {
+    assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() &&
+           "invalid position");
+    return pos < cstr.getNumDimVars()
+               ? builder.getAffineDimExpr(pos)
+               : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+  };
+
+  // Check if the value/dim is statically known. In that case, an affine
+  // constant expression should be returned. This allows us to support
+  // multiplications with constants. (Multiplications of two columns in the
+  // constraint set is not supported.)
+  std::optional<int64_t> constSize = std::nullopt;
   auto shapedType = dyn_cast<ShapedType>(value.getType());
   if (shapedType) {
-    // Static dimension: return constant directly.
     if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
-      return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
-  } else {
-    // Constant index value: return directly.
-    if (auto constInt = ::getConstantIntValue(value))
-      return builder.getAffineConstantExpr(*constInt);
+      constSize = shapedType.getDimSize(*dim);
+  } else if (auto constInt = ::getConstantIntValue(value)) {
+    constSize = *constInt;
   }
 
-  // Dynamic value: add to constraint set.
+  // If the value/dim is already mapped, return the corresponding expression
+  // directly.
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  if (!valueDimToPosition.contains(valueDim))
-    (void)insert(value, dim);
-  int64_t pos = getPos(value, dim);
-  return pos < cstr.getNumDimVars()
-             ? builder.getAffineDimExpr(pos)
-             : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+  if (valueDimToPosition.contains(valueDim)) {
+    // If it is a constant, return an affine constant expression. Otherwise,
+    // return an affine expression that represents the respective column in the
+    // constraint set.
+    if (constSize)
+      return builder.getAffineConstantExpr(*constSize);
+    return getPosExpr(getPos(value, dim));
+  }
+
+  if (constSize) {
+    // Constant index value/dim: add column to the constraint set, add EQ bound
+    // and return an affine constant expression without pushing the newly added
+    // column to the worklist.
+    (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
+    if (shapedType)
+      bound(value)[*dim] == *constSize;
+    else
+      bound(value) == *constSize;
+    return builder.getAffineConstantExpr(*constSize);
+  }
+
+  // Dynamic value/dim: insert column to the constraint set and put it on the
+  // worklist. Return an affine expression that represents the newly inserted
+  // column in the constraint set.
+  return getPosExpr(insert(value, dim, /*isSymbol=*/true));
 }
 
 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
@@ -140,7 +172,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
 
 int64_t ValueBoundsConstraintSet::insert(Value value,
                                          std::optional<int64_t> dim,
-                                         bool isSymbol) {
+                                         bool isSymbol, bool addToWorklist) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
 #endif // NDEBUG
@@ -155,7 +187,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
     if (positionToValueDim[i].has_value())
       valueDimToPosition[*positionToValueDim[i]] = i;
 
-  worklist.push(pos);
+  if (addToWorklist) {
+    LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
+                            << " (dim: " << dim.value_or(kIndexValue) << ")\n");
+    worklist.push(pos);
+  }
+
   return pos;
 }
 

>From 12e7e8880621483918e7a3e2a99bf3eec13eb899 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 21 Mar 2024 08:01:01 +0000
Subject: [PATCH 2/4] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`:
 Delete dead code

There is an assertion that the stop condition is not satisfied for the the starting point at the beginning of `computeBound`. Therefore, that case does not have to be handled later on in that function.
---
 mlir/lib/Interfaces/ValueBoundsOpInterface.cpp | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 02af3a83166dfb..f2f732f3a21d25 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -321,18 +321,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   Builder b(value.getContext());
   mapOperands.clear();
 
-  if (stopCondition(value, dim)) {
-    // Special case: If the stop condition is satisfied for the input
-    // value/dimension, directly return it.
-    mapOperands.push_back(std::make_pair(value, dim));
-    AffineExpr bound = b.getAffineDimExpr(0);
-    if (type == BoundType::UB)
-      bound = bound + ubAdjustment;
-    resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
-                               b.getAffineDimExpr(0));
-    return success();
-  }
-
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));

>From db3dde1d9c6e3eb1b85083d1a3545691f47acb7c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 21 Mar 2024 08:04:11 +0000
Subject: [PATCH 3/4] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Pass
 stop condition in the constructor

This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed.

This change is in preparation of adding support for branches.
---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  | 16 +++--
 .../Affine/Transforms/ReifyValueBounds.cpp    |  6 +-
 .../Arith/Transforms/ReifyValueBounds.cpp     |  6 +-
 .../Linalg/Transforms/HoistPadding.cpp        |  2 +-
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  2 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 60 +++++++++++--------
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  9 ++-
 7 files changed, 62 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 94a8a8b429c801..b79c44162ea8ef 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -113,8 +113,9 @@ class ValueBoundsConstraintSet {
   ///
   /// The first parameter of the function is the shaped value/index-typed
   /// value. The second parameter is the dimension in case of a shaped value.
-  using StopConditionFn =
-      function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
+  /// The third parameter is this constraint set.
+  using StopConditionFn = function_ref<bool(
+      Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
   /// Compute a bound for the given index-typed value or shape dimension size.
   /// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -263,12 +264,12 @@ class ValueBoundsConstraintSet {
   /// An index-typed value or the dimension of a shaped-type value.
   using ValueDim = std::pair<Value, int64_t>;
 
-  ValueBoundsConstraintSet(MLIRContext *ctx);
+  ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
 
   /// Iteratively process all elements on the worklist until an index-typed
-  /// value or shaped value meets `stopCondition`. Such values are not processed
-  /// any further.
-  void processWorklist(StopConditionFn stopCondition);
+  /// value or shaped value meets `currentStopCondition`. Such values are not
+  /// processed any further.
+  void processWorklist();
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
@@ -316,6 +317,9 @@ class ValueBoundsConstraintSet {
 
   /// Builder for constructing affine expressions.
   Builder builder;
+
+  /// The current stop condition function.
+  StopConditionFn stopCondition = nullptr;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f76d4465d..117ee8e8701ad7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd1478aa9e6..fad221288f190e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8eebaecb9..c3a08ce86082a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
     FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
         rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
         /*stopCondition=*/
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           if (v == forOp.getUpperBound())
             return false;
           // Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0cecf0d24..1e13e60068ee7f 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
     ValueDimList boundOperands;
     LogicalResult status = ValueBoundsConstraintSet::computeBound(
         bound, boundOperands, BoundType::EQ, yieldedValue, dim,
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           // Stop when reaching a block argument of the loop body.
           if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
             return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f2f732f3a21d25..ec710bbacc758f 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,9 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
-    : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+    MLIRContext *ctx, StopConditionFn stopCondition)
+    : builder(ctx), stopCondition(stopCondition) {}
 
 #ifndef NDEBUG
 static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
@@ -228,7 +229,8 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+  LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -249,13 +251,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
 
     // Do not process any further if the stop condition is met.
     auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
-    if (stopCondition(value, maybeDim))
+    if (stopCondition(value, maybeDim, *this)) {
+      LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+                              << " (dim: " << maybeDim << ")\n");
       continue;
+    }
 
     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
     // the worklist.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+    LLVM_DEBUG(llvm::dbgs()
+               << "Query value bounds for: " << value
+               << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
     if (valueBoundsOp) {
       if (dim == kIndexValue) {
         valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -264,6 +272,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
       }
       continue;
     }
+    LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
 
     // If the op does not implement `ValueBoundsOpInterface`, check if it
     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -313,8 +322,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     bool closedUB) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
-  assert(!stopCondition(value, dim) &&
-         "stop condition should not be satisfied for starting point");
 #endif // NDEBUG
 
   int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -324,9 +331,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext());
+  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+  assert(!stopCondition(value, dim, cstr) &&
+         "stop condition should not be satisfied for starting point");
   int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
-  cstr.processWorklist(stopCondition);
+  cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
@@ -336,7 +345,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
       return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
-    return !stopCondition(p.first, maybeDim);
+    return !stopCondition(p.first, maybeDim, cstr);
   });
 
   // Compute lower and upper bounds for `valueDim`.
@@ -442,7 +451,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     bool closedUB) {
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) {
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
       closedUB);
@@ -478,7 +487,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
   // Reify bounds in terms of any independent values.
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+        return isIndependent(v);
+      },
       closedUB);
 }
 
@@ -500,8 +511,18 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
     presburger::BoundType type, AffineMap map, ValueDimList operands,
     StopConditionFn stopCondition, bool closedUB) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
-  ValueBoundsConstraintSet cstr(map.getContext());
-  int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+  // Default stop condition if none was specified: Keep adding constraints until
+  // a bound could be computed.
+  int64_t pos;
+  auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+                                  ValueBoundsConstraintSet &cstr) {
+    return cstr.cstr.getConstantBound64(type, pos).has_value();
+  };
+
+  ValueBoundsConstraintSet cstr(
+      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  pos = cstr.insert(/*isSymbol=*/false);
 
   // Add map and operands to the constraint set. Dimensions are converted to
   // symbols. All operands are added to the worklist.
@@ -517,17 +538,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
       map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
 
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
-  // until `stopCondition` is met.
-  if (stopCondition) {
-    cstr.processWorklist(stopCondition);
-  } else {
-    // No stop condition specified: Keep adding constraints until a bound could
-    // be computed.
-    cstr.processWorklist(
-        /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
-          return cstr.cstr.getConstantBound64(type, pos).has_value();
-        });
-  }
+  // until the stop condition is met.
+  cstr.processWorklist();
 
   // Compute constant bound for `valueDim`.
   int64_t ubAdjustment = closedUB ? 0 : 1;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 39671a930f2e21..e99a13cdca2f3c 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -112,14 +112,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 
       // Prepare stop condition. By default, reify in terms of the op's
       // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value, std::optional<int64_t>)> stopCondition =
-          [&](Value v, std::optional<int64_t> d) {
+      std::function<bool(Value, std::optional<int64_t>,
+                         ValueBoundsConstraintSet & cstr)>
+          stopCondition = [&](Value v, std::optional<int64_t> d,
+                              ValueBoundsConstraintSet &cstr) {
             // Reify in terms of SSA values that are different from `value`.
             return v != value;
           };
       if (reifyToFuncArgs) {
         // Reify in terms of function block arguments.
-        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
+        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d,
+                                           ValueBoundsConstraintSet &cstr) {
           auto bbArg = dyn_cast<BlockArgument>(v);
           if (!bbArg)
             return false;

>From 21b4e408486925e7a5b3122bb0fd80de7c248115 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 21 Mar 2024 08:07:24 +0000
Subject: [PATCH 4/4] [mlir][SCF] `ValueBoundsConstraintSet`: Support
 preliminary support for branches

This commit adds support for `scf.if` to `ValueBoundsConstraintSet`.

Example:
```
%0 = scf.if ... -> index {
  scf.yield %a : index
} else {
  scf.yield %b : index
}
```

The following constraints hold for %0:
* %0 >= min(%a, %b)
* %0 <= max(%a, %b)

Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b:
* %0 >= %a
* %0 <= %b

This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  22 ++++
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  63 ++++++++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  62 +++++++++
 .../SCF/value-bounds-op-interface-impl.mlir   | 119 +++++++++++++++++-
 4 files changed, 264 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index b79c44162ea8ef..702991b4e0bb95 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -199,6 +199,28 @@ class ValueBoundsConstraintSet {
                        std::optional<int64_t> dim1 = std::nullopt,
                        std::optional<int64_t> dim2 = std::nullopt);
 
+  /// Traverse the IR starting from the given value/dim and add populate
+  /// constraints as long as the currently set stop condition holds. Also
+  /// processes all values/dims that are already on the worklist.
+  void populateConstraints(Value value, std::optional<int64_t> dim);
+
+  /// Comparison operator for `ValueBoundsConstraintSet::compare`.
+  enum ComparisonOperator { LT, LE, EQ, GT, GE };
+
+  /// Try to prove that, based on the current state of this constraint set
+  /// (i.e., without analyzing additional IR or adding new constraints), it can
+  /// be deduced that the first given value/dim is LE/LT/EQ/GT/GE than the
+  /// second given value/dim.
+  ///
+  /// Return "true" if the specified relation between the two values/dims was
+  /// proven to hold. Return "false" if the specified relation could not be
+  /// proven. This could be because the specified relation does in fact not hold
+  /// or because there is not enough information in the constraint set. In other
+  /// words, if we do not know for sure, this function returns "false".
+  bool compare(Value value1, std::optional<int64_t> dim1,
+               ComparisonOperator cmp, Value value2,
+               std::optional<int64_t> dim2);
+
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 1e13e60068ee7f..72a25d0f0b30b0 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -111,6 +111,68 @@ struct ForOpInterface
   }
 };
 
+struct IfOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto ifOp = cast<IfOp>(op);
+    unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+    Value thenValue = ifOp.thenYield().getResults()[resultNum];
+    Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+    // Populate constraints for the yielded value (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(thenValue, /*valueDim=*/std::nullopt);
+    cstr.populateConstraints(elseValue, /*valueDim=*/std::nullopt);
+
+    // Compare yielded values.
+    // If thenValue <= elseValue:
+    // * result <= elseValue
+    // * result >= thenValue
+    if (cstr.compare(thenValue, /*dim1=*/std::nullopt,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     elseValue, /*dim2=*/std::nullopt)) {
+      cstr.bound(value) >= thenValue;
+      cstr.bound(value) <= elseValue;
+    }
+    // If elseValue <= thenValue:
+    // * result <= thenValue
+    // * result >= elseValue
+    if (cstr.compare(elseValue, /*dim1=*/std::nullopt,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     thenValue, /*dim2=*/std::nullopt)) {
+      cstr.bound(value) >= elseValue;
+      cstr.bound(value) <= thenValue;
+    }
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    // See `populateBoundsForIndexValue` for documentation.
+    auto ifOp = cast<IfOp>(op);
+    unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+    Value thenValue = ifOp.thenYield().getResults()[resultNum];
+    Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+    cstr.populateConstraints(thenValue, dim);
+    cstr.populateConstraints(elseValue, dim);
+
+    if (cstr.compare(thenValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     elseValue, dim)) {
+      cstr.bound(value)[dim] >= cstr.getExpr(thenValue, dim);
+      cstr.bound(value)[dim] <= cstr.getExpr(elseValue, dim);
+    }
+    if (cstr.compare(elseValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     thenValue, dim)) {
+      cstr.bound(value)[dim] >= cstr.getExpr(elseValue, dim);
+      cstr.bound(value)[dim] <= cstr.getExpr(thenValue, dim);
+    }
+  }
+};
+
 } // namespace
 } // namespace scf
 } // namespace mlir
@@ -119,5 +181,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
     scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
+    scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
   });
 }
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ec710bbacc758f..c88532d2325f0c 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -575,6 +575,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
                               {{value1, dim1}, {value2, dim2}});
 }
 
+void ValueBoundsConstraintSet::populateConstraints(Value value,
+                                                   std::optional<int64_t> dim) {
+  // `getExpr` pushes the value/dim onto the worklist (unless it was already
+  // analyzed).
+  (void)getExpr(value, dim);
+  // Process all values/dims on the worklist. This may traverse and analyze
+  // additional IR, depending the current stop function.
+  processWorklist();
+}
+
+bool ValueBoundsConstraintSet::compare(Value value1,
+                                       std::optional<int64_t> dim1,
+                                       ComparisonOperator cmp, Value value2,
+                                       std::optional<int64_t> dim2) {
+  // This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
+  // hold.
+  //
+  // Example for ComparisonOperator::LE and index-typed values: We would like to
+  // prove that value1 <= value2. Proof by contradiction: add the inverse
+  // relation (value1 > value2) to the constraint set and check if the resulting
+  // constraint set is "empty" (i.e. has no solution). In that case,
+  // value1 > value2 must be incorrect and we can deduce that value1 <= value2
+  // holds.
+
+  // We cannot use prove anything if the constraint set is already empty.
+  if (cstr.isEmpty()) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "cannot compare value/dims: constraint system is already empty");
+    return false;
+  }
+
+  // EQ can be expressed as LE and GE.
+  if (cmp == EQ)
+    return compare(value1, dim1, ComparisonOperator::LE, value2, dim2) &&
+           compare(value1, dim1, ComparisonOperator::GE, value2, dim2);
+
+  // Construct inequality. For the above example: value1 > value2.
+  // `IntegerRelation` inequalities are expressed in the "flattened" form and
+  // with ">= 0". I.e., value1 - value2 - 1 >= 0.
+  SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
+  if (cmp == LT || cmp == LE) {
+    eq[getPos(value1, dim1)]++;
+    eq[getPos(value2, dim2)]--;
+  } else if (cmp == GT || cmp == GE) {
+    eq[getPos(value1, dim1)]--;
+    eq[getPos(value2, dim2)]++;
+  } else {
+    llvm_unreachable("unsupported comparison operator");
+  }
+  if (cmp == LE || cmp == GE)
+    eq[cstr.getNumDimAndSymbolVars()] -= 1;
+
+  // Add inequality to the constraint set and check if it made the constraint
+  // set empty.
+  int64_t ineqPos = cstr.getNumInequalities();
+  cstr.addInequality(eq);
+  bool isEmpty = cstr.isEmpty();
+  cstr.removeInequality(ineqPos);
+  return isEmpty;
+}
+
 FailureOr<bool>
 ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
                                    std::optional<int64_t> dim1,
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index e4d71415924994..0ea06737886d41 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
-// RUN:     -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN:     -verify-diagnostics -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @scf_for(
 //  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
@@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
   "test.some_use"(%reify1) : (index) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_if_constant(
+func.func @scf_if_constant(%c : i1) {
+  // CHECK: arith.constant 4 : index
+  // CHECK: arith.constant 9 : index
+  %c4 = arith.constant 4 : index
+  %c9 = arith.constant 9 : index
+  %r = scf.if %c -> index {
+    scf.yield %c4 : index
+  } else {
+    scf.yield %c9 : index
+  }
+
+  // CHECK: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+// CHECK-LABEL: func @scf_if_dynamic(
+//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
+  %c4 = arith.constant 4 : index
+  %r = scf.if %c -> index {
+    %add1 = arith.addi %a, %b : index
+    scf.yield %add1 : index
+  } else {
+    %add2 = arith.addi %b, %c4 : index
+    %add3 = arith.addi %add2, %a : index
+    scf.yield %add3 : index
+  }
+
+  // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+  // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[lb]], %[[ub]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
+  %r = scf.if %c -> index {
+    scf.yield %a : index
+  } else {
+    scf.yield %b : index
+  }
+  // The reified bound would be min(%a, %b). min/max expressions are not
+  // supported in reified bounds.
+  // expected-error @below{{could not reify bound}}
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  "test.some_use"(%reify1) : (index) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_tensor_dim(
+func.func @scf_if_tensor_dim(%c : i1) {
+  // CHECK: arith.constant 4 : index
+  // CHECK: arith.constant 9 : index
+  %c4 = arith.constant 4 : index
+  %c9 = arith.constant 9 : index
+  %t1 = tensor.empty(%c4) : tensor<?xf32>
+  %t2 = tensor.empty(%c9) : tensor<?xf32>
+  %r = scf.if %c -> tensor<?xf32> {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    scf.yield %t2 : tensor<?xf32>
+  }
+
+  // CHECK: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @scf_if_eq(
+//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
+  %c0 = arith.constant 0 : index
+  %r = scf.if %c -> index {
+    %add1 = arith.addi %a, %b : index
+    scf.yield %add1 : index
+  } else {
+    %add2 = arith.addi %b, %c0 : index
+    %add3 = arith.addi %add2, %a : index
+    scf.yield %add3 : index
+  }
+
+  // CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+  %reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[eq]])
+  "test.some_use"(%reify1) : (index) -> ()
+  return
+}



More information about the Mlir-commits mailing list