[Mlir-commits] [mlir] cb9481d - [mlir][affine] Add folders for delinearize_index and linearize_index (#115766)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 13 11:05:41 PST 2024
Author: Krzysztof Drewniak
Date: 2024-11-13T13:05:37-06:00
New Revision: cb9481dbf902adc349757eca12a0a09396dc4a23
URL: https://github.com/llvm/llvm-project/commit/cb9481dbf902adc349757eca12a0a09396dc4a23
DIFF: https://github.com/llvm/llvm-project/commit/cb9481dbf902adc349757eca12a0a09396dc4a23.diff
LOG: [mlir][affine] Add folders for delinearize_index and linearize_index (#115766)
This commit adds implementations of fold() for delinearize_index and
linearize_index to constant-fold them away when they have a fully
constant basis and constant argument(s).
This commit also adds a canonicalization pattern to linearize_index that
causes it to drop leading-zero inputs.
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c9d9202ae3cf1a..6a495e11ae1ad5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f1b0fe7e645051..fbc9053a0e273b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ if (adaptor.getLinearIndex() == nullptr)
+ return failure();
+
+ if (!adaptor.getDynamicBasis().empty())
+ return failure();
+
+ int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+ Type attrType = getLinearIndex().getType();
+ for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+ result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+ highPart = llvm::divideFloorSigned(highPart, modulus);
+ }
+ result.push_back(IntegerAttr::get(attrType, highPart));
+ std::reverse(result.begin(), result.end());
+ return success();
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4715,6 +4735,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
return success();
}
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+ if (llvm::any_of(adaptor.getMultiIndex(),
+ [](Attribute a) { return a == nullptr; }))
+ return nullptr;
+
+ if (!adaptor.getDynamicBasis().empty())
+ return nullptr;
+
+ int64_t result = 0;
+ int64_t stride = 1;
+ for (auto [indexAttr, length] :
+ llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+ llvm::reverse(getStaticBasis()))) {
+ result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+ stride = stride * length;
+ }
+
+ return IntegerAttr::get(getResult().getType(), result);
+}
+
namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4820,11 +4860,39 @@ struct CancelLinearizeOfDelinearizeExact final
return success();
}
};
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ Value leadingIdx = op.getMultiIndex().front();
+ if (!matchPattern(leadingIdx, m_Zero()))
+ return failure();
+
+ if (op.getMultiIndex().size() == 1) {
+ rewriter.replaceOp(op, leadingIdx);
+ return success();
+ }
+
+ SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+ rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+ op, op.getMultiIndex().drop_front(),
+ ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeOneBasisElement,
+ patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+ DropLinearizeOneBasisElement,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 8988f779ad02b4..ec00b31258d072 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
// -----
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+ %c_22 = arith.constant -22 : index
+ %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index, index, index, index, index, index) {
%c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
// -----
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+ return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+ return %ret : index
+}
+
+// -----
+
// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
@@ -1676,3 +1742,17 @@ func.func @no_cancel_linearize_denearize_
diff erent_basis(%arg0: index, %arg1: in
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+// CHECK: return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+ %c0 = arith.constant 0 : index
+ %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+ return %ret : index
+}
+
More information about the Mlir-commits
mailing list