[Mlir-commits] [mlir] 49f90e7 - [mlir][affine] Cancel exactly-matching delinearize/linearize pairs (#115758)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 12 13:36:11 PST 2024


Author: Krzysztof Drewniak
Date: 2024-11-12T15:36:07-06:00
New Revision: 49f90e798fe5667ac5e71a796aa897af3185137d

URL: https://github.com/llvm/llvm-project/commit/49f90e798fe5667ac5e71a796aa897af3185137d
DIFF: https://github.com/llvm/llvm-project/commit/49f90e798fe5667ac5e71a796aa897af3185137d.diff

LOG: [mlir][affine] Cancel exactly-matching delinearize/linearize pairs (#115758)

If we linearize values (with an assertion tha they are disjoint) and
then delinearize that linear index with th exact same basis, we know
that these operations are exact inverses of each other and can be
replaced with the original inputs to the linearization.

Similarly, if we take a linear index, delinearize it with some bases,
and then re-linearize it with that same basis (noting that the outputs
of the delinearization are guaranteed to by `disjoint`, even if this is
not asserted on the linearize_index operation), the re-linearization is
the inverse of the delinearization, so those two operations can also be
canceled out.

This commit adds canonicalization patterns for these simple
cancelations.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..c9d9202ae3cf1a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1090,7 +1090,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   let results = (outs Variadic<Index>:$multi_index);
 
   let assemblyFormat = [{
-    $linear_index `into` ` `
+    $linear_index `into`
     custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
     attr-dict `:` type($multi_index)
   }];

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..f1b0fe7e645051 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4586,7 +4586,8 @@ struct DropUnitExtentBasis
     }
 
     if (newOperands.size() == delinearizeOp.getStaticBasis().size())
-      return failure();
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "no unit basis elements");
 
     if (!newOperands.empty()) {
       auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
@@ -4619,17 +4620,48 @@ struct DropDelinearizeOneBasisElement
   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
                                 PatternRewriter &rewriter) const override {
     if (delinearizeOp.getStaticBasis().size() != 1)
-      return failure();
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "doesn't have a length-1 basis");
     rewriter.replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
     return success();
   }
 };
 
+/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
+/// disjoint` and the two operations have the same basis, replace the
+/// delinearizeation results with the inputs of the `affine.linearize_index`
+/// since they are exact inverses of each other.
+///
+/// The `disjoint` flag is needed on the `affine.linearize_index` because
+/// otherwise, there is no guarantee that the inputs to the linearization are
+/// in-bounds the way the outputs of the delinearization would be.
+struct CancelDelinearizeOfLinearizeDisjointExact
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint() ||
+        linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "not disjoint or basis doesn't match delinearize");
+
+    rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+  patterns.insert<CancelDelinearizeOfLinearizeDisjointExact,
+                  DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4723,7 +4755,8 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
       }
     }
     if (newIndices.size() == numIndices)
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "no unit basis entries to replace");
 
     if (newIndices.size() == 0) {
       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
@@ -4746,17 +4779,53 @@ struct DropLinearizeOneBasisElement final
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
     if (op.getStaticBasis().size() != 1 || op.getMultiIndex().size() != 1)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "doesn't have a a length-1 basis");
     rewriter.replaceOp(op, op.getMultiIndex().front());
     return success();
   }
 };
+
+/// Cancel out linearize_index(delinearize_index(x, B), B).
+///
+/// That is, rewrite
+/// ```
+/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
+/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
+/// %bN)
+/// ```
+/// to replacing `%y` with `%x`.
+struct CancelLinearizeOfDelinearizeExact final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto delinearizeOp = linearizeOp.getMultiIndex()
+                             .front()
+                             .getDefiningOp<affine::AffineDelinearizeIndexOp>();
+    if (!delinearizeOp)
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "last entry doesn't come from a delinearize");
+
+    if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+      return rewriter.notifyMatchFailure(
+          linearizeOp,
+          "basis of linearize and delinearize don't match exactly");
+
+    if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "not all indices come from delinearize");
+
+    rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
-               DropLinearizeOneBasisElement>(context);
+  patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeOneBasisElement,
+               DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8988f779ad02b4 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1535,6 +1535,60 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// These don't cancel because the delinearize and linearize have a 
diff erent basis.
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_
diff erent_basis(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 8, %[[ARG4]])
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_
diff erent_basis(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 8, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1631,48 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
   %ret = affine.linearize_index [%arg0] by (%arg1) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// Don't cancel because the values from the delinearize aren't used in order
+// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     return %[[LIN]]
+func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// Won't cancel because the linearize and delinearize are using a 
diff erent basis
+// CHECK-LABEL: func @no_cancel_linearize_denearize_
diff erent_basis(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
+//       CHECK:     return %[[LIN]]
+func.func @no_cancel_linearize_denearize_
diff erent_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
+  return %1 : index
+}


        


More information about the Mlir-commits mailing list