[llvm-branch-commits] [mlir] [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify `scf.for` implementation (PR #86239)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Mar 22 23:04:26 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/86239
>From 3c4adb5458f054634d51e1502736bb3dbebad106 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sat, 23 Mar 2024 06:02:28 +0000
Subject: [PATCH] [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify
`scf.for` implementation
This commit simplifies the implementation of the `ValueBoundsOpInterface` for `scf.for` based on the newly added `ValueBoundsConstraintSet::compare` API and adds additional documentation.
Previously, the interface implementation created a new constraint set just to check if the yielded value and iter_arg are equal. This was inefficient because constraints were added multiple times (to two different constraint sets) for ops that are inside the loop.
---
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 80 +++++++++----------
1 file changed, 36 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 8e9d1021f93e4b..72c5aaa2306783 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -12,7 +12,6 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
-using presburger::BoundType;
namespace mlir {
namespace scf {
@@ -21,7 +20,28 @@ namespace {
struct ForOpInterface
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
- /// Populate bounds of values/dimensions for iter_args/OpResults.
+ /// Populate bounds of values/dimensions for iter_args/OpResults. If the
+ /// value/dimension size does not change in an iteration, we can deduce that
+ /// it the same as the initial value/dimension.
+ ///
+ /// Example 1:
+ /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
+ /// ...
+ /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
+ /// scf.yield %1 : tensor<?xf32>
+ /// }
+ /// --> bound(%0)[0] == bound(%t)[0]
+ /// --> bound(%arg0)[0] == bound(%t)[0]
+ ///
+ /// Example 2:
+ /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
+ /// %sz = tensor.dim %arg0 : tensor<?xf32>
+ /// %incr = arith.addi %sz, %c1 : index
+ /// %1 = tensor.empty(%incr) : tensor<?xf32>
+ /// scf.yield %1 : tensor<?xf32>
+ /// }
+ /// --> The yielded tensor dimension size changes with each iteration. Such
+ /// loops are not supported and no constraints are added.
static void populateIterArgBounds(scf::ForOp forOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
@@ -33,59 +53,31 @@ struct ForOpInterface
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
}
- // An EQ constraint can be added if the yielded value (dimension size)
- // equals the corresponding block argument (dimension size).
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
.getOperand(iterArgIdx);
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];
- auto addEqBound = [&]() {
+ // Populate constraints for the yielded value.
+ cstr.populateConstraints(yieldedValue, dim);
+ // Populate constraints for the iter_arg. This is just to ensure that the
+ // iter_arg is mapped in the constraint set, which is a prerequisite for
+ // `compare`. It may lead to a recursive call to this function in case the
+ // iter_arg was not visited when the constraints for the yielded value were
+ // populated, but no additional work is done.
+ cstr.populateConstraints(iterArg, dim);
+
+ // An EQ constraint can be added if the yielded value (dimension size)
+ // equals the corresponding block argument (dimension size).
+ if (cstr.compare(yieldedValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
+ dim)) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
cstr.bound(value) == initArg;
}
- };
-
- if (yieldedValue == iterArg) {
- addEqBound();
- return;
- }
-
- // Compute EQ bound for yielded value.
- AffineMap bound;
- ValueDimList boundOperands;
- LogicalResult status = ValueBoundsConstraintSet::computeBound(
- bound, boundOperands, BoundType::EQ, yieldedValue, dim,
- [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
- // Stop when reaching a block argument of the loop body.
- if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
- return bbArg.getOwner()->getParentOp() == forOp;
- // Stop when reaching a value that is defined outside of the loop. It
- // is impossible to reach an iter_arg from there.
- Operation *op = v.getDefiningOp();
- return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
- });
- if (failed(status))
- return;
- if (bound.getNumResults() != 1)
- return;
-
- // Check if computed bound equals the corresponding iter_arg.
- Value singleValue = nullptr;
- std::optional<int64_t> singleDim;
- if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
- int64_t idx = dimExpr.getPosition();
- singleValue = boundOperands[idx].first;
- singleDim = boundOperands[idx].second;
- } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
- int64_t idx = symExpr.getPosition() + bound.getNumDims();
- singleValue = boundOperands[idx].first;
- singleDim = boundOperands[idx].second;
}
- if (singleValue == iterArg && singleDim == dim)
- addEqBound();
}
void populateBoundsForIndexValue(Operation *op, Value value,
More information about the llvm-branch-commits
mailing list