[Mlir-commits] [mlir] 37fecfa - [mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern (#138921)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 8 08:51:51 PDT 2025


Author: Vivian Zhang
Date: 2025-05-08T08:51:48-07:00
New Revision: 37fecfaa63eef7bd9dff9c16d74e61c99e3ce70a

URL: https://github.com/llvm/llvm-project/commit/37fecfaa63eef7bd9dff9c16d74e61c99e3ce70a
DIFF: https://github.com/llvm/llvm-project/commit/37fecfaa63eef7bd9dff9c16d74e61c99e3ce70a.diff

LOG: [mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern (#138921)

This PR fixes `ExtractSliceOfPadTensorSwapPattern` to support
rank-reducing `tensor.extract_slice` ops, which were previously
unhandled and could cause crashes. To support this, an additional
`tensor.extract_slice` is inserted after `tensor.pad` to reduce the
result rank.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6700b4e0c2cb6..8718c57b9e86c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1017,9 +1017,22 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
                                sliceOp.getMixedSizes(), zeroSliceGuard);
   if (failed(tilingResult))
     return failure();
-  // All shapes are static and the data source is actually used. Rewrite into
-  // pad(extract_slice(x)).
-  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+
+  RankedTensorType sourceType = sliceOp.getSourceType();
+  RankedTensorType resultType = sliceOp.getResultType();
+
+  // If the extract_slice is not rank-reduced, all shapes are static and the
+  // data source is actually used. Rewrite into pad(extract_slice(x)).
+  if (sourceType.getRank() == resultType.getRank()) {
+    rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+    return success();
+  }
+
+  // Handle rank-reduced slice by creating another extract_slice op.
+  Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
+
+  rewriter.replaceOp(sliceOp, rankReduced);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index d43b9a7ac6c04..6a056bab98807 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
 
 // -----
 
+// CHECK-LABEL: @static_rank_reduce
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
+//       CHECK:   %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
+//       CHECK:   %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
+//       CHECK:   } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
+//       CHECK:   %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
+//       CHECK: return %[[RESULT]]
+func.func @static_rank_reduce(%arg0: tensor<8x16x4xf32>, %pad: f32)
+    -> tensor<16x4xf32> {
+  %0 = tensor.pad %arg0 low[0, 2, 0] high[0, 0, 0] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %pad : f32
+  } : tensor<8x16x4xf32> to tensor<8x18x4xf32>
+  %1 = tensor.extract_slice %0[0, 0, 0] [1, 16, 4] [1, 1, 1]
+      : tensor<8x18x4xf32> to tensor<16x4xf32>
+  return %1 : tensor<16x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @dynamic_high_pad
 //  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x5xf32>
 //   CHECK-NOT:   tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: @dynamic_rank_reduce
+//       CHECK:   %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
+//       CHECK:     tensor.generate
+//       CHECK:   } else {
+//       CHECK:     %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
+//       CHECK:     tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
+//       CHECK:     } : tensor<?x1xf32> to tensor<1x4xf32>
+//       CHECK:   }
+//       CHECK:   %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
+//       CHECK:   return %[[RESULT]]
+func.func @dynamic_rank_reduce(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> tensor<4xf32> {
+  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad : f32
+    } : tensor<?x5xf32> to tensor<?x13xf32>
+  %1 = tensor.extract_slice %0[2, 4] [1, 4] [1, 1] : tensor<?x13xf32> to tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
 // -----
 // CHECK-LABEL: @nopaddim_with_dynamic_extract(
 //  CHECK-SAME:     %[[ARG0:.*]]: tensor<3x4x5xf32>


        


More information about the Mlir-commits mailing list