[Mlir-commits] [mlir] 6859f8e - [mlir][linalg] Adapt the PadTensorOpVectorizationWithInsertSlicePattern matching.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 13 05:00:14 PST 2021
Author: gysit
Date: 2021-12-13T12:55:07Z
New Revision: 6859f8ed1ef795cc032e0224bd903fdd54f79c16
URL: https://github.com/llvm/llvm-project/commit/6859f8ed1ef795cc032e0224bd903fdd54f79c16
DIFF: https://github.com/llvm/llvm-project/commit/6859f8ed1ef795cc032e0224bd903fdd54f79c16.diff
LOG: [mlir][linalg] Adapt the PadTensorOpVectorizationWithInsertSlicePattern matching.
Tighten the matcher of the PadTensorOpVectorizationWithInsertSlicePattern pattern. Only match if the PadOp result is used by the InsertSliceOp source. Fail if the result is used by the InsertSliceOp dest.
Depends On D115336
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D115359
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 099a91d0a0a8..bc02298d9d00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1015,6 +1015,7 @@ struct PadTensorOpVectorizationWithTransferWritePattern
/// (Implies that sizes of `insertOp` are all static.)
/// - Only unit strides in `insertOp`.
/// - Single, scalar padding value.
+/// - `padOp` result not used as destination.
struct PadTensorOpVectorizationWithInsertSlicePattern
: public VectorizePadTensorOpUserPattern<tensor::InsertSliceOp> {
using VectorizePadTensorOpUserPattern<
@@ -1035,6 +1036,9 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
// Dynamic shapes not supported.
if (!padOp.result().getType().cast<ShapedType>().hasStaticShape())
return failure();
+ // Pad result not used as destination.
+ if (insertOp.dest() == padOp.result())
+ return failure();
auto vecType = VectorType::get(padOp.getType().getShape(),
padOp.getType().getElementType());
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 934f8e349936..370c3e009814 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -684,7 +684,7 @@ func @pad_and_transfer_write_dynamic_static(
func private @make_vector() -> tensor<12x13xf32>
-// CHECK-LABEL: func @pad_and_insert_slice
+// CHECK-LABEL: func @pad_and_insert_slice_source
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK-NOT: linalg.pad_tensor
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -693,7 +693,7 @@ func private @make_vector() -> tensor<12x13xf32>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
// CHECK: return %[[WRITE]]
-func @pad_and_insert_slice(
+func @pad_and_insert_slice_source(
%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5.0 : f32
@@ -708,6 +708,26 @@ func @pad_and_insert_slice(
// -----
+func private @make_vector() -> tensor<12x13xf32>
+
+// CHECK-LABEL: func @pad_and_insert_slice_dest
+// Check the insert slice is not rewritten if the padded result is used by the destination operand.
+// CHECK: %[[T1:.*]] = call @make_vector() : () -> tensor<12x13xf32>
+// CHECK: = tensor.insert_slice %[[T1]] into
+func @pad_and_insert_slice_dest(
+ %arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
+ %c5 = arith.constant 5.0 : f32
+ %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 7, 7] {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index):
+ linalg.yield %c5 : f32
+ } : tensor<1x5x6xf32> to tensor<1x12x13xf32>
+ %1 = call @make_vector() : () -> tensor<12x13xf32>
+ %r = tensor.insert_slice %1 into %0[0, 0, 0][1, 12, 13][1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
+ return %r : tensor<1x12x13xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @pad_tensor_non_const_pad_value
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK-NOT: linalg.pad_tensor
More information about the Mlir-commits
mailing list