[Mlir-commits] [mlir] [mlir][Linalg] Add folders for `linalg.transpose` (PR #81709)

Diego Caballero llvmlistbot at llvm.org
Tue Feb 20 16:49:41 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/81709

>From 49af44262a2f86074cb10eecc1af69baa2ab7be3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 14 Feb 2024 06:16:22 +0000
Subject: [PATCH 1/2] [mlir][Linalg] Add folders for `linalg.transpose`

This PR adds folders for linalg transpose ops with only one dimension or
an identity permutation. The folding removes the `linalg.transpose` and
just propagates the input tensor. Given that this is a DPS op, I'm now
wondering if this folding is incorrect and we should instead replace the
op with a `linalg.copy` so that the init tensor is still used. Feedback
would be appreciated. I think that propagating the input tensor if the
DPS op is folded away should be ok, given that all the uses of the init
tensor are replaced with the input tensor, but I might be missing something.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  3 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 16 +++++++++
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 34 +++++++++++++++++++
 3 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..de9414598b0282 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
           }
     ```
 
-    Shortened print form is available. Applies to simple maps with one 
+    Shortened print form is available. Applies to simple maps with one
     non-yield operation inside the body.
 
     The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
                              ::mlir::OperationState & odsState);
   }];
 
+  let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..2f6ab7e32e5872 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1785,6 +1785,22 @@ void TransposeOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+                                SmallVectorImpl<OpFoldResult> &result) {
+  // Single dimension transpose.
+  if (getPermutation().size() == 0) {
+    result.push_back(getInput());
+    return success();
+  }
+  // Identity permutation.
+  if (isIdentityPermutation(getPermutation())) {
+    result.push_back(getInput());
+    return success();
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..5bc4cb82f8cdbc 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,37 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
   %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %copy : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+                        %init: tensor<16xf32>) -> tensor<16xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16xf32>)
+      outs(%init:tensor<16xf32>)
+      permutation = [0]
+  func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+                                   %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x32x64xf32>)
+      permutation = [0, 1, 2]
+  func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>

>From 5715b89b7799093e058f8f12f30db9578ef8972a Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 21 Feb 2024 00:49:20 +0000
Subject: [PATCH 2/2] Fix tests

---
 .../Linalg/generalize-tensor-pack-tile.mlir     | 16 ++++------------
 .../Dialect/Linalg/generalize-tensor-pack.mlir  | 16 ++++------------
 .../Linalg/generalize-tensor-unpack-tile.mlir   | 16 ++++------------
 .../Linalg/generalize-tensor-unpack.mlir        | 17 ++++-------------
 4 files changed, 16 insertions(+), 49 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d63433248ab1e0..0a197a0ee9fa68 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -48,12 +48,8 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
 // CHECK:             %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]]
 // CHECK:               tensor.yield %[[PAD_VAL]]
 // CHECK:             } : tensor<?x?xf32> to tensor<8x2xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:         %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[PAD]] : tensor<8x2xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:      permutation = [0, 1]
-// CHECK:         %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %{{.+}} = tensor.insert_slice %[[PAD]] into %{{.+}}
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -81,12 +77,8 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
 // CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
 // CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC]]
 // CHECK-SAME:          [%[[IN_K]], %[[IN_C]]] [32, 8] [1, 1]
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]]
-// CHECK-SAME:          outs(%[[EMPTY]]
-// CHECK-SAME:          permutation = [0, 1]
-// CHECK:             %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %[[SUB_ITER:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
 // CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32>
 // CHECK:             %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}}
 // CHECK-SAME:          [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32>
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index eaad6bd8270476..7d87a0994004fe 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -29,12 +29,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:    %[[PAD_VAL:[a-zA-Z0-9]+]]
 // CHECK:         %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1]
 // CHECK:           tensor.yield %[[PAD_VAL]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:         %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[PAD]] : tensor<8x2xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:      permutation = [0, 1]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-NOT:     linalg.transpose
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
@@ -47,12 +43,8 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
 // CHECK-LABEL: func.func @simple_NC_to_CNnc
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[SRC]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x8xf32>)
-// CHECK-SAME:      permutation = [0, 1]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-NOT:     linalg.transpose
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
index f0d4b790520e03..7d64331c987841 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -57,12 +57,8 @@ func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13
 // CHECK-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
 // CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
 // CHECK-SAME:          [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-SAME:          outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:          permutation = [0, 1]
-// CHECK:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TILE]]
 // CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
 // CHECK:             %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
 // CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
@@ -96,12 +92,8 @@ func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>)
 // CHECK-SAME:          [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
 // CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]]
-// CHECK-SAME:          outs(%[[EMPTY]]
-// CHECK-SAME:          permutation = [0, 1]
-// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
 // CHECK-SAME:          [%[[K]], %[[C]]] [32, 8] [1, 1]
 
 
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index 02376808865006..153ce68b8f086c 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -27,14 +27,10 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
 // CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:      permutation = [0, 1]
+// CHECK-NOT:     linalg.transpose
 //                They have the same type, so the insert_slice op is folded
 //                away.
-// CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1]
+// CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
 // CHECK:         return %[[SLICE]]
 
 // -----
@@ -47,14 +43,10 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
 // CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x8xf32>)
-// CHECK-SAME:      permutation = [0, 1]
+// CHECK-NOT:     linalg.transpose
 //                They have the same type, so the insert_slice op is folded
 //                away.
-// CHECK:         return %[[TRANSP]]
+// CHECK:         return %[[TILE]]
 
 // -----
 
@@ -75,7 +67,6 @@ func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x
 //                away.
 // CHECK:         return %[[TRANSP]]
 
-
 // -----
 
 func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {



More information about the Mlir-commits mailing list