[Mlir-commits] [mlir] [mlir][affine] Cancel exactly-matching delinearize/linearize pairs (PR #115758)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 11 11:09:23 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization.

Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by `disjoint`, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out.

This commit adds canonicalization patterns for these simple cancelations.

---
Full diff: https://github.com/llvm/llvm-project/pull/115758.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+1-1) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+63-2) 
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+96) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..c9d9202ae3cf1a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1090,7 +1090,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   let results = (outs Variadic<Index>:$multi_index);
 
   let assemblyFormat = [{
-    $linear_index `into` ` `
+    $linear_index `into`
     custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
     attr-dict `:` type($multi_index)
   }];
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..d73d808753ba54 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4625,11 +4625,39 @@ struct DropDelinearizeOneBasisElement
   }
 };
 
+/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
+/// disjoint` and the two operations have the same basis, replace the
+/// delinearizeation results with the inputs of the `affine.linearize_index`
+/// since they are exact inverses of each other.
+///
+/// The `disjoint` flag is needed on the `affine.linearize_index` because
+/// otherwise, there is no guarantee that the inputs to the linearization are
+/// in-bounds the way the outputs of the delinearization would be.
+struct CancelDelinearizeOfLinearizeDisjointExact
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return failure();
+
+    if (!linearizeOp.getDisjoint() ||
+        linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+      return failure();
+
+    rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+  patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis,
+                  CancelDelinearizeOfLinearizeDisjointExact>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4751,12 +4779,45 @@ struct DropLinearizeOneBasisElement final
     return success();
   }
 };
+
+/// Cancel out linearize_index(delinearize_index(x, B), B).
+///
+/// That is, rewrite
+/// ```
+/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
+/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
+/// %bN)
+/// ```
+/// to replacing `%y` with `%x`.
+struct CancelLinearizeOfDelinearizeExact final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto delinearizeOp = linearizeOp.getMultiIndex()
+                             .front()
+                             .getDefiningOp<affine::AffineDelinearizeIndexOp>();
+    if (!delinearizeOp)
+      return failure();
+
+    if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+      return failure();
+
+    if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+      return failure();
+
+    rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
-               DropLinearizeOneBasisElement>(context);
+               DropLinearizeOneBasisElement, CancelLinearizeOfDelinearizeExact>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..99c115ba782c01 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1526,6 +1526,59 @@ func.func @delinearize_non_induction_variable(%arg0: memref<?xi32>, %i : index,
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_different_basis(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 8, %[[ARG4]])
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 8, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @delinearize_non_loop_like
 // CHECK-NOT: affine.delinearize
 func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index {
@@ -1577,3 +1630,46 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
   %ret = affine.linearize_index [%arg0] by (%arg1) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     return %[[LIN]]
+func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
+//       CHECK:     return %[[LIN]]
+func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
+  return %1 : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/115758


More information about the Mlir-commits mailing list