[Mlir-commits] [mlir] b9a071d - [mlir][Linalg] Add folders for `linalg.transpose` (#81709)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 20 17:40:02 PST 2024
Author: Diego Caballero
Date: 2024-02-20T17:39:58-08:00
New Revision: b9a071dc3995c1599724447b9db8ced449318839
URL: https://github.com/llvm/llvm-project/commit/b9a071dc3995c1599724447b9db8ced449318839
DIFF: https://github.com/llvm/llvm-project/commit/b9a071dc3995c1599724447b9db8ced449318839.diff
LOG: [mlir][Linalg] Add folders for `linalg.transpose` (#81709)
This PR adds folders for linalg transpose ops with only one dimension or
an identity permutation. The folding removes the `linalg.transpose` and
just propagates the input tensor.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 272bc3116c5fdc..92d844eefb7207 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
```
- Shortened print form is available. Applies to simple maps with one
+ Shortened print form is available. Applies to simple maps with one
non-yield operation inside the body.
The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
::mlir::OperationState & odsState);
}];
+ let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a0f02f6a7f259d..919f5130e1760f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1786,6 +1786,22 @@ void TransposeOp::getEffects(
getDpsInits());
}
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ // Single dimension transpose.
+ if (getPermutation().size() == 0) {
+ result.push_back(getInput());
+ return success();
+ }
+ // Identity permutation.
+ if (isIdentityPermutation(getPermutation())) {
+ result.push_back(getInput());
+ return success();
+ }
+
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 721f35162ef867..7adde3117deeaa 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1029,3 +1029,38 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
%0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
return %0 : tensor<2x3xf32>
}
+
+// ----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+ %init: tensor<16xf32>) -> tensor<16xf32> {
+ %transpose = linalg.transpose
+ ins(%input:tensor<16xf32>)
+ outs(%init:tensor<16xf32>)
+ permutation = [0]
+ func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+// CHECK-NOT: linalg.transpose
+// CHECK: return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x32x64xf32>)
+ permutation = [0, 1, 2]
+ func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+// CHECK-NOT: linalg.transpose
+// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
+
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d63433248ab1e0..0a197a0ee9fa68 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -48,12 +48,8 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC_SLICE]]
// CHECK: tensor.yield %[[PAD_VAL]]
// CHECK: } : tensor<?x?xf32> to tensor<8x2xf32>
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME: permutation = [0, 1]
-// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
+// CHECK-NOT: linalg.transpose
+// CHECK: %{{.+}} = tensor.insert_slice %[[PAD]] into %{{.+}}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -81,12 +77,8 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]]
// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [32, 8] [1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]]
-// CHECK-SAME: outs(%[[EMPTY]]
-// CHECK-SAME: permutation = [0, 1]
-// CHECK: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[SUB_ITER:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<32x8xf32> into tensor<1x1x32x8xf32>
// CHECK: %{{.+}} = tensor.insert_slice %[[SUB_ITER]] into %{{[a-zA-Z0-9]+}}
// CHECK-SAME: [%[[C]], %[[K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> into tensor<32x4x32x8xf32>
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index eaad6bd8270476..7d87a0994004fe 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -29,12 +29,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1]
// CHECK: tensor.yield %[[PAD_VAL]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[PAD]] : tensor<8x2xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME: permutation = [0, 1]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
@@ -47,12 +43,8 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
// CHECK-LABEL: func.func @simple_NC_to_CNnc
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[SRC]] : tensor<32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x8xf32>)
-// CHECK-SAME: permutation = [0, 1]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
index f0d4b790520e03..7d64331c987841 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir
@@ -57,12 +57,8 @@ func.func @unpack_and_extract_slice(%arg0: tensor<2x8x8x2xf32>, %arg1: tensor<13
// CHECK-SAME: [%[[I]], %[[J]]] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]]
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<1x1x8x2xf32> to tensor<8x2xf32>
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME: permutation = [0, 1]
-// CHECK: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TRANSP]]
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[UNPACK_TILE:.+]] = tensor.extract_slice %[[TILE]]
// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UNPACK_TILE]] into %[[ITER_SLICE]]
// CHECK-SAME: [0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]]] [1, 1]
@@ -96,12 +92,8 @@ func.func @CKkc_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>)
// CHECK-SAME: [%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]]
-// CHECK-SAME: outs(%[[EMPTY]]
-// CHECK-SAME: permutation = [0, 1]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %{{[a-zA-Z0-9]+}}
+// CHECK-NOT: linalg.transpose
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TILE]] into %{{[a-zA-Z0-9]+}}
// CHECK-SAME: [%[[K]], %[[C]]] [32, 8] [1, 1]
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index 02376808865006..153ce68b8f086c 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -27,14 +27,10 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<8x2xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xf32>)
-// CHECK-SAME: permutation = [0, 1]
+// CHECK-NOT: linalg.transpose
// They have the same type, so the insert_slice op is folded
// away.
-// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
// CHECK: return %[[SLICE]]
// -----
@@ -47,14 +43,10 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x8xf32>)
-// CHECK-SAME: permutation = [0, 1]
+// CHECK-NOT: linalg.transpose
// They have the same type, so the insert_slice op is folded
// away.
-// CHECK: return %[[TRANSP]]
+// CHECK: return %[[TILE]]
// -----
@@ -75,7 +67,6 @@ func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x
// away.
// CHECK: return %[[TRANSP]]
-
// -----
func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {
More information about the Mlir-commits
mailing list