[Mlir-commits] [mlir] 3cb0128 - [mlir][tensor] Do not fold rank-reduced extract_slice into unpack op.

Hanhan Wang llvmlistbot at llvm.org
Tue Aug 15 10:47:35 PDT 2023


Author: Hanhan Wang
Date: 2023-08-15T10:28:21-07:00
New Revision: 3cb0128d267f531fc74b842f1f9f27e6455dc3b2

URL: https://github.com/llvm/llvm-project/commit/3cb0128d267f531fc74b842f1f9f27e6455dc3b2
DIFF: https://github.com/llvm/llvm-project/commit/3cb0128d267f531fc74b842f1f9f27e6455dc3b2.diff

LOG: [mlir][tensor] Do not fold rank-reduced extract_slice into unpack op.

Differential Revision: https://reviews.llvm.org/D157927

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
    mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 744e49edcb6c92..9eac3e5c7ef910 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -59,6 +59,11 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     if (!unpackOp)
       return failure();
 
+    if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "rank-reduced folding is not supported");
+    }
+
     // Check all offsets are zeros, and all strides are ones.
     if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
         !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {

diff  --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 0981faf8a1f262..5c757896657427 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -45,6 +45,19 @@ func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg
 
 // -----
 
+func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index) -> tensor<f32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
+  return %1 : tensor<f32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
 func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32


        


More information about the Mlir-commits mailing list