[Mlir-commits] [mlir] [mlir][Affine] Expand affine.[de]linearize_index without affine maps (PR #116703)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Nov 18 14:30:06 PST 2024
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/116703
As the documentation for -affine-expand-index-ops says, affine.delinearize_index and affine.linearize_index don't need to be expanded into the affine dialect.
Expanding these operations into affine.apply operations can introduce unwanted "simplifications", mainly translations of `(dN mod C + ...)` to `(dN + ... - (dN floordiv C) * C)` and similar, which create worse generated code. This commit resolves this issue by expanding out affine.delanierize_index directly.
In addition, the lowering of affine.linearize_index now sorts the operands by loop-independence, allowing an increased amount of loop-invariant code motion after lowering.
The old behavior is preserved as -expand-affine-index-ops-as-affine but is no longer the default
>From ba4ae62b1ab03531b4bc94d08b7c5579d8fdf5c2 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 13 Nov 2024 23:44:29 +0000
Subject: [PATCH] [mlir][Affine] Expand affine.[de]linearize_index without
affine maps
As the documentation for -affine-expand-index-ops says,
affine.delinearize_index and affine.linearize_index don't need to be
expanded into the affine dialect.
Expanding these operations into affine.apply operations can introduce
unwanted "simplifications", mainly translations of `(dN mod C + ...)`
to `(dN + ... - (dN floordiv C) * C)` and similar, which create worse
generated code. This commit resolves this issue by expanding out
affine.delanierize_index directly.
In addition, the lowering of affine.linearize_index now sorts the
operands by loop-independence, allowing an increased amount of
loop-invariant code motion after lowering.
The old behavior is preserved as -expand-affine-index-ops-as-affine
but is no longer the default
---
mlir/include/mlir/Dialect/Affine/LoopUtils.h | 5 +
mlir/include/mlir/Dialect/Affine/Passes.h | 4 +
mlir/include/mlir/Dialect/Affine/Passes.td | 5 +
.../Dialect/Affine/Transforms/Transforms.h | 4 +
.../Transforms/AffineExpandIndexOps.cpp | 148 ++++++++++++++++--
.../AffineExpandIndexOpsAsAffine.cpp | 98 ++++++++++++
.../Dialect/Affine/Transforms/CMakeLists.txt | 1 +
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 12 ++
.../AffineToStandard/lower-affine.mlir | 66 --------
.../affine-expand-index-ops-as-affine.mlir | 70 +++++++++
.../Affine/affine-expand-index-ops.mlir | 101 ++++++++----
11 files changed, 405 insertions(+), 109 deletions(-)
create mode 100644 mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
create mode 100644 mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 380c742b5224cb..7fe1f6d48ceebb 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -301,6 +301,11 @@ separateFullTiles(MutableArrayRef<AffineForOp> nest,
/// Walk an affine.for to find a band to coalesce.
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
+/// Count the number of loops surrounding `operand` such that operand could be
+/// hoisted above.
+/// Stop counting at the first loop over which the operand cannot be hoisted.
+/// This counts any LoopLikeOpInterface, not just affine.for.
+int64_t numEnclosingInvariantLoops(OpOperand &operand);
} // namespace affine
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 61f24255f305f7..e152101236dc7a 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -116,6 +116,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createPipelineDataTransferPass();
/// operations (not necessarily restricted to Affine dialect).
std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
+/// Creates a pass to expand affine index operations into affine.apply
+/// operations.
+std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index b08e803345f76e..77073aa29da73e 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -408,4 +408,9 @@ def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
}
+def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
+ let summary = "Lower affine operations operating on indices into affine.apply operations";
+ let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
+}
+
#endif // MLIR_DIALECT_AFFINE_PASSES
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index b244d37c0707f2..bf830a29613fdd 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -37,6 +37,10 @@ class AffineApplyOp;
/// operations (not necessarily restricted to Affine dialect).
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
+/// Populate patterns that expand affine index operations into their equivalent
+/// `affine.apply` representations.
+void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns);
+
/// Helper function to rewrite `op`'s affine map and reorder its operands such
/// that they are in increasing order of hoistability (i.e. the least hoistable)
/// operands come first in the operand list.
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 15478e0e1e3a5b..d7b218225bc9ab 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -10,6 +10,7 @@
// fundamental operations.
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -28,6 +29,50 @@ namespace affine {
using namespace mlir;
using namespace mlir::affine;
+/// Given a basis (in static and dynamic components), return the sequence of
+/// suffix products of the basis, including the product of the entire basis,
+/// which must **not** contain an outer bound.
+///
+/// If excess dynamic values are provided, the values at the beginning
+/// will be ignored. This allows for dropping the outer bound without
+/// needing to manipulate the dynamic value array.
+static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
+ ValueRange dynamicBasis,
+ ArrayRef<int64_t> staticBasis) {
+ if (staticBasis.empty())
+ return {};
+
+ SmallVector<Value> result;
+ result.reserve(staticBasis.size());
+ size_t dynamicIndex = dynamicBasis.size();
+ Value dynamicPart = nullptr;
+ int64_t staticPart = 1;
+ for (int64_t elem : llvm::reverse(staticBasis)) {
+ if (ShapedType::isDynamic(elem)) {
+ if (dynamicPart)
+ dynamicPart = rewriter.create<arith::MulIOp>(
+ loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+ else
+ dynamicPart = dynamicBasis[dynamicIndex - 1];
+ --dynamicIndex;
+ } else {
+ staticPart *= elem;
+ }
+
+ if (dynamicPart && staticPart == 1) {
+ result.push_back(dynamicPart);
+ } else {
+ Value stride =
+ rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
+ if (dynamicPart)
+ stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+ result.push_back(stride);
+ }
+ }
+ std::reverse(result.begin(), result.end());
+ return result;
+}
+
namespace {
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
@@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- FailureOr<SmallVector<Value>> multiIndex =
- delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
- op.getEffectiveBasis(), /*hasOuterBound=*/false);
- if (failed(multiIndex))
- return failure();
- rewriter.replaceOp(op, *multiIndex);
+ Location loc = op.getLoc();
+ Value linearIdx = op.getLinearIndex();
+ unsigned numResults = op.getNumResults();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numResults == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ if (numResults == 1) {
+ rewriter.replaceOp(op, linearIdx);
+ return success();
+ }
+
+ SmallVector<Value> results;
+ results.reserve(numResults);
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+
+ Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ Value initialPart =
+ rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+ results.push_back(initialPart);
+
+ auto emitModTerm = [&](Value stride) -> Value {
+ Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+ Value remainderNegative = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, remainder, zero);
+ Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+ Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+ corrected, remainder);
+ return mod;
+ };
+
+ // Generate all the intermediate parts
+ for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+ Value thisStride = strides[i];
+ Value nextStride = strides[i + 1];
+ Value modulus = emitModTerm(thisStride);
+ // We know both inputs are positive, so floorDiv == div.
+ // This could potentially be a divui, but it's not clear if that would
+ // cause issues.
+ Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+ results.push_back(divided);
+ }
+
+ results.push_back(emitModTerm(strides.back()));
+
+ rewriter.replaceOp(op, results);
return success();
}
};
/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions.
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
@@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
return success();
}
- SmallVector<OpFoldResult> multiIndex =
- getAsOpFoldResult(op.getMultiIndex());
- OpFoldResult linearIndex =
- linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
- Value linearIndexValue =
- getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
- rewriter.replaceOp(op, linearIndexValue);
+ Location loc = op.getLoc();
+ ValueRange multiIndex = op.getMultiIndex();
+ size_t numIndexes = multiIndex.size();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numIndexes == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ SmallVector<std::pair<Value, int64_t>> scaledValues;
+ scaledValues.reserve(numIndexes);
+
+ // Note: strides doesn't contain a value for the final element (stride 1)
+ // and everything else lines up. We use the "mutable" accessor so we can get
+ // our hands on an `OpOperand&` for the loop invariant counting function.
+ for (auto [stride, idxOp] :
+ llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+ Value scaledIdx =
+ rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+ int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+ scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+ }
+ scaledValues.emplace_back(
+ multiIndex.back(),
+ numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+ // Sort by how many enclosing loops there are, ties implicitly broken by
+ // size of the stride.
+ llvm::stable_sort(scaledValues,
+ [&](auto l, auto r) { return l.second > r.second; });
+
+ Value result = scaledValues.front().first;
+ for (auto [scaledValue, numHoistableLoops] :
+ llvm::drop_begin(scaledValues)) {
+ std::ignore = numHoistableLoops;
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+ }
+ rewriter.replaceOp(op, result);
return success();
}
};
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
new file mode 100644
index 00000000000000..bfcc1ddf91653a
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -0,0 +1,98 @@
+//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to expand affine index ops into one or more more
+// fundamental operations.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+struct LowerDelinearizeIndexOps
+ : public OpRewritePattern<AffineDelinearizeIndexOp> {
+ using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<SmallVector<Value>> multiIndex =
+ delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+ op.getEffectiveBasis(), /*hasOuterBound=*/false);
+ if (failed(multiIndex))
+ return failure();
+ rewriter.replaceOp(op, *multiIndex);
+ return success();
+ }
+};
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions.
+struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ // Should be folded away, included here for safety.
+ if (op.getMultiIndex().empty()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+ return success();
+ }
+
+ SmallVector<OpFoldResult> multiIndex =
+ getAsOpFoldResult(op.getMultiIndex());
+ OpFoldResult linearIndex =
+ linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+ Value linearIndexValue =
+ getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+ rewriter.replaceOp(op, linearIndexValue);
+ return success();
+ }
+};
+
+class ExpandAffineIndexOpsAsAffinePass
+ : public affine::impl::AffineExpandIndexOpsAsAffineBase<
+ ExpandAffineIndexOpsAsAffinePass> {
+public:
+ ExpandAffineIndexOpsAsAffinePass() = default;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ populateAffineExpandIndexOpsAsAffinePatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
+ patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
+ return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 772f15335d907f..c42789b01bc9fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRAffineTransforms
AffineDataCopyGeneration.cpp
AffineExpandIndexOps.cpp
+ AffineExpandIndexOpsAsAffine.cpp
AffineLoopInvariantCodeMotion.cpp
AffineLoopNormalize.cpp
AffineParallelize.cpp
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index d6fc4ed07bfab3..e75d1c571d08cc 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2772,3 +2772,15 @@ LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
}
return result;
}
+
+int64_t mlir::affine::numEnclosingInvariantLoops(OpOperand &operand) {
+ int64_t count = 0;
+ Operation *currentOp = operand.getOwner();
+ while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
+ if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
+ break;
+ currentOp = loopOp;
+ count++;
+ }
+ return count;
+}
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 3be42661f63ee7..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,69 +927,3 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
// CHECK: scf.reduce.return %[[RES]] : i64
// CHECK: }
// CHECK: }
-
-///////////////////////////////////////////////////////////////////////
-
-func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
- %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
- return %1#0, %1#1, %1#2 : index, index, index
-}
-// CHECK-LABEL: func.func @test_dilinearize_index(
-// CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK: %[[VAL_1:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_2:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant -1 : index
-// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
-// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
-// CHECK: %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
-// CHECK: %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
-// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
-// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
-// CHECK: %[[VAL_12:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
-// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
-// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
-// CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
-// CHECK: %[[VAL_18:.*]] = arith.constant 50176 : index
-// CHECK: %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
-// CHECK: %[[VAL_20:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
-// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
-// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
-// CHECK: %[[VAL_24:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_25:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_26:.*]] = arith.constant -1 : index
-// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
-// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
-// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
-// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
-// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
-// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
-// CHECK: %[[VAL_33:.*]] = arith.constant 224 : index
-// CHECK: %[[VAL_34:.*]] = arith.remsi %[[VAL_0]], %[[VAL_33]] : index
-// CHECK: %[[VAL_35:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_34]], %[[VAL_35]] : index
-// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_33]] : index
-// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index
-// CHECK: return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index
-// CHECK: }
-
-/////////////////////////////////////////////////////////////////////
-
-func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
- %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
- return %ret : index
-}
-
-// CHECK-LABEL: @test_linearize_index
-// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
-// CHECK: %[[c15:.+]] = arith.constant 15 : index
-// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
-// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
-// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
-// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
-// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
-// CHECK-NEXT: return %[[ret]]
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
new file mode 100644
index 00000000000000..bf9f00da5793aa
--- /dev/null
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -affine-expand-index-ops-as-affine -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)>
+
+// CHECK-LABEL: @static_basis
+// CHECK-SAME: (%[[IDX:.+]]: index)
+// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]]
+// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]]
+// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]]
+// CHECK: return %[[N]], %[[P]], %[[Q]]
+func.func @static_basis(%linear_index: index) -> (index, index, index) {
+ %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)>
+
+// CHECK-LABEL: @dynamic_basis
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
+// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
+// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: return %[[N]], %[[P]], %[[Q]]
+func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
+ %b2 = memref.dim %src, %c2 : memref<?x?x?xf32>
+ // Note: no outer bound.
+ %1:3 = affine.delinearize_index %linear_index into (%b1, %b2) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
+
+// CHECK-LABEL: @linearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
+ func.return %0 : index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
+
+// CHECK-LABEL: @linearize_dynamic
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
+ // Note: no outer bounds
+ %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
+ func.return %0 : index
+}
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index 650555cfb5fe13..e4b1b98d1893dc 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -1,38 +1,48 @@
// RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)>
-
-// CHECK-LABEL: @static_basis
+// CHECK-LABEL: @delinearize_static_basis
// CHECK-SAME: (%[[IDX:.+]]: index)
-// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]]
-// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]]
-// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]]
+// CHECK-DAG: %[[C224:.+]] = arith.constant 224 : index
+// CHECK-DAG: %[[C50176:.+]] = arith.constant 50176 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]]
+// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]]
+// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
+// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]]
+// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
+// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]]
+// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]]
+// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
+// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]]
+// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
// CHECK: return %[[N]], %[[P]], %[[Q]]
-func.func @static_basis(%linear_index: index) -> (index, index, index) {
+func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) {
%1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
// -----
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)>
-
-// CHECK-LABEL: @dynamic_basis
+// CHECK-LABEL: @delinearize_dynamic_basis
// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
-// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
-// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
-// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
-// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]]
+// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
+// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
+// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]]
+// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]]
+// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]]
+// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
+// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]]
+// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
+// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[D2]]
+// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]]
+// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
+// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]]
+// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
// CHECK: return %[[N]], %[[P]], %[[Q]]
-func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
- %c0 = arith.constant 0 : index
+func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
@@ -44,12 +54,15 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
// -----
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
-
// CHECK-LABEL: @linearize_static
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
-// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
-// CHECK: return %[[val_0]]
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
+// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]]
+// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]]
+// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
+// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
+// CHECK: return %[[val_1]]
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
func.return %0 : index
@@ -57,14 +70,44 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
// -----
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
-
// CHECK-LABEL: @linearize_dynamic
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
-// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]]
-// CHECK: return %[[val_0]]
+// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]]
+// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]]
+// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]]
+// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
+// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
+// CHECK: return %[[val_1]]
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
// Note: no outer bounds
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
func.return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: @linearize_sort_adds
+// CHECK-SAME: (%[[arg0:.+]]: memref<?xi32>, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: scf.for %[[ARG3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} {
+// CHECK: scf.for %[[ARG4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} {
+// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]]
+// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]]
+// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]]
+// Note: even though %arg3 has a lower stride, we add it first
+// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]]
+// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]]
+// CHECK: return %[[val_1]]
+func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0_i32 = arith.constant 0 : i32
+ scf.for %arg3 = %c0 to %arg2 step %c1 {
+ scf.for %arg4 = %c0 to %c4 step %c1 {
+ %idx = affine.linearize_index disjoint [%arg1, %arg4, %arg3] by (4, %arg2) : index
+ memref.store %c0_i32, %arg0[%idx] : memref<?xi32>
+ }
+ }
+ return
+}
More information about the Mlir-commits
mailing list