[Mlir-commits] [mlir] 7063c94 - [mlir][Linalg] Bugfix for folder of `linalg.transpose` (#102888)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 21 00:49:13 PDT 2024


Author: Longsheng Mou
Date: 2024-08-21T15:49:10+08:00
New Revision: 7063c9427e11b5028ab2e926768faa7ff431bb85

URL: https://github.com/llvm/llvm-project/commit/7063c9427e11b5028ab2e926768faa7ff431bb85
DIFF: https://github.com/llvm/llvm-project/commit/7063c9427e11b5028ab2e926768faa7ff431bb85.diff

LOG: [mlir][Linalg] Bugfix for folder of `linalg.transpose` (#102888)

Folder of linalg transpose only support tensor type. Fix #102576.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 775ed8f37344ed..76df3ecf2d2bd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1908,6 +1908,10 @@ void TransposeOp::getEffects(
 
 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
                                 SmallVectorImpl<OpFoldResult> &result) {
+  // Only the tensor type is supported.
+  if (!isa<TensorType>(getInput().getType()))
+    return failure();
+
   // Single dimension transpose.
   if (getPermutation().size() == 0) {
     result.push_back(getInput());

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a50fbb0fc3b86c..4bc2ed140da91a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1216,3 +1216,19 @@ func.func @concats_of_fill(
 //       CHECK:   %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
 //       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
 //       CHECK:   return %[[FILL]]
+
+// -----
+
+func.func @transpose_buffer(%input: memref<?xf32>,
+                            %init: memref<?xf32>) {
+  linalg.transpose ins(%input:memref<?xf32>)
+                   outs(%init:memref<?xf32>)
+                   permutation = [0]
+  func.return
+}
+
+// CHECK-LABEL:   func.func @transpose_buffer(
+//  CHECK-SAME:            %[[VAL_0:.*]]: memref<?xf32>,
+//  CHECK-SAME:            %[[VAL_1:.*]]: memref<?xf32>) {
+//       CHECK:     linalg.transpose ins(%[[VAL_0]] : memref<?xf32>)
+//  CHECK-SAME:       outs(%[[VAL_1]] : memref<?xf32>) permutation = [0]

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b818170a8e7974..6ddbd06389f5eb 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -873,3 +873,39 @@ func.func @lower_to_loops_with_rank_reducing_subviews(
 //       CHECKPARALLEL:     %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]]
 //       CHECKPARALLEL:     memref.store %[[VAL]], %{{.+}}[%[[IV]]]
 //       CHECKPARALLEL:   }
+
+// -----
+
+func.func @transpose(%input: memref<?xf32>,
+                     %init: memref<?xf32>) {
+  linalg.transpose ins(%input:memref<?xf32>)
+                   outs(%init:memref<?xf32>)
+                   permutation = [0]
+  return
+}
+// CHECK-LABEL:   func.func @transpose(
+// CHECK-SAME:                         %[[VAL_0:.*]]: memref<?xf32>,
+// CHECK-SAME:                         %[[VAL_1:.*]]: memref<?xf32>) {
+//      CHECK:      %[[VAL_2:.*]] = arith.constant 1 : index
+//      CHECK:      %[[VAL_3:.*]] = arith.constant 0 : index
+//      CHECK:      %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?xf32>
+//      CHECK:      scf.for %[[VAL_5:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_2]] {
+//      CHECK:        %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+//      CHECK:        memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+//      CHECK:      }
+//      CHECK:      return
+//      CHECK:    }
+
+// CHECKPARALLEL-LABEL:   func.func @transpose(
+// CHECKPARALLEL-SAME:                         %[[VAL_0:.*]]: memref<?xf32>,
+// CHECKPARALLEL-SAME:                         %[[VAL_1:.*]]: memref<?xf32>) {
+//      CHECKPARALLEL:      %[[VAL_2:.*]] = arith.constant 1 : index
+//      CHECKPARALLEL:      %[[VAL_3:.*]] = arith.constant 0 : index
+//      CHECKPARALLEL:      %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?xf32>
+//      CHECKPARALLEL:      scf.parallel (%[[VAL_5:.*]]) = (%[[VAL_3]]) to (%[[VAL_4]]) step (%[[VAL_2]]) {
+//      CHECKPARALLEL:        %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+//      CHECKPARALLEL:        memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<?xf32>
+//      CHECKPARALLEL:        scf.reduce
+//      CHECKPARALLEL:      }
+//      CHECKPARALLEL:      return
+//      CHECKPARALLEL:    }


        


More information about the Mlir-commits mailing list