[Mlir-commits] [mlir] b1b8227 - [mlir] Vectorize linalg.pad_tensor consumed by transfer_read

Matthias Springer llvmlistbot at llvm.org
Sun Jun 13 17:57:41 PDT 2021


Author: Matthias Springer
Date: 2021-06-14T09:52:25+09:00
New Revision: b1b822714db8ea15f811ab03084ee60ff32def21

URL: https://github.com/llvm/llvm-project/commit/b1b822714db8ea15f811ab03084ee60ff32def21
DIFF: https://github.com/llvm/llvm-project/commit/b1b822714db8ea15f811ab03084ee60ff32def21.diff

LOG: [mlir] Vectorize linalg.pad_tensor consumed by transfer_read

Vectorize linalg.pad_tensor without generating a linalg.init_tensor when consumed by a transfer_read.

Differential Revision: https://reviews.llvm.org/D103735

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 36cdf7948de6c..819b8382432de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -696,10 +696,80 @@ struct GenericPadTensorOpVectorizationPattern
   }
 };
 
+/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
+/// operation type OpTy.
+template <typename OpTy>
+struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padOp,
+                                PatternRewriter &rewriter) const final {
+    bool changed = false;
+    // Insert users in vector, because some users may be replaced/removed.
+    for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
+      if (auto op = dyn_cast<OpTy>(user))
+        changed |= rewriteUser(rewriter, padOp, op).succeeded();
+    return success(changed);
+  }
+
+ protected:
+  virtual LogicalResult rewriteUser(
+      PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
+};
+
+/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
+/// ```
+/// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
+/// %r = vector.transfer_read %0[%c0, %c0], %cst
+///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %r = vector.transfer_read %src[%c0, %c0], %padding
+///     {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<17x5xf32>
+/// ```
+/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
+/// sure that the original padding value %cst was never used.
+///
+/// This rewrite is possible if:
+/// - `xferOp` has no out-of-bounds dims or mask.
+/// - Low padding is static 0.
+/// - Single, scalar padding value.
+struct PadTensorOpVectorizationWithTransferReadPattern
+    : public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
+  using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
+      ::VectorizePadTensorOpUserPattern;
+
+  LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
+                            vector::TransferReadOp xferOp) const override {
+    // Low padding must be static 0.
+    if (!padOp.hasZeroLowPad()) return failure();
+    // Pad value must be a constant.
+    auto padValue = padOp.getConstantPaddingValue();
+    if (!padValue) return failure();
+    // Padding value of existing `xferOp` is unused.
+    if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
+
+    rewriter.updateRootInPlace(xferOp, [&]() {
+      SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
+      xferOp->setAttr(xferOp.getInBoundsAttrName(),
+                      rewriter.getBoolArrayAttr(inBounds));
+      xferOp.sourceMutable().assign(padOp.source());
+      xferOp.paddingMutable().assign(padValue);
+    });
+
+    return success();
+  }
+};
+
 void mlir::linalg::populatePadTensorOpVectorizationPatterns(
     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
   patterns.add<GenericPadTensorOpVectorizationPattern>(
       patterns.getContext(), baseBenefit);
+  // Try these specialized patterns first before resorting to the generic one.
+  patterns.add<PadTensorOpVectorizationWithTransferReadPattern>(
+      patterns.getContext(), baseBenefit.getBenefit() + 1);
 }
 
 // TODO: cleanup all the convolution vectorization patterns.

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 236bf3b335472..ab55f879f8ae0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -558,6 +558,28 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
 
 // -----
 
+// CHECK-LABEL: func @pad_and_transfer_read
+//  CHECK-SAME:     %[[ARG0:.*]]: tensor<5x6xf32>
+//   CHECK-NOT:   linalg.pad_tensor
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[C5:.*]] = constant 5.0
+//       CHECK:   %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
+//       CHECK:   return %[[RESULT]]
+func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
+  %c0 = constant 0 : index
+  %c5 = constant 5.0 : f32
+  %c6 = constant 6.0 : f32
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %c5 : f32
+  } : tensor<5x6xf32> to tensor<10x13xf32>
+  %1 = vector.transfer_read %0[%c0, %c0], %c6
+      : tensor<10x13xf32>, vector<7x9xf32>
+  return %1 : vector<7x9xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)>
 
 // CHECK-LABEL: func @sum_exp


        


More information about the Mlir-commits mailing list