[Mlir-commits] [mlir] [mlir][Linalg] Bugfix for folder of `linalg.transpose` (PR #102888)
Longsheng Mou
llvmlistbot at llvm.org
Mon Aug 12 05:11:00 PDT 2024
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/102888
Folder of linalg transpose only support tensor type. Fix #102576.
>From 9e7fd8651585f0716c79b6a6ca19618af2894a73 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Mon, 12 Aug 2024 20:05:49 +0800
Subject: [PATCH] [mlir][Linalg] Bugfix for folder of `linalg.transpose`
Folder of linalg transpose only support tensor type.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +++
mlir/test/Dialect/Linalg/canonicalize.mlir | 16 ++++++++++
mlir/test/Dialect/Linalg/loops.mlir | 36 ++++++++++++++++++++++
3 files changed, 56 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a101552e419bc8..44b0a4b26588fc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1908,6 +1908,10 @@ void TransposeOp::getEffects(
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
+ // Only the tensor type is supported.
+ if (!isa<TensorType>(getInput().getType()))
+ return failure();
+
// Single dimension transpose.
if (getPermutation().size() == 0) {
result.push_back(getInput());
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a50fbb0fc3b86c..4bc2ed140da91a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1216,3 +1216,19 @@ func.func @concats_of_fill(
// CHECK: %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
// CHECK: return %[[FILL]]
+
+// -----
+
+func.func @transpose_buffer(%input: memref<?xf32>,
+ %init: memref<?xf32>) {
+ linalg.transpose ins(%input:memref<?xf32>)
+ outs(%init:memref<?xf32>)
+ permutation = [0]
+ func.return
+}
+
+// CHECK-LABEL: func.func @transpose_buffer(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf32>) {
+// CHECK: linalg.transpose ins(%[[VAL_0]] : memref<?xf32>)
+// CHECK-SAME: outs(%[[VAL_1]] : memref<?xf32>) permutation = [0]
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b818170a8e7974..6ddbd06389f5eb 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -873,3 +873,39 @@ func.func @lower_to_loops_with_rank_reducing_subviews(
// CHECKPARALLEL: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
// CHECKPARALLEL: memref.store %[[VAL]], %{{.+}}[%[[IV]]]
// CHECKPARALLEL: }
+
+// -----
+
+func.func @transpose(%input: memref<?xf32>,
+ %init: memref<?xf32>) {
+ linalg.transpose ins(%input:memref<?xf32>)
+ outs(%init:memref<?xf32>)
+ permutation = [0]
+ return
+}
+// CHECK-LABEL: func.func @transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<?xf32>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?xf32>
+// CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_2]] {
+// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+// CHECK: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// CHECKPARALLEL-LABEL: func.func @transpose(
+// CHECKPARALLEL-SAME: %[[VAL_0:.*]]: memref<?xf32>,
+// CHECKPARALLEL-SAME: %[[VAL_1:.*]]: memref<?xf32>) {
+// CHECKPARALLEL: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECKPARALLEL: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECKPARALLEL: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?xf32>
+// CHECKPARALLEL: scf.parallel (%[[VAL_5:.*]]) = (%[[VAL_3]]) to (%[[VAL_4]]) step (%[[VAL_2]]) {
+// CHECKPARALLEL: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+// CHECKPARALLEL: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+// CHECKPARALLEL: scf.reduce
+// CHECKPARALLEL: }
+// CHECKPARALLEL: return
+// CHECKPARALLEL: }
More information about the Mlir-commits
mailing list