[Mlir-commits] [mlir] [mlir][SCF] Use Affine ops for indexing math. (PR #108450)

Han-Chung Wang llvmlistbot at llvm.org
Tue Sep 17 04:47:29 PDT 2024


================
@@ -4534,6 +4534,140 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+namespace {
+
+// Drops delinearization indices that correspond to unit-extent basis
+struct DropUnitExtentBasis
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
+    std::optional<Value> zero = std::nullopt;
+    Location loc = delinearizeOp->getLoc();
+    auto getZero = [&]() -> Value {
+      if (!zero) {
+        zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+      }
+      return zero.value();
+    };
+
+    // Replace all indices corresponding to unit-extent basis with 0.
+    // Remaining basis can be used to get a new `affine.delinearize_index` op.
+    SmallVector<Value> newOperands;
+    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
+      if (matchPattern(basis, m_One())) {
+        replacements[index] = getZero();
+      } else {
+        newOperands.push_back(basis);
+      }
+    }
+
+    if (newOperands.size() == delinearizeOp.getBasis().size()) {
+      return failure();
+    }
+
+    if (!newOperands.empty()) {
+      auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
+          loc, delinearizeOp.getLinearIndex(), newOperands);
+      int newIndex = 0;
+      // Map back the new delinearized indices to the values they replace.
+      for (auto i : llvm::seq<size_t>(0, replacements.size())) {
+        if (replacements[i])
+          continue;
+        replacements[i] = newDelinearizeOp->getResult(newIndex++);
+      }
+    }
+
+    rewriter.replaceOp(delinearizeOp, replacements);
+    return success();
+  }
+};
+
+/// Drop delinearization pattern related to loops in the following way
+///
+/// ```
+/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
+///   %0 = affine.delinearize_index %iv into (%ub) : index
+///   <some_use>(%0)
+/// }
+/// ```
+///
+/// can be canonicalized to
+///
+/// ```
+/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
+///   <some_use>(%iv)
+/// }
+/// ```
+struct DropDelinearizeOfSingleLoop
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto basis = delinearizeOp.getBasis();
+    if (basis.size() != 1) {
+      return failure();
+    }
+
+    // Check that the `linear_index` is an induction variable.
+    auto inductionVar = cast<BlockArgument>(delinearizeOp.getLinearIndex());
+    if (!inductionVar)
+      return failure();
+
+    // Check that the parent is a `LoopLikeOpInterface`.
+    auto loopLikeOp = cast<LoopLikeOpInterface>(
+        inductionVar.getParentRegion()->getParentOp());
+    if (!loopLikeOp) {
+      return failure();
+    }
+
+    // Check that loop is unit-rank and that the `linear_index` is the induction
+    // variable.
+    auto inductionVars = loopLikeOp.getLoopInductionVars();
+    if (!inductionVars || inductionVars->size() != 1 ||
+        inductionVars->front() != inductionVar) {
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "`linear_index` is not loop induction variable");
+    }
+
+    // Check that the upper-bound is the basis.
+    auto upperBounds = loopLikeOp.getLoopUpperBounds();
+    if (!upperBounds || upperBounds->size() != 1 ||
+        upperBounds->front() != getAsOpFoldResult(basis.front())) {
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "`basis` is not upper bound");
+    }
+
+    // Check that the lower bound is zero.
+    auto lowerBounds = loopLikeOp.getLoopLowerBounds();
+    if (!lowerBounds || lowerBounds->size() != 1 ||
+        !isConstantIntValue(lowerBounds->front(), 0)) {
----------------
hanhanW wrote:

optional nit: this can be `!isZeroIndex(lowerBounds->front())` which is shorter.

https://github.com/llvm/llvm-project/pull/108450


More information about the Mlir-commits mailing list