[Mlir-commits] [mlir] b9a071d - [mlir][Linalg] Add folders for `linalg.transpose` (#81709)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 20 17:40:02 PST 2024


Author: Diego Caballero
Date: 2024-02-20T17:39:58-08:00
New Revision: b9a071dc3995c1599724447b9db8ced449318839

URL: https://github.com/llvm/llvm-project/commit/b9a071dc3995c1599724447b9db8ced449318839
DIFF: https://github.com/llvm/llvm-project/commit/b9a071dc3995c1599724447b9db8ced449318839.diff

LOG: [mlir][Linalg] Add folders for `linalg.transpose` (#81709)

This PR adds folders for linalg transpose ops with only one dimension or
an identity permutation. The folding removes the `linalg.transpose` and
just propagates the input tensor.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
    mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
    mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
    mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 272bc3116c5fdc..92d844eefb7207 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
           }
     ```
 
-    Shortened print form is available. Applies to simple maps with one 
+    Shortened print form is available. Applies to simple maps with one
     non-yield operation inside the body.
 
     The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
                              ::mlir::OperationState & odsState);
   }];
 
+  let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a0f02f6a7f259d..919f5130e1760f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1786,6 +1786,22 @@ void TransposeOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+                                SmallVectorImpl<OpFoldResult> &result) {
+  // Single dimension transpose.
+  if (getPermutation().size() == 0) {
+    result.push_back(getInput());
+    return success();
+  }
+  // Identity permutation.
+  if (isIdentityPermutation(getPermutation())) {
+    result.push_back(getInput());
+    return success();
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 721f35162ef867..7adde3117deeaa 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1029,3 +1029,38 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
   %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
   return %0 : tensor<2x3xf32>
 }
+
+// ----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+                        %init: tensor<16xf32>) -> tensor<16xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16xf32>)
+      outs(%init:tensor<16xf32>)
+      permutation = [0]
+  func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+                                   %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x32x64xf32>)
+      permutation = [0, 1, 2]
+  func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>
+

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d63433248ab1e0..0a197a0ee9fa68 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -48,12 +48,8 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
 // CHECK:             %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]]
 // CHECK:               tensor.yield %[[PAD_VAL]]
 // CHECK:             } : tensor<?x?xf32> to tensor<8x2xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:         %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[PAD]] : tensor<8x2xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:      permutation = [0, 1]
-// CHECK:         %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %{{.+}} = tensor.insert_slice %[[PAD]] into %{{.+}}
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -81,12 +77,8 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
 // CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
 // CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC]]
 // CHECK-SAME:          [%[[IN_K]], %[[IN_C]]] [32, 8] [1, 1]
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]]
-// CHECK-SAME:          outs(%[[EMPTY]]
-// CHECK-SAME:          permutation = [0, 1]
-// CHECK:             %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %[[SUB_ITER:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
 // CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32>
 // CHECK:             %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}}
 // CHECK-SAME:          [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32>

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index eaad6bd8270476..7d87a0994004fe 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -29,12 +29,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
 // CHECK-SAME:    %[[PAD_VAL:[a-zA-Z0-9]+]]
 // CHECK:         %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1]
 // CHECK:           tensor.yield %[[PAD_VAL]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:         %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[PAD]] : tensor<8x2xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:      permutation = [0, 1]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-NOT:     linalg.transpose
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 
@@ -47,12 +43,8 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
 // CHECK-LABEL: func.func @simple_NC_to_CNnc
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[SRC]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x8xf32>)
-// CHECK-SAME:      permutation = [0, 1]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-NOT:     linalg.transpose
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
index f0d4b790520e03..7d64331c987841 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -57,12 +57,8 @@ func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13
 // CHECK-SAME:          [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
 // CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
 // CHECK-SAME:          [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-SAME:          outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:          permutation = [0, 1]
-// CHECK:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TILE]]
 // CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
 // CHECK:             %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
 // CHECK-SAME:          [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
@@ -96,12 +92,8 @@ func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>)
 // CHECK-SAME:          [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
 // CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]]
-// CHECK-SAME:          outs(%[[EMPTY]]
-// CHECK-SAME:          permutation = [0, 1]
-// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-NOT:         linalg.transpose
+// CHECK:             %[[INSERT:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
 // CHECK-SAME:          [%[[K]], %[[C]]] [32, 8] [1, 1]
 
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index 02376808865006..153ce68b8f086c 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -27,14 +27,10 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
 // CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME:      permutation = [0, 1]
+// CHECK-NOT:     linalg.transpose
 //                They have the same type, so the insert_slice op is folded
 //                away.
-// CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1]
+// CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
 // CHECK:         return %[[SLICE]]
 
 // -----
@@ -47,14 +43,10 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
 // CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<32x8xf32>)
-// CHECK-SAME:      permutation = [0, 1]
+// CHECK-NOT:     linalg.transpose
 //                They have the same type, so the insert_slice op is folded
 //                away.
-// CHECK:         return %[[TRANSP]]
+// CHECK:         return %[[TILE]]
 
 // -----
 
@@ -75,7 +67,6 @@ func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x
 //                away.
 // CHECK:         return %[[TRANSP]]
 
-
 // -----
 
 func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {


        


More information about the Mlir-commits mailing list