[Mlir-commits] [mlir] 37fecfa - [mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern (#138921)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 08:51:51 PDT 2025
Author: Vivian Zhang
Date: 2025-05-08T08:51:48-07:00
New Revision: 37fecfaa63eef7bd9dff9c16d74e61c99e3ce70a
URL: https://github.com/llvm/llvm-project/commit/37fecfaa63eef7bd9dff9c16d74e61c99e3ce70a
DIFF: https://github.com/llvm/llvm-project/commit/37fecfaa63eef7bd9dff9c16d74e61c99e3ce70a.diff
LOG: [mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern (#138921)
This PR fixes `ExtractSliceOfPadTensorSwapPattern` to support
rank-reducing `tensor.extract_slice` ops, which were previously
unhandled and could cause crashes. To support this, an additional
`tensor.extract_slice` is inserted after `tensor.pad` to reduce the
result rank.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6700b4e0c2cb6..8718c57b9e86c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1017,9 +1017,22 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
sliceOp.getMixedSizes(), zeroSliceGuard);
if (failed(tilingResult))
return failure();
- // All shapes are static and the data source is actually used. Rewrite into
- // pad(extract_slice(x)).
- rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+
+ RankedTensorType sourceType = sliceOp.getSourceType();
+ RankedTensorType resultType = sliceOp.getResultType();
+
+ // If the extract_slice is not rank-reduced, all shapes are static and the
+ // data source is actually used. Rewrite into pad(extract_slice(x)).
+ if (sourceType.getRank() == resultType.getRank()) {
+ rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+ return success();
+ }
+
+ // Handle rank-reduced slice by creating another extract_slice op.
+ Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
+
+ rewriter.replaceOp(sliceOp, rankReduced);
return success();
}
diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index d43b9a7ac6c04..6a056bab98807 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
// -----
+// CHECK-LABEL: @static_rank_reduce
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
+// CHECK: } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
+// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
+// CHECK: return %[[RESULT]]
+func.func @static_rank_reduce(%arg0: tensor<8x16x4xf32>, %pad: f32)
+ -> tensor<16x4xf32> {
+ %0 = tensor.pad %arg0 low[0, 2, 0] high[0, 0, 0] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %pad : f32
+ } : tensor<8x16x4xf32> to tensor<8x18x4xf32>
+ %1 = tensor.extract_slice %0[0, 0, 0] [1, 16, 4] [1, 1, 1]
+ : tensor<8x18x4xf32> to tensor<16x4xf32>
+ return %1 : tensor<16x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @dynamic_high_pad
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
// CHECK-NOT: tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
return %1 : tensor<?x?xf32>
}
+// -----
+
+// CHECK-LABEL: @dynamic_rank_reduce
+// CHECK: %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
+// CHECK: tensor.generate
+// CHECK: } else {
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
+// CHECK: tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
+// CHECK: } : tensor<?x1xf32> to tensor<1x4xf32>
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
+// CHECK: return %[[RESULT]]
+func.func @dynamic_rank_reduce(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> tensor<4xf32> {
+ %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad : f32
+ } : tensor<?x5xf32> to tensor<?x13xf32>
+ %1 = tensor.extract_slice %0[2, 4] [1, 4] [1, 1] : tensor<?x13xf32> to tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
+
// -----
// CHECK-LABEL: @nopaddim_with_dynamic_extract(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>
More information about the Mlir-commits
mailing list