[Mlir-commits] [mlir] [mlir][Affine] Expand affine.[de]linearize_index without affine maps (PR #116703)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Nov 18 14:30:06 PST 2024


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/116703

As the documentation for -affine-expand-index-ops says, affine.delinearize_index and affine.linearize_index don't need to be expanded into the affine dialect.

Expanding these operations into affine.apply operations can introduce unwanted "simplifications", mainly translations of `(dN mod C + ...)` to `(dN + ... - (dN floordiv C) * C)` and similar, which create worse generated code. This commit resolves this issue by expanding out affine.delanierize_index directly.

In addition, the lowering of affine.linearize_index now sorts the operands by loop-independence, allowing an increased amount of loop-invariant code motion after lowering.

The old behavior is preserved as -expand-affine-index-ops-as-affine but is no longer the default

>From ba4ae62b1ab03531b4bc94d08b7c5579d8fdf5c2 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 13 Nov 2024 23:44:29 +0000
Subject: [PATCH] [mlir][Affine] Expand affine.[de]linearize_index without
 affine maps

As the documentation for -affine-expand-index-ops says,
affine.delinearize_index and affine.linearize_index don't need to be
expanded into the affine dialect.

Expanding these operations into affine.apply operations can introduce
unwanted "simplifications", mainly translations of `(dN mod C + ...)`
to `(dN + ... - (dN floordiv C) * C)` and similar, which create worse
generated code. This commit resolves this issue by expanding out
affine.delanierize_index directly.

In addition, the lowering of affine.linearize_index now sorts the
operands by loop-independence, allowing an increased amount of
loop-invariant code motion after lowering.

The old behavior is preserved as -expand-affine-index-ops-as-affine
but is no longer the default
---
 mlir/include/mlir/Dialect/Affine/LoopUtils.h  |   5 +
 mlir/include/mlir/Dialect/Affine/Passes.h     |   4 +
 mlir/include/mlir/Dialect/Affine/Passes.td    |   5 +
 .../Dialect/Affine/Transforms/Transforms.h    |   4 +
 .../Transforms/AffineExpandIndexOps.cpp       | 148 ++++++++++++++++--
 .../AffineExpandIndexOpsAsAffine.cpp          |  98 ++++++++++++
 .../Dialect/Affine/Transforms/CMakeLists.txt  |   1 +
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   |  12 ++
 .../AffineToStandard/lower-affine.mlir        |  66 --------
 .../affine-expand-index-ops-as-affine.mlir    |  70 +++++++++
 .../Affine/affine-expand-index-ops.mlir       | 101 ++++++++----
 11 files changed, 405 insertions(+), 109 deletions(-)
 create mode 100644 mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
 create mode 100644 mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 380c742b5224cb..7fe1f6d48ceebb 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -301,6 +301,11 @@ separateFullTiles(MutableArrayRef<AffineForOp> nest,
 /// Walk an affine.for to find a band to coalesce.
 LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
 
+/// Count the number of loops surrounding `operand` such that operand could be
+/// hoisted above.
+/// Stop counting at the first loop over which the operand cannot be hoisted.
+/// This counts any LoopLikeOpInterface, not just affine.for.
+int64_t numEnclosingInvariantLoops(OpOperand &operand);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 61f24255f305f7..e152101236dc7a 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -116,6 +116,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createPipelineDataTransferPass();
 /// operations (not necessarily restricted to Affine dialect).
 std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
 
+/// Creates a pass to expand affine index operations into affine.apply
+/// operations.
+std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index b08e803345f76e..77073aa29da73e 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -408,4 +408,9 @@ def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
   let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
 }
 
+def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
+  let summary = "Lower affine operations operating on indices into affine.apply operations";
+  let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
+}
+
 #endif // MLIR_DIALECT_AFFINE_PASSES
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index b244d37c0707f2..bf830a29613fdd 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -37,6 +37,10 @@ class AffineApplyOp;
 /// operations (not necessarily restricted to Affine dialect).
 void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns that expand affine index operations into their equivalent
+/// `affine.apply` representations.
+void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns);
+
 /// Helper function to rewrite `op`'s affine map and reorder its operands such
 /// that they are in increasing order of hoistability (i.e. the least hoistable)
 /// operands come first in the operand list.
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 15478e0e1e3a5b..d7b218225bc9ab 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -10,6 +10,7 @@
 // fundamental operations.
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Affine/Passes.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -28,6 +29,50 @@ namespace affine {
 using namespace mlir;
 using namespace mlir::affine;
 
+/// Given a basis (in static and dynamic components), return the sequence of
+/// suffix products of the basis, including the product of the entire basis,
+/// which must **not** contain an outer bound.
+///
+/// If excess dynamic values are provided, the values at the beginning
+/// will be ignored. This allows for dropping the outer bound without
+/// needing to manipulate the dynamic value array.
+static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
+                                         ValueRange dynamicBasis,
+                                         ArrayRef<int64_t> staticBasis) {
+  if (staticBasis.empty())
+    return {};
+
+  SmallVector<Value> result;
+  result.reserve(staticBasis.size());
+  size_t dynamicIndex = dynamicBasis.size();
+  Value dynamicPart = nullptr;
+  int64_t staticPart = 1;
+  for (int64_t elem : llvm::reverse(staticBasis)) {
+    if (ShapedType::isDynamic(elem)) {
+      if (dynamicPart)
+        dynamicPart = rewriter.create<arith::MulIOp>(
+            loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+      else
+        dynamicPart = dynamicBasis[dynamicIndex - 1];
+      --dynamicIndex;
+    } else {
+      staticPart *= elem;
+    }
+
+    if (dynamicPart && staticPart == 1) {
+      result.push_back(dynamicPart);
+    } else {
+      Value stride =
+          rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
+      if (dynamicPart)
+        stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+      result.push_back(stride);
+    }
+  }
+  std::reverse(result.begin(), result.end());
+  return result;
+}
+
 namespace {
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
 /// operations.
@@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
-    if (failed(multiIndex))
-      return failure();
-    rewriter.replaceOp(op, *multiIndex);
+    Location loc = op.getLoc();
+    Value linearIdx = op.getLinearIndex();
+    unsigned numResults = op.getNumResults();
+    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+    if (numResults == staticBasis.size())
+      staticBasis = staticBasis.drop_front();
+
+    if (numResults == 1) {
+      rewriter.replaceOp(op, linearIdx);
+      return success();
+    }
+
+    SmallVector<Value> results;
+    results.reserve(numResults);
+    SmallVector<Value> strides =
+        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+
+    Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+    Value initialPart =
+        rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+    results.push_back(initialPart);
+
+    auto emitModTerm = [&](Value stride) -> Value {
+      Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+      Value remainderNegative = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::slt, remainder, zero);
+      Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+      Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+                                                   corrected, remainder);
+      return mod;
+    };
+
+    // Generate all the intermediate parts
+    for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+      Value thisStride = strides[i];
+      Value nextStride = strides[i + 1];
+      Value modulus = emitModTerm(thisStride);
+      // We know both inputs are positive, so floorDiv == div.
+      // This could potentially be a divui, but it's not clear if that would
+      // cause issues.
+      Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+      results.push_back(divided);
+    }
+
+    results.push_back(emitModTerm(strides.back()));
+
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
 
 /// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions.
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
@@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
       return success();
     }
 
-    SmallVector<OpFoldResult> multiIndex =
-        getAsOpFoldResult(op.getMultiIndex());
-    OpFoldResult linearIndex =
-        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
-    Value linearIndexValue =
-        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
-    rewriter.replaceOp(op, linearIndexValue);
+    Location loc = op.getLoc();
+    ValueRange multiIndex = op.getMultiIndex();
+    size_t numIndexes = multiIndex.size();
+    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+    if (numIndexes == staticBasis.size())
+      staticBasis = staticBasis.drop_front();
+
+    SmallVector<Value> strides =
+        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+    SmallVector<std::pair<Value, int64_t>> scaledValues;
+    scaledValues.reserve(numIndexes);
+
+    // Note: strides doesn't contain a value for the final element (stride 1)
+    // and everything else lines up. We use the "mutable" accessor so we can get
+    // our hands on an `OpOperand&` for the loop invariant counting function.
+    for (auto [stride, idxOp] :
+         llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+      Value scaledIdx =
+          rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+      int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+      scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+    }
+    scaledValues.emplace_back(
+        multiIndex.back(),
+        numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+    // Sort by how many enclosing loops there are, ties implicitly broken by
+    // size of the stride.
+    llvm::stable_sort(scaledValues,
+                      [&](auto l, auto r) { return l.second > r.second; });
+
+    Value result = scaledValues.front().first;
+    for (auto [scaledValue, numHoistableLoops] :
+         llvm::drop_begin(scaledValues)) {
+      std::ignore = numHoistableLoops;
+      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+    }
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
new file mode 100644
index 00000000000000..bfcc1ddf91653a
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -0,0 +1,98 @@
+//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to expand affine index ops into one or more more
+// fundamental operations.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+struct LowerDelinearizeIndexOps
+    : public OpRewritePattern<AffineDelinearizeIndexOp> {
+  using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<SmallVector<Value>> multiIndex =
+        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
+    if (failed(multiIndex))
+      return failure();
+    rewriter.replaceOp(op, *multiIndex);
+    return success();
+  }
+};
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions.
+struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    // Should be folded away, included here for safety.
+    if (op.getMultiIndex().empty()) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      return success();
+    }
+
+    SmallVector<OpFoldResult> multiIndex =
+        getAsOpFoldResult(op.getMultiIndex());
+    OpFoldResult linearIndex =
+        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+    Value linearIndexValue =
+        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+    rewriter.replaceOp(op, linearIndexValue);
+    return success();
+  }
+};
+
+class ExpandAffineIndexOpsAsAffinePass
+    : public affine::impl::AffineExpandIndexOpsAsAffineBase<
+          ExpandAffineIndexOpsAsAffinePass> {
+public:
+  ExpandAffineIndexOpsAsAffinePass() = default;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    populateAffineExpandIndexOpsAsAffinePatterns(patterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
+      patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
+  return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 772f15335d907f..c42789b01bc9fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRAffineTransforms
   AffineDataCopyGeneration.cpp
   AffineExpandIndexOps.cpp
+  AffineExpandIndexOpsAsAffine.cpp
   AffineLoopInvariantCodeMotion.cpp
   AffineLoopNormalize.cpp
   AffineParallelize.cpp
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index d6fc4ed07bfab3..e75d1c571d08cc 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2772,3 +2772,15 @@ LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
   }
   return result;
 }
+
+int64_t mlir::affine::numEnclosingInvariantLoops(OpOperand &operand) {
+  int64_t count = 0;
+  Operation *currentOp = operand.getOwner();
+  while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
+    if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
+      break;
+    currentOp = loopOp;
+    count++;
+  }
+  return count;
+}
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 3be42661f63ee7..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,69 +927,3 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
 // CHECK:      scf.reduce.return %[[RES]] : i64
 // CHECK:    }
 // CHECK:  }
-
-///////////////////////////////////////////////////////////////////////
-
-func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
-  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}
-// CHECK-LABEL:   func.func @test_dilinearize_index(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
-// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
-// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
-// CHECK:           %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
-// CHECK:           %[[VAL_12:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
-// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
-// CHECK:           %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
-// CHECK:           %[[VAL_18:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
-// CHECK:           %[[VAL_20:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
-// CHECK:           %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
-// CHECK:           %[[VAL_24:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_25:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_26:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
-// CHECK:           %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
-// CHECK:           %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
-// CHECK:           %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
-// CHECK:           %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
-// CHECK:           %[[VAL_33:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_34:.*]] = arith.remsi %[[VAL_0]], %[[VAL_33]] : index
-// CHECK:           %[[VAL_35:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_34]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_33]] : index
-// CHECK:           %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index
-// CHECK:           return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index
-// CHECK:         }
-
-/////////////////////////////////////////////////////////////////////
-
-func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
-  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
-  return %ret : index
-}
-
-// CHECK-LABEL: @test_linearize_index
-// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
-// CHECK: %[[c15:.+]] = arith.constant 15 : index
-// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
-// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
-// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
-// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
-// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
-// CHECK-NEXT: return %[[ret]]
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
new file mode 100644
index 00000000000000..bf9f00da5793aa
--- /dev/null
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -affine-expand-index-ops-as-affine -split-input-file | FileCheck %s
+
+//   CHECK-DAG:   #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)>
+//   CHECK-DAG:   #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)>
+//   CHECK-DAG:   #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)>
+
+// CHECK-LABEL: @static_basis
+//  CHECK-SAME:    (%[[IDX:.+]]: index)
+//       CHECK:   %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]]
+//       CHECK:   %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]]
+//       CHECK:   %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]]
+//       CHECK:   return %[[N]], %[[P]], %[[Q]]
+func.func @static_basis(%linear_index: index) -> (index, index, index) {
+  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+//   CHECK-DAG:   #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))>
+//   CHECK-DAG:   #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)>
+//   CHECK-DAG:   #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)>
+
+// CHECK-LABEL: @dynamic_basis
+//  CHECK-SAME:    (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//        CHECK:  %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
+//        CHECK:  %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
+//       CHECK:   %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+//       CHECK:   %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+//       CHECK:   %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+//       CHECK:   return %[[N]], %[[P]], %[[Q]]
+func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
+  %b2 = memref.dim %src, %c2 : memref<?x?x?xf32>
+  // Note: no outer bound.
+  %1:3 = affine.delinearize_index %linear_index into (%b1, %b2) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
+
+// CHECK-LABEL: @linearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
+  func.return %0 : index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] =  affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
+
+// CHECK-LABEL: @linearize_dynamic
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
+  // Note: no outer bounds
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
+  func.return %0 : index
+}
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index 650555cfb5fe13..e4b1b98d1893dc 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -1,38 +1,48 @@
 // RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s
 
-//   CHECK-DAG:   #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)>
-//   CHECK-DAG:   #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)>
-//   CHECK-DAG:   #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)>
-
-// CHECK-LABEL: @static_basis
+// CHECK-LABEL: @delinearize_static_basis
 //  CHECK-SAME:    (%[[IDX:.+]]: index)
-//       CHECK:   %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]]
-//       CHECK:   %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]]
-//       CHECK:   %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]]
+//   CHECK-DAG:   %[[C224:.+]] = arith.constant 224 : index
+//   CHECK-DAG:   %[[C50176:.+]] = arith.constant 50176 : index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]]
+//   CHECK-DAG:   %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]]
+//   CHECK-DAG:   %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
+//   CHECK-DAG:   %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]]
+//   CHECK-DAG:   %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
+//       CHECK:   %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]]
+//   CHECK-DAG:   %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]]
+//   CHECK-DAG:   %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
+//   CHECK-DAG:   %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]]
+//       CHECK:   %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
 //       CHECK:   return %[[N]], %[[P]], %[[Q]]
-func.func @static_basis(%linear_index: index) -> (index, index, index) {
+func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) {
   %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 
 // -----
 
-//   CHECK-DAG:   #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))>
-//   CHECK-DAG:   #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)>
-//   CHECK-DAG:   #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)>
-
-// CHECK-LABEL: @dynamic_basis
+// CHECK-LABEL: @delinearize_dynamic_basis
 //  CHECK-SAME:    (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//        CHECK:  %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
-//        CHECK:  %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
-//       CHECK:   %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
-//       CHECK:   %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
-//       CHECK:   %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+//       CHECK:  %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
+//       CHECK:  %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
+//       CHECK:  %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]]
+//       CHECK:  %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]]
+//   CHECK-DAG:  %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]]
+//   CHECK-DAG:  %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
+//   CHECK-DAG:  %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]]
+//   CHECK-DAG:  %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
+//       CHECK:  %[[P:.+]] = arith.divsi %[[P_MOD]], %[[D2]]
+//   CHECK-DAG:  %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]]
+//   CHECK-DAG:  %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
+//   CHECK-DAG:  %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]]
+//       CHECK:  %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
 //       CHECK:   return %[[N]], %[[P]], %[[Q]]
-func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
-  %c0 = arith.constant 0 : index
+func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
   %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
@@ -44,12 +54,15 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
 
 // -----
 
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
-
 // CHECK-LABEL: @linearize_static
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
-// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
-// CHECK: return %[[val_0]]
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
+// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]]
+// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]]
+// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
+// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
+// CHECK: return %[[val_1]]
 func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
   %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
   func.return %0 : index
@@ -57,14 +70,44 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
 
 // -----
 
-// CHECK-DAG: #[[$map0:.+]] =  affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
-
 // CHECK-LABEL: @linearize_dynamic
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
-// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]]
-// CHECK: return %[[val_0]]
+// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]]
+// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]]
+// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]]
+// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
+// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
+// CHECK: return %[[val_1]]
 func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
   // Note: no outer bounds
   %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
   func.return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: @linearize_sort_adds
+// CHECK-SAME: (%[[arg0:.+]]: memref<?xi32>, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: scf.for %[[ARG3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} {
+// CHECK: scf.for %[[ARG4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} {
+// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]]
+// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]]
+// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]]
+// Note: even though %arg3 has a lower stride, we add it first
+// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]]
+// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]]
+// CHECK: return %[[val_1]]
+func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0_i32 = arith.constant 0 : i32
+  scf.for %arg3 = %c0 to %arg2 step %c1 {
+    scf.for %arg4 = %c0 to %c4 step %c1 {
+      %idx = affine.linearize_index disjoint [%arg1, %arg4, %arg3] by (4, %arg2) : index
+      memref.store %c0_i32, %arg0[%idx] : memref<?xi32>
+    }
+  }
+  return
+}



More information about the Mlir-commits mailing list