[llvm-branch-commits] [mlir] [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify `scf.for` implementation (PR #86239)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Mar 21 21:00:32 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/86239

This commit simplifies the implementation of the `ValueBoundsOpInterface` for `scf.for` based on the newly added `ValueBoundsConstraintSet::compare` API and adds additional documentation.

Previously, the interface implementation created a new constraint set just to check if the yielded value and iter_arg are equal. This was inefficient because constraints were added multiple times (to two different constraint sets) for ops that are inside the loop.

>From 6f5ad656892a65eb0a0d5db889dbf47e4cb9929c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 22 Mar 2024 03:50:48 +0000
Subject: [PATCH] [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify
 `scf.for` implementation

This commit simplifies the implementation of the `ValueBoundsOpInterface` for `scf.for` based on the newly added `ValueBoundsConstraintSet::compare` API and adds additional documentation.

Previously, the interface implementation created a new constraint set just to check if the yielded value and iter_arg are equal. This was inefficient because constraints were added multiple times (to two different constraint sets) for ops that are inside the loop.
---
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     | 80 +++++++++----------
 1 file changed, 36 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 72a25d0f0b30b0..9400e7abf87805 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -12,7 +12,6 @@
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
-using presburger::BoundType;
 
 namespace mlir {
 namespace scf {
@@ -21,7 +20,28 @@ namespace {
 struct ForOpInterface
     : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
 
-  /// Populate bounds of values/dimensions for iter_args/OpResults.
+  /// Populate bounds of values/dimensions for iter_args/OpResults. If the
+  /// value/dimension size does not change in an iteration, we can deduce that
+  /// it the same as the initial value/dimension.
+  ///
+  /// Example 1:
+  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
+  ///   ...
+  ///   %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
+  ///   scf.yield %1 : tensor<?xf32>
+  /// }
+  /// --> bound(%0)[0] == bound(%t)[0]
+  /// --> bound(%arg0)[0] == bound(%t)[0]
+  ///
+  /// Example 2:
+  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
+  ///   %sz = tensor.dim %arg0 : tensor<?xf32>
+  ///   %incr = arith.addi %sz, %c1 : index
+  ///   %1 = tensor.empty(%incr) : tensor<?xf32>
+  ///   scf.yield %1 : tensor<?xf32>
+  /// }
+  /// --> The yielded tensor dimension size changes with each iteration. Such
+  ///     loops are not supported and no constraints are added.
   static void populateIterArgBounds(scf::ForOp forOp, Value value,
                                     std::optional<int64_t> dim,
                                     ValueBoundsConstraintSet &cstr) {
@@ -33,59 +53,31 @@ struct ForOpInterface
       iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
     }
 
-    // An EQ constraint can be added if the yielded value (dimension size)
-    // equals the corresponding block argument (dimension size).
     Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
                              .getOperand(iterArgIdx);
     Value iterArg = forOp.getRegionIterArg(iterArgIdx);
     Value initArg = forOp.getInitArgs()[iterArgIdx];
 
-    auto addEqBound = [&]() {
+    // Populate constraints for the yielded value.
+    cstr.populateConstraints(yieldedValue, dim);
+    // Populate constraints for the iter_arg. This is just to ensure that the
+    // iter_arg is mapped in the constraint set, which is a prerequisite for
+    // `compare`. It may lead to a recursive call to this function in case the
+    // iter_arg was not visited when the constraints for the yielded value were
+    // populated, but no additional work is done.
+    cstr.populateConstraints(iterArg, dim);
+
+    // An EQ constraint can be added if the yielded value (dimension size)
+    // equals the corresponding block argument (dimension size).
+    if (cstr.compare(yieldedValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
+                     dim)) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
         cstr.bound(value) == initArg;
       }
-    };
-
-    if (yieldedValue == iterArg) {
-      addEqBound();
-      return;
-    }
-
-    // Compute EQ bound for yielded value.
-    AffineMap bound;
-    ValueDimList boundOperands;
-    LogicalResult status = ValueBoundsConstraintSet::computeBound(
-        bound, boundOperands, BoundType::EQ, yieldedValue, dim,
-        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
-          // Stop when reaching a block argument of the loop body.
-          if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
-            return bbArg.getOwner()->getParentOp() == forOp;
-          // Stop when reaching a value that is defined outside of the loop. It
-          // is impossible to reach an iter_arg from there.
-          Operation *op = v.getDefiningOp();
-          return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
-        });
-    if (failed(status))
-      return;
-    if (bound.getNumResults() != 1)
-      return;
-
-    // Check if computed bound equals the corresponding iter_arg.
-    Value singleValue = nullptr;
-    std::optional<int64_t> singleDim;
-    if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
-      int64_t idx = dimExpr.getPosition();
-      singleValue = boundOperands[idx].first;
-      singleDim = boundOperands[idx].second;
-    } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
-      int64_t idx = symExpr.getPosition() + bound.getNumDims();
-      singleValue = boundOperands[idx].first;
-      singleDim = boundOperands[idx].second;
     }
-    if (singleValue == iterArg && singleDim == dim)
-      addEqBound();
   }
 
   void populateBoundsForIndexValue(Operation *op, Value value,



More information about the llvm-branch-commits mailing list