[Mlir-commits] [mlir] [mlir][affine] Add folders for delinearize_index and linearize_index (PR #115766)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 11 12:51:10 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s).
This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.
---
Full diff: https://github.com/llvm/llvm-project/pull/115766.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+2)
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+68-1)
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+80)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..753b8951fb084b 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..37316632a6a06f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ if (adaptor.getLinearIndex() == nullptr)
+ return failure();
+
+ if (!adaptor.getDynamicBasis().empty())
+ return failure();
+
+ int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+ Type attrType = getLinearIndex().getType();
+ for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+ result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+ highPart = llvm::divideFloorSigned(highPart, modulus);
+ }
+ result.push_back(IntegerAttr::get(attrType, highPart));
+ std::reverse(result.begin(), result.end());
+ return success();
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4683,6 +4703,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
return success();
}
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+ if (llvm::any_of(adaptor.getMultiIndex(),
+ [](Attribute a) { return a == nullptr; }))
+ return nullptr;
+
+ if (!adaptor.getDynamicBasis().empty())
+ return nullptr;
+
+ int64_t result = 0;
+ int64_t stride = 1;
+ for (auto [indexAttr, length] :
+ llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+ llvm::reverse(getStaticBasis()))) {
+ result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+ stride = stride * length;
+ }
+
+ return IntegerAttr::get(getResult().getType(), result);
+}
+
namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4751,12 +4791,39 @@ struct DropLinearizeOneBasisElement final
return success();
}
};
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ Value leadingIdx = op.getMultiIndex().front();
+ if (!matchPattern(leadingIdx, m_Zero()))
+ return failure();
+
+ if (op.getMultiIndex().size() == 1) {
+ rewriter.replaceOp(op, leadingIdx);
+ return success();
+ }
+
+ SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+ rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+ op, op.getMultiIndex().drop_front(),
+ ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
- DropLinearizeOneBasisElement>(context);
+ DropLinearizeOneBasisElement, DropLinearizeLeadingZero>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8ed6ab5c965ca0 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
// -----
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+ %c_22 = arith.constant -22 : index
+ %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index, index, index, index, index, index) {
%c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
// -----
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+ return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+ return %ret : index
+}
+
+// -----
+
// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1643,17 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
%ret = affine.linearize_index [%arg0] by (%arg1) : index
return %ret : index
}
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+// CHECK: return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+ %c0 = arith.constant 0 : index
+ %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+ return %ret : index
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/115766
More information about the Mlir-commits
mailing list