[Mlir-commits] [mlir] fdb21f0 - [mlir][linalg] Remove generic PadTensorOp vectorization pattern
Matthias Springer
llvmlistbot at llvm.org
Sun Jun 13 18:54:22 PDT 2021
Author: Matthias Springer
Date: 2021-06-14T10:53:50+09:00
New Revision: fdb21f0c5edd17b9aeb6f5135d0980b9e4c74bf2
URL: https://github.com/llvm/llvm-project/commit/fdb21f0c5edd17b9aeb6f5135d0980b9e4c74bf2
DIFF: https://github.com/llvm/llvm-project/commit/fdb21f0c5edd17b9aeb6f5135d0980b9e4c74bf2.diff
LOG: [mlir][linalg] Remove generic PadTensorOp vectorization pattern
The generic vectorization pattern handles only those cases, where
low and high padding is zero. This is already handled by a
canonicalization pattern.
Also add a new canonicalization test case to ensure that tensor cast ops
are properly inserted.
A more general vectorization pattern will be added in a subsequent commit.
Differential Revision: https://reviews.llvm.org/D103590
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bf48f0e337e0..3e8fb3813ff2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -671,52 +671,6 @@ static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
return result;
}
-/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
-/// TransferWriteOp. For now, this only applies when all low and high paddings
-/// are determined to be zero.
-struct GenericPadTensorOpVectorizationPattern
- : public OpRewritePattern<PadTensorOp> {
- using OpRewritePattern<PadTensorOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(PadTensorOp padOp,
- PatternRewriter &rewriter) const override {
- /// Given an OpFoldResult, return true if its value is guaranteed to be a
- /// zero integer.
- auto isZeroInt = [&](OpFoldResult ofr) {
- return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); };
- // Low padding must be static 0.
- if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure();
- // High padding must be static 0.
- if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
- // Pad value must be a constant.
- auto padValue = padOp.getConstantPaddingValue();
- if (!padValue) return failure();
-
- // Bail on non-static shapes.
- auto resultShapedType = padOp.result().getType().cast<ShapedType>();
- if (!resultShapedType.hasStaticShape())
- return failure();
- VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
- if (!vectorType)
- return failure();
-
- // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
- // TransferWriteOp@[0..0].
- SmallVector<Value> indices(
- resultShapedType.getRank(),
- rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
- Value read = rewriter.create<vector::TransferReadOp>(
- padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
- Value init = rewriter.create<InitTensorOp>(
- padOp.getLoc(), resultShapedType.getShape(),
- resultShapedType.getElementType());
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
- indices);
-
- return success();
- }
-};
-
/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
/// operation type OpTy.
template <typename OpTy>
@@ -995,13 +949,14 @@ struct PadTensorOpVectorizationWithSubTensorInsertPattern
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
- patterns.add<GenericPadTensorOpVectorizationPattern>(
- patterns.getContext(), baseBenefit);
+ // TODO: Canonicalizer handles simple cases where low = 0 and high = 0, but a
+ // generic vectorization pattern is still missing.
+
// Try these specialized patterns first before resorting to the generic one.
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
PadTensorOpVectorizationWithTransferWritePattern,
PadTensorOpVectorizationWithSubTensorInsertPattern>(
- patterns.getContext(), baseBenefit.getBenefit() + 1);
+ patterns.getContext(), baseBenefit);
}
// TODO: cleanup all the convolution vectorization patterns.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index c51cbdbb3568..6fa9fc4900f6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1148,3 +1148,21 @@ func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
// CHECK-LABEL: @tensor_pad_cast
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
// CHECK: return %[[ARG0]]
+
+// -----
+
+// CHECK-LABEL: func @pad_static_zero_cast(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
+// CHECK-NOT: linalg.pad_tensor
+// CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+// CHECK: return %[[RESULT]]
+func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
+ %c0 = constant 0 : index
+ %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ linalg.yield %pad_value : f32
+ } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+
+ return %0 : tensor<2x3x4xf32>
+}
+
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bc5a2feff447..43dd39602b35 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -512,27 +512,6 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
// -----
-// CHECK-LABEL: func @pad_static
-// CHECK-NOT: linalg.pad_tensor
-func @pad_static(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]]
- // CHECK-SAME: : tensor<?x?x?xf32>, vector<2x3x4xf32>
- // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32>
- // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]]
- // CHECK-SAME: {in_bounds = [true, true, true]} : vector<2x3x4xf32>, tensor<2x3x4xf32>
- %c0 = constant 0 : index
- %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] {
- ^bb0(%arg1: index, %arg2: index, %arg3: index):
- linalg.yield %pad_value : f32
- } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
-
- // CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32>
- return %0 : tensor<2x3x4xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @pad_static_high_padding
// CHECK: linalg.pad_tensor
func @pad_static_high_padding(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
More information about the Mlir-commits
mailing list