[Mlir-commits] [mlir] 0ba3e96 - [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify `scf.for` implementation (#87862)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 5 23:30:30 PDT 2024


Author: Matthias Springer
Date: 2024-04-06T15:30:26+09:00
New Revision: 0ba3e96be114dcbe0ac6813a1d0e2940d2a88229

URL: https://github.com/llvm/llvm-project/commit/0ba3e96be114dcbe0ac6813a1d0e2940d2a88229
DIFF: https://github.com/llvm/llvm-project/commit/0ba3e96be114dcbe0ac6813a1d0e2940d2a88229.diff

LOG: [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify `scf.for` implementation (#87862)

This commit simplifies the implementation of the
`ValueBoundsOpInterface` for `scf.for` based on the newly added
`ValueBoundsConstraintSet::compare` API and adds additional
documentation.

Previously, the interface implementation created a new constraint set
just to check if the yielded value and iter_arg are equal. This was
inefficient because constraints were added multiple times (to two
different constraint sets) for ops that are inside the loop.

Note: This is a re-upload of #86239.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 8e9d1021f93e4b..72c5aaa2306783 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -12,7 +12,6 @@
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
-using presburger::BoundType;
 
 namespace mlir {
 namespace scf {
@@ -21,7 +20,28 @@ namespace {
 struct ForOpInterface
     : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
 
-  /// Populate bounds of values/dimensions for iter_args/OpResults.
+  /// Populate bounds of values/dimensions for iter_args/OpResults. If the
+  /// value/dimension size does not change in an iteration, we can deduce that
+  /// it the same as the initial value/dimension.
+  ///
+  /// Example 1:
+  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
+  ///   ...
+  ///   %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
+  ///   scf.yield %1 : tensor<?xf32>
+  /// }
+  /// --> bound(%0)[0] == bound(%t)[0]
+  /// --> bound(%arg0)[0] == bound(%t)[0]
+  ///
+  /// Example 2:
+  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
+  ///   %sz = tensor.dim %arg0 : tensor<?xf32>
+  ///   %incr = arith.addi %sz, %c1 : index
+  ///   %1 = tensor.empty(%incr) : tensor<?xf32>
+  ///   scf.yield %1 : tensor<?xf32>
+  /// }
+  /// --> The yielded tensor dimension size changes with each iteration. Such
+  ///     loops are not supported and no constraints are added.
   static void populateIterArgBounds(scf::ForOp forOp, Value value,
                                     std::optional<int64_t> dim,
                                     ValueBoundsConstraintSet &cstr) {
@@ -33,59 +53,31 @@ struct ForOpInterface
       iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
     }
 
-    // An EQ constraint can be added if the yielded value (dimension size)
-    // equals the corresponding block argument (dimension size).
     Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
                              .getOperand(iterArgIdx);
     Value iterArg = forOp.getRegionIterArg(iterArgIdx);
     Value initArg = forOp.getInitArgs()[iterArgIdx];
 
-    auto addEqBound = [&]() {
+    // Populate constraints for the yielded value.
+    cstr.populateConstraints(yieldedValue, dim);
+    // Populate constraints for the iter_arg. This is just to ensure that the
+    // iter_arg is mapped in the constraint set, which is a prerequisite for
+    // `compare`. It may lead to a recursive call to this function in case the
+    // iter_arg was not visited when the constraints for the yielded value were
+    // populated, but no additional work is done.
+    cstr.populateConstraints(iterArg, dim);
+
+    // An EQ constraint can be added if the yielded value (dimension size)
+    // equals the corresponding block argument (dimension size).
+    if (cstr.compare(yieldedValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
+                     dim)) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
         cstr.bound(value) == initArg;
       }
-    };
-
-    if (yieldedValue == iterArg) {
-      addEqBound();
-      return;
-    }
-
-    // Compute EQ bound for yielded value.
-    AffineMap bound;
-    ValueDimList boundOperands;
-    LogicalResult status = ValueBoundsConstraintSet::computeBound(
-        bound, boundOperands, BoundType::EQ, yieldedValue, dim,
-        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
-          // Stop when reaching a block argument of the loop body.
-          if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
-            return bbArg.getOwner()->getParentOp() == forOp;
-          // Stop when reaching a value that is defined outside of the loop. It
-          // is impossible to reach an iter_arg from there.
-          Operation *op = v.getDefiningOp();
-          return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
-        });
-    if (failed(status))
-      return;
-    if (bound.getNumResults() != 1)
-      return;
-
-    // Check if computed bound equals the corresponding iter_arg.
-    Value singleValue = nullptr;
-    std::optional<int64_t> singleDim;
-    if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
-      int64_t idx = dimExpr.getPosition();
-      singleValue = boundOperands[idx].first;
-      singleDim = boundOperands[idx].second;
-    } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
-      int64_t idx = symExpr.getPosition() + bound.getNumDims();
-      singleValue = boundOperands[idx].first;
-      singleDim = boundOperands[idx].second;
     }
-    if (singleValue == iterArg && singleDim == dim)
-      addEqBound();
   }
 
   void populateBoundsForIndexValue(Operation *op, Value value,


        


More information about the Mlir-commits mailing list