[Mlir-commits] [mlir] [mlir][Affine] Split off delinearize parts that depend on last component (PR #117015)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Nov 25 11:43:06 PST 2024
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/117015
>From 9b1890c7a2fe43388b7f7f4ec25aed0a45f20d3d Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 20 Nov 2024 17:36:19 +0000
Subject: [PATCH 1/4] [mlir][Affine] Split off delinearize parts that depend on
last component
If we have
%0 = affine.linearize_index disjoint [%a, %b] by (A, B)
%1:3 = affine.delinearize_index %0 into (A, B1, B2)
where B = B1 * B2 (or some mor complex product), we can simplify this
to
%0 = affine.linearize_index disjoint [%a] by (A)
%1a:1 = affine.delinearize_index %0 into (A)
%1b:2 = affine.delinearize_index %b into (B1, B2)
This, and more complex cases, prevent us from adding terms together
only to divide them away from each other.
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 87 +++++++++++++++++++++-
mlir/test/Dialect/Affine/canonicalize.mlir | 66 ++++++++++++++++
2 files changed, 151 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..b13331abc32ada 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact
return success();
}
};
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
+/// %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+ : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp = delinearizeOp.getLinearIndex()
+ .getDefiningOp<affine::AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "index doesn't come from linearize");
+
+ if (!linearizeOp.getDisjoint())
+ return rewriter.notifyMatchFailure(linearizeOp,
+ "linearize isn't disjoint");
+
+ int64_t target = linearizeOp.getStaticBasis().back();
+ if (ShapedType::isDynamic(target))
+ return rewriter.notifyMatchFailure(
+ linearizeOp, "linearize ends with dynamic basis value");
+
+ int64_t sizeToSplit = 1;
+ size_t elemsToSplit = 0;
+ ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+ for (int64_t basisElem : llvm::reverse(basis)) {
+ if (ShapedType::isDynamic(basisElem))
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "dynamic basis element while scanning for split");
+ sizeToSplit *= basisElem;
+ elemsToSplit += 1;
+
+ if (sizeToSplit > target)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "overshot last argument size");
+ if (sizeToSplit == target)
+ break;
+ }
+
+ if (sizeToSplit < target)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "product of known basis elements doesn't exceed last "
+ "linearize argument");
+
+ if (elemsToSplit < 2)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "don't have a non-trivial basis product");
+
+ Value linearizeWithoutBack =
+ rewriter.create<affine::AffineLinearizeIndexOp>(
+ linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+ linearizeOp.getDynamicBasis(),
+ linearizeOp.getStaticBasis().drop_back(),
+ linearizeOp.getDisjoint());
+ auto delinearizeWithoutSplitPart =
+ rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), linearizeWithoutBack,
+ delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+ delinearizeOp.hasOuterBound());
+ auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+ basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
+ SmallVector<Value> results = llvm::to_vector(
+ llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
+ delinearizeBack.getResults()));
+ rewriter.replaceOp(delinearizeOp, results);
+
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns
- .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
- context);
+ .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis,
+ SplitDelinearizeSpanningLastLinearizeArg>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..efeea7eb2af530 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
// -----
+// CHECK-LABEL: func @split_delinearize_spanning_final_part
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
+// CHECK: %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
+// CHECK: %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+// CHECK: return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 8, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+// CHECK: return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't match the last basis element before
+// overshooting it, don't simplify.
+// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
+// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
+func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 16, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't fully multiply to the final basis element.
+// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1
+func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
+ %1:3 = affine.delinearize_index %0 into (4, 8)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
>From c90f3f3bd7381ff846046d6216231100708ef4c6 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 22 Nov 2024 18:23:54 +0000
Subject: [PATCH 2/4] clang-format
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bb9f1d72e611cb..28d27b0b2810f4 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4816,8 +4816,10 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
- DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(context);
+ patterns
+ .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
+ DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
+ context);
}
//===----------------------------------------------------------------------===//
>From 808e0a4bc26809da9ad400b747a5e67bf9ed844f Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 25 Nov 2024 13:41:05 -0600
Subject: [PATCH 3/4] Update debug message wording
Co-authored-by: Abhishek Varma <abhvarma at amd.com>
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 28d27b0b2810f4..ba259876ec18ca 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4788,7 +4788,7 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
if (elemsToSplit < 2)
return rewriter.notifyMatchFailure(
- delinearizeOp, "don't have a non-trivial basis product");
+ delinearizeOp, "need at least two elements to form the basis product");
Value linearizeWithoutBack =
rewriter.create<affine::AffineLinearizeIndexOp>(
>From 62f24e096017465b3e771304887a7a65c79091b1 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 25 Nov 2024 19:42:52 +0000
Subject: [PATCH 4/4] Clang-format of suggestion
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ba259876ec18ca..1c5466730a5589 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4788,7 +4788,8 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
if (elemsToSplit < 2)
return rewriter.notifyMatchFailure(
- delinearizeOp, "need at least two elements to form the basis product");
+ delinearizeOp,
+ "need at least two elements to form the basis product");
Value linearizeWithoutBack =
rewriter.create<affine::AffineLinearizeIndexOp>(
More information about the Mlir-commits
mailing list