[Mlir-commits] [mlir] [mlir][linalg] Canonicalize non-identity `linalg.generic` ops (PR #101430)

Ian Wood llvmlistbot at llvm.org
Wed Jul 31 16:29:14 PDT 2024


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/101430

Extend `linalg.generic`'s canonicalization patterns to be able to erase ops with non-identity indexing maps but that are still noops.

>From afcf8f1680b2aa074acb7dfdd3779b6db0266569 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Wed, 31 Jul 2024 22:51:52 +0000
Subject: [PATCH] Canonicalize non-identity `linalg.generic` ops

Extend `linalg.generic`'s canonicalization patterns to be able to erase
ops with non-identity indexing maps but are still noops.
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   |  9 ++++++---
 mlir/test/Dialect/Linalg/canonicalize.mlir | 20 ++++++++++++++++++++
 2 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 99b625d99fec2..34f403330621f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1217,9 +1217,12 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy linalgOp,
                                 PatternRewriter &rewriter) const override {
-    // Check all indexing maps are identity.
-    if (llvm::any_of(linalgOp.getIndexingMapsArray(),
-                     [](AffineMap map) { return !map.isIdentity(); }))
+    // All indexing maps must be equal permutations
+    auto indexingMaps = linalgOp.getIndexingMapsArray();
+    if (!llvm::all_equal(indexingMaps))
+      return failure();
+
+    if (!indexingMaps.empty() && !indexingMaps.front().isPermutation())
       return failure();
 
     // Check that the body of the linalg operation is just a linalg.yield
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b1212b8863a5d..a50fbb0fc3b86 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -911,6 +911,26 @@ func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
 
 // -----
 
+#map = affine_map<(d0, d1) -> (d1, d0)>
+func.func @erase_non_identity_noop(%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0 : tensor<?x?xf32>)
+    outs(%arg1 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in: f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32> 
+}
+
+// Do not erase ops with buffer semantics.
+// CHECK-LABEL: func @erase_non_identity_noop
+//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
+//       CHECK:   return %[[ARG0]] : tensor<?x?xf32>
+
+// -----
+
 // Just make sure that we don't crash.
 
 // CHECK-LABEL: func @dedeplicate_regression_test



More information about the Mlir-commits mailing list