[Mlir-commits] [mlir] [mlir] Expose linearize/delinearize lowering transforms (PR #144156)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 13:26:09 PDT 2025


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/144156

Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings.

>From c961bef079a8bed326c6b0537ba3cd32348d68af Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 13 Jun 2025 20:15:50 +0000
Subject: [PATCH] [mlir] Expose linearize/delinearize lowering transforms

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../Dialect/Affine/Transforms/Transforms.h    |  13 ++
 .../Transforms/AffineExpandIndexOps.cpp       | 218 +++++++++---------
 2 files changed, 124 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index bf830a29613fd..779571e911e1d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 
@@ -33,6 +34,18 @@ enum class BoundType;
 namespace affine {
 class AffineApplyOp;
 
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+                                            AffineDelinearizeIndexOp op);
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
+LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+                                          AffineLinearizeIndexOp op);
+
 /// Populate patterns that expand affine index operations into more fundamental
 /// operations (not necessarily restricted to Affine dialect).
 void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 35205a6ca2eee..c0ef28c648ac5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
   return result;
 }
 
+LogicalResult
+affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+                                      AffineDelinearizeIndexOp op) {
+  Location loc = op.getLoc();
+  Value linearIdx = op.getLinearIndex();
+  unsigned numResults = op.getNumResults();
+  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+  if (numResults == staticBasis.size())
+    staticBasis = staticBasis.drop_front();
+
+  if (numResults == 1) {
+    rewriter.replaceOp(op, linearIdx);
+    return success();
+  }
+
+  SmallVector<Value> results;
+  results.reserve(numResults);
+  SmallVector<Value> strides =
+      computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                     /*knownNonNegative=*/true);
+
+  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+  Value initialPart =
+      rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+  results.push_back(initialPart);
+
+  auto emitModTerm = [&](Value stride) -> Value {
+    Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+    Value remainderNegative = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::slt, remainder, zero);
+    // If the correction is relevant, this term is <= stride, which is known
+    // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+    // this branch won't be taken, so the risk of `poison` is fine.
+    Value corrected = rewriter.create<arith::AddIOp>(
+        loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
+    Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+                                                 corrected, remainder);
+    return mod;
+  };
+
+  // Generate all the intermediate parts
+  for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+    Value thisStride = strides[i];
+    Value nextStride = strides[i + 1];
+    Value modulus = emitModTerm(thisStride);
+    // We know both inputs are positive, so floorDiv == div.
+    // This could potentially be a divui, but it's not clear if that would
+    // cause issues.
+    Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+    results.push_back(divided);
+  }
+
+  results.push_back(emitModTerm(strides.back()));
+
+  rewriter.replaceOp(op, results);
+  return success();
+}
+
+LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+                                                  AffineLinearizeIndexOp op) {
+  // Should be folded away, included here for safety.
+  if (op.getMultiIndex().empty()) {
+    rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+    return success();
+  }
+
+  Location loc = op.getLoc();
+  ValueRange multiIndex = op.getMultiIndex();
+  size_t numIndexes = multiIndex.size();
+  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+  if (numIndexes == staticBasis.size())
+    staticBasis = staticBasis.drop_front();
+
+  SmallVector<Value> strides =
+      computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                     /*knownNonNegative=*/op.getDisjoint());
+  SmallVector<std::pair<Value, int64_t>> scaledValues;
+  scaledValues.reserve(numIndexes);
+
+  // Note: strides doesn't contain a value for the final element (stride 1)
+  // and everything else lines up. We use the "mutable" accessor so we can get
+  // our hands on an `OpOperand&` for the loop invariant counting function.
+  for (auto [stride, idxOp] :
+       llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+    Value scaledIdx = rewriter.create<arith::MulIOp>(
+        loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+    int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+    scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+  }
+  scaledValues.emplace_back(
+      multiIndex.back(),
+      numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+  // Sort by how many enclosing loops there are, ties implicitly broken by
+  // size of the stride.
+  llvm::stable_sort(scaledValues,
+                    [&](auto l, auto r) { return l.second > r.second; });
+
+  Value result = scaledValues.front().first;
+  for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
+    std::ignore = numHoistableLoops;
+    result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+                                            arith::IntegerOverflowFlags::nsw);
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 namespace {
-/// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
 struct LowerDelinearizeIndexOps
     : public OpRewritePattern<AffineDelinearizeIndexOp> {
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value linearIdx = op.getLinearIndex();
-    unsigned numResults = op.getNumResults();
-    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
-    if (numResults == staticBasis.size())
-      staticBasis = staticBasis.drop_front();
-
-    if (numResults == 1) {
-      rewriter.replaceOp(op, linearIdx);
-      return success();
-    }
-
-    SmallVector<Value> results;
-    results.reserve(numResults);
-    SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
-                       /*knownNonNegative=*/true);
-
-    Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-
-    Value initialPart =
-        rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
-    results.push_back(initialPart);
-
-    auto emitModTerm = [&](Value stride) -> Value {
-      Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
-      Value remainderNegative = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::slt, remainder, zero);
-      // If the correction is relevant, this term is <= stride, which is known
-      // to be positive in `index`. Otherwise, while 2 * stride might overflow,
-      // this branch won't be taken, so the risk of `poison` is fine.
-      Value corrected = rewriter.create<arith::AddIOp>(
-          loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
-      Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
-                                                   corrected, remainder);
-      return mod;
-    };
-
-    // Generate all the intermediate parts
-    for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
-      Value thisStride = strides[i];
-      Value nextStride = strides[i + 1];
-      Value modulus = emitModTerm(thisStride);
-      // We know both inputs are positive, so floorDiv == div.
-      // This could potentially be a divui, but it's not clear if that would
-      // cause issues.
-      Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
-      results.push_back(divided);
-    }
-
-    results.push_back(emitModTerm(strides.back()));
-
-    rewriter.replaceOp(op, results);
-    return success();
+    return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
   }
 };
 
-/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions. Make a best effort to sort the input indices so that
-/// the most loop-invariant terms are at the left of the additions
-/// to enable loop-invariant code motion.
 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    // Should be folded away, included here for safety.
-    if (op.getMultiIndex().empty()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
-      return success();
-    }
-
-    Location loc = op.getLoc();
-    ValueRange multiIndex = op.getMultiIndex();
-    size_t numIndexes = multiIndex.size();
-    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
-    if (numIndexes == staticBasis.size())
-      staticBasis = staticBasis.drop_front();
-
-    SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
-                       /*knownNonNegative=*/op.getDisjoint());
-    SmallVector<std::pair<Value, int64_t>> scaledValues;
-    scaledValues.reserve(numIndexes);
-
-    // Note: strides doesn't contain a value for the final element (stride 1)
-    // and everything else lines up. We use the "mutable" accessor so we can get
-    // our hands on an `OpOperand&` for the loop invariant counting function.
-    for (auto [stride, idxOp] :
-         llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
-      Value scaledIdx = rewriter.create<arith::MulIOp>(
-          loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
-      int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
-      scaledValues.emplace_back(scaledIdx, numHoistableLoops);
-    }
-    scaledValues.emplace_back(
-        multiIndex.back(),
-        numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
-
-    // Sort by how many enclosing loops there are, ties implicitly broken by
-    // size of the stride.
-    llvm::stable_sort(scaledValues,
-                      [&](auto l, auto r) { return l.second > r.second; });
-
-    Value result = scaledValues.front().first;
-    for (auto [scaledValue, numHoistableLoops] :
-         llvm::drop_begin(scaledValues)) {
-      std::ignore = numHoistableLoops;
-      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
-                                              arith::IntegerOverflowFlags::nsw);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
+    return affine::lowerAffineLinearizeIndexOp(rewriter, op);
   }
 };
 



More information about the Mlir-commits mailing list