[Mlir-commits] [mlir] ee285fa - [mlir] Do not bubble up extract slice when it is rank-reducing.

Okwan Kwon llvmlistbot at llvm.org
Fri Apr 22 12:25:50 PDT 2022


Author: Okwan Kwon
Date: 2022-04-22T12:21:47-07:00
New Revision: ee285faed2e86e0f913f621ee86732ac8f0798f7

URL: https://github.com/llvm/llvm-project/commit/ee285faed2e86e0f913f621ee86732ac8f0798f7
DIFF: https://github.com/llvm/llvm-project/commit/ee285faed2e86e0f913f621ee86732ac8f0798f7.diff

LOG: [mlir] Do not bubble up extract slice when it is rank-reducing.

The bubble up logic was written by assuming the slice operation is
always a normal slice that outputs a tensor with the same rank.

Differential Revision: https://reviews.llvm.org/D124283

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
    mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 53200d86511ce..7b45f3eaa0d4d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -76,6 +76,10 @@ struct BubbleUpExtractSliceOpPattern
     if (!sliceOp.hasUnitStride())
       return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
 
+    if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
+      return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
+    }
+
     OpOperand *outOperand = linalgOp.getOutputOperand(0);
     AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
     if (!indexingMap.isProjectedPermutation()) {

diff  --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
index 126927e8f07d8..a6aa94cb1e8ad 100644
--- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
@@ -156,3 +156,22 @@ func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32x
 // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
 // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
 // CHECK: return %[[CONV]] : tensor<1x32x32x16xf32>
+
+//-----
+
+// The slice is not supposed to be bubbled up when it is rank-reducing.
+func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %init = linalg.init_tensor [1, %width] : tensor<1x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
+  %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+  %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
+  return %expand : tensor<1x1x1x?xf32>
+}
+
+// CHECK: func @rank_reducing_slice
+// CHECK: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: %[[FILL:.+]] = linalg.fill ins
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]]
+// CHECK: return %[[EXPAND]]


        


More information about the Mlir-commits mailing list