[Mlir-commits] [mlir] [mlir] Expose linearize/delinearize lowering transforms (PR #144156)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 13:26:09 PDT 2025
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/144156
Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings.
>From c961bef079a8bed326c6b0537ba3cd32348d68af Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 13 Jun 2025 20:15:50 +0000
Subject: [PATCH] [mlir] Expose linearize/delinearize lowering transforms
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Dialect/Affine/Transforms/Transforms.h | 13 ++
.../Transforms/AffineExpandIndexOps.cpp | 218 +++++++++---------
2 files changed, 124 insertions(+), 107 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index bf830a29613fd..779571e911e1d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
@@ -33,6 +34,18 @@ enum class BoundType;
namespace affine {
class AffineApplyOp;
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+ AffineDelinearizeIndexOp op);
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
+LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+ AffineLinearizeIndexOp op);
+
/// Populate patterns that expand affine index operations into more fundamental
/// operations (not necessarily restricted to Affine dialect).
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 35205a6ca2eee..c0ef28c648ac5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
return result;
}
+LogicalResult
+affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+ AffineDelinearizeIndexOp op) {
+ Location loc = op.getLoc();
+ Value linearIdx = op.getLinearIndex();
+ unsigned numResults = op.getNumResults();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numResults == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ if (numResults == 1) {
+ rewriter.replaceOp(op, linearIdx);
+ return success();
+ }
+
+ SmallVector<Value> results;
+ results.reserve(numResults);
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/true);
+
+ Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ Value initialPart =
+ rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+ results.push_back(initialPart);
+
+ auto emitModTerm = [&](Value stride) -> Value {
+ Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+ Value remainderNegative = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, remainder, zero);
+ // If the correction is relevant, this term is <= stride, which is known
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+ // this branch won't be taken, so the risk of `poison` is fine.
+ Value corrected = rewriter.create<arith::AddIOp>(
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
+ Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+ corrected, remainder);
+ return mod;
+ };
+
+ // Generate all the intermediate parts
+ for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+ Value thisStride = strides[i];
+ Value nextStride = strides[i + 1];
+ Value modulus = emitModTerm(thisStride);
+ // We know both inputs are positive, so floorDiv == div.
+ // This could potentially be a divui, but it's not clear if that would
+ // cause issues.
+ Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+ results.push_back(divided);
+ }
+
+ results.push_back(emitModTerm(strides.back()));
+
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+ AffineLinearizeIndexOp op) {
+ // Should be folded away, included here for safety.
+ if (op.getMultiIndex().empty()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+ return success();
+ }
+
+ Location loc = op.getLoc();
+ ValueRange multiIndex = op.getMultiIndex();
+ size_t numIndexes = multiIndex.size();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numIndexes == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/op.getDisjoint());
+ SmallVector<std::pair<Value, int64_t>> scaledValues;
+ scaledValues.reserve(numIndexes);
+
+ // Note: strides doesn't contain a value for the final element (stride 1)
+ // and everything else lines up. We use the "mutable" accessor so we can get
+ // our hands on an `OpOperand&` for the loop invariant counting function.
+ for (auto [stride, idxOp] :
+ llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+ Value scaledIdx = rewriter.create<arith::MulIOp>(
+ loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+ int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+ scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+ }
+ scaledValues.emplace_back(
+ multiIndex.back(),
+ numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+ // Sort by how many enclosing loops there are, ties implicitly broken by
+ // size of the stride.
+ llvm::stable_sort(scaledValues,
+ [&](auto l, auto r) { return l.second > r.second; });
+
+ Value result = scaledValues.front().first;
+ for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
+ std::ignore = numHoistableLoops;
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
namespace {
-/// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
struct LowerDelinearizeIndexOps
: public OpRewritePattern<AffineDelinearizeIndexOp> {
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value linearIdx = op.getLinearIndex();
- unsigned numResults = op.getNumResults();
- ArrayRef<int64_t> staticBasis = op.getStaticBasis();
- if (numResults == staticBasis.size())
- staticBasis = staticBasis.drop_front();
-
- if (numResults == 1) {
- rewriter.replaceOp(op, linearIdx);
- return success();
- }
-
- SmallVector<Value> results;
- results.reserve(numResults);
- SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
- /*knownNonNegative=*/true);
-
- Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-
- Value initialPart =
- rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
- results.push_back(initialPart);
-
- auto emitModTerm = [&](Value stride) -> Value {
- Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
- Value remainderNegative = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zero);
- // If the correction is relevant, this term is <= stride, which is known
- // to be positive in `index`. Otherwise, while 2 * stride might overflow,
- // this branch won't be taken, so the risk of `poison` is fine.
- Value corrected = rewriter.create<arith::AddIOp>(
- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
- Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
- corrected, remainder);
- return mod;
- };
-
- // Generate all the intermediate parts
- for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
- Value thisStride = strides[i];
- Value nextStride = strides[i + 1];
- Value modulus = emitModTerm(thisStride);
- // We know both inputs are positive, so floorDiv == div.
- // This could potentially be a divui, but it's not clear if that would
- // cause issues.
- Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
- results.push_back(divided);
- }
-
- results.push_back(emitModTerm(strides.back()));
-
- rewriter.replaceOp(op, results);
- return success();
+ return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
}
};
-/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions. Make a best effort to sort the input indices so that
-/// the most loop-invariant terms are at the left of the additions
-/// to enable loop-invariant code motion.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- // Should be folded away, included here for safety.
- if (op.getMultiIndex().empty()) {
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
- return success();
- }
-
- Location loc = op.getLoc();
- ValueRange multiIndex = op.getMultiIndex();
- size_t numIndexes = multiIndex.size();
- ArrayRef<int64_t> staticBasis = op.getStaticBasis();
- if (numIndexes == staticBasis.size())
- staticBasis = staticBasis.drop_front();
-
- SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
- /*knownNonNegative=*/op.getDisjoint());
- SmallVector<std::pair<Value, int64_t>> scaledValues;
- scaledValues.reserve(numIndexes);
-
- // Note: strides doesn't contain a value for the final element (stride 1)
- // and everything else lines up. We use the "mutable" accessor so we can get
- // our hands on an `OpOperand&` for the loop invariant counting function.
- for (auto [stride, idxOp] :
- llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx = rewriter.create<arith::MulIOp>(
- loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
- int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
- scaledValues.emplace_back(scaledIdx, numHoistableLoops);
- }
- scaledValues.emplace_back(
- multiIndex.back(),
- numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
-
- // Sort by how many enclosing loops there are, ties implicitly broken by
- // size of the stride.
- llvm::stable_sort(scaledValues,
- [&](auto l, auto r) { return l.second > r.second; });
-
- Value result = scaledValues.front().first;
- for (auto [scaledValue, numHoistableLoops] :
- llvm::drop_begin(scaledValues)) {
- std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
- arith::IntegerOverflowFlags::nsw);
- }
- rewriter.replaceOp(op, result);
- return success();
+ return affine::lowerAffineLinearizeIndexOp(rewriter, op);
}
};
More information about the Mlir-commits
mailing list