[Mlir-commits] [mlir] [mlir][affine] Cancel exactly-matching delinearize/linearize pairs (PR #115758)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Nov 11 11:08:48 PST 2024
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/115758
If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization.
Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by `disjoint`, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out.
This commit adds canonicalization patterns for these simple cancelations.
>From d904ceca683a95462dafd2755eaf469f08ae02af Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 11 Nov 2024 17:59:21 +0000
Subject: [PATCH] [mlir][affine] Cancel exactly-matching delinearize/linearize
pairs
If we linearize values (with an assertion tha they are disjoint) and
then delinearize that linear index with th exact same basis, we know
that these operations are exact inverses of each other and can be
replaced with the original inputs to the linearization.
Similarly, if we take a linear index, delinearize it with some bases,
and then re-linearize it with that same basis (noting that the outputs
of the delinearization are guaranteed to by `disjoint`, even if this
is not asserted on the linearize_index operation), the
re-linearization is the inverse of the delinearization, so those two
operations can also be canceled out.
This commit adds canonicalization patterns for these simple
cancelations.
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 2 +-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 65 ++++++++++++-
mlir/test/Dialect/Affine/canonicalize.mlir | 96 +++++++++++++++++++
3 files changed, 160 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..c9d9202ae3cf1a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1090,7 +1090,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
let results = (outs Variadic<Index>:$multi_index);
let assemblyFormat = [{
- $linear_index `into` ` `
+ $linear_index `into`
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
attr-dict `:` type($multi_index)
}];
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..d73d808753ba54 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4625,11 +4625,39 @@ struct DropDelinearizeOneBasisElement
}
};
+/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
+/// disjoint` and the two operations have the same basis, replace the
+/// delinearizeation results with the inputs of the `affine.linearize_index`
+/// since they are exact inverses of each other.
+///
+/// The `disjoint` flag is needed on the `affine.linearize_index` because
+/// otherwise, there is no guarantee that the inputs to the linearization are
+/// in-bounds the way the outputs of the delinearization would be.
+struct CancelDelinearizeOfLinearizeDisjointExact
+ : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp = delinearizeOp.getLinearIndex()
+ .getDefiningOp<affine::AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return failure();
+
+ if (!linearizeOp.getDisjoint() ||
+ linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+ return failure();
+
+ rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+ patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis,
+ CancelDelinearizeOfLinearizeDisjointExact>(context);
}
//===----------------------------------------------------------------------===//
@@ -4751,12 +4779,45 @@ struct DropLinearizeOneBasisElement final
return success();
}
};
+
+/// Cancel out linearize_index(delinearize_index(x, B), B).
+///
+/// That is, rewrite
+/// ```
+/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
+/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
+/// %bN)
+/// ```
+/// to replacing `%y` with `%x`.
+struct CancelLinearizeOfDelinearizeExact final
+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto delinearizeOp = linearizeOp.getMultiIndex()
+ .front()
+ .getDefiningOp<affine::AffineDelinearizeIndexOp>();
+ if (!delinearizeOp)
+ return failure();
+
+ if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+ return failure();
+
+ if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ return failure();
+
+ rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
- DropLinearizeOneBasisElement>(context);
+ DropLinearizeOneBasisElement, CancelLinearizeOfDelinearizeExact>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..99c115ba782c01 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1526,6 +1526,59 @@ func.func @delinearize_non_induction_variable(%arg0: memref<?xi32>, %i : index,
// -----
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 4, %[[ARG4]])
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_different_basis(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 8, %[[ARG4]])
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (%arg3, 8, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
// CHECK-LABEL: func @delinearize_non_loop_like
// CHECK-NOT: affine.delinearize
func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index {
@@ -1577,3 +1630,46 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
%ret = affine.linearize_index [%arg0] by (%arg1) : index
return %ret : index
}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: return %[[ARG0]]
+func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: return %[[LIN]]
+func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
+// CHECK: return %[[LIN]]
+func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
+ return %1 : index
+}
More information about the Mlir-commits
mailing list