[Mlir-commits] [mlir] e000b62 - [mlir][vector] Separate out vector transfer + tensor slice patterns

Lei Zhang llvmlistbot at llvm.org
Wed May 17 09:01:27 PDT 2023


Author: Lei Zhang
Date: 2023-05-17T09:01:19-07:00
New Revision: e000b62a342cac907fd77cfdd070f0b055f0c3c4

URL: https://github.com/llvm/llvm-project/commit/e000b62a342cac907fd77cfdd070f0b055f0c3c4
DIFF: https://github.com/llvm/llvm-project/commit/e000b62a342cac907fd77cfdd070f0b055f0c3c4.diff

LOG: [mlir][vector] Separate out vector transfer + tensor slice patterns

These patterns touches the structure generated from tiling so it
affects later steps like bufferization and vector hoisting.
Instead of putting them in canonicalization, this commit creates
separate entry points for them to be called explicitly.

This is NFC regarding the functionality and tests of those patterns.
It also addresses two TODO items in the codebase.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D150702

Added: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 325860079b3db..2912c02528723 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -218,6 +218,17 @@ void populateBreakDownVectorBitCastOpPatterns(
 void populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
+/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
+/// read and write ops.
+///
+/// If `controlFn` is not nullptr, the pattern will only apply to ops where
+/// `controlFn` returns true, given the vector transfer read/write op as input.
+void populateVectorTransferTensorSliceTransforms(
+    RewritePatternSet &patterns,
+    std::function<bool(Operation *vectorOp)> controlFn = nullptr,
+    PatternBenefit benefit = 1);
+
 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
 /// `options` structure controls which operations are unrolled and the target
 /// shape.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index baabf1ae67fc9..4fe9b9fab6a7a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
@@ -2880,6 +2881,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
                                                        /*benefit=*/2);
   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+  vector::populateVectorTransferTensorSliceTransforms(patterns);
 
   patterns.add<CopyVectorizationPattern>(ctx);
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 80ea03f6e8d81..1549237f8c9fd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3692,108 +3692,7 @@ void TransferReadOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
-/// Returns true if all rank reduced in the given `extractOp` happen in leading
-/// dimensions earlier than last `trailingRank` dimensions.
-static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
-                                        unsigned trailingRank) {
-  // If no ranks are reduced at all, it's a degenerated case; always true.
-  if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
-    return true;
-
-  RankedTensorType inferredType = extractOp.inferResultType(
-      extractOp.getSourceType(), extractOp.getMixedOffsets(),
-      extractOp.getMixedSizes(), extractOp.getMixedStrides());
-  return extractOp.getType().getShape().take_back(trailingRank) ==
-         inferredType.getShape().take_back(trailingRank);
-}
-
 namespace {
-/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
-///
-/// ```
-/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
-///     : tensor<?x?xf32> to tensor<?x?xf32>
-/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
-///     : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %p0 = arith.addi %a, %e : index
-/// %p1 = arith.addi %b, %f : index
-/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
-///     : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldExtractSliceIntoTransferRead
-    : public OpRewritePattern<TransferReadOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransferReadOp xferOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (xferOp.getTransferRank() == 0)
-      return failure();
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (!xferOp.getPermutationMap().isMinorIdentity())
-      return failure();
-    if (xferOp.getMask())
-      return failure();
-    auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
-    if (!extractOp)
-      return failure();
-    if (!extractOp.hasUnitStride())
-      return failure();
-
-    // Bail on illegal rank-reduction: we need to check that the rank-reduced
-    // dims are exactly the leading dims. I.e. the following is illegal:
-    // ```
-    //    %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
-    //      tensor<2x1x4xf32> to tensor<2x4xf32>
-    //    %1 = vector.transfer_read %0[0,0], %cst :
-    //      tensor<2x4xf32>, vector<2x4xf32>
-    // ```
-    //
-    // Cannot fold into:
-    // ```
-    //    %0 = vector.transfer_read %t[0,0,0], %cst :
-    //      tensor<2x1x4xf32>, vector<2x4xf32>
-    // ```
-    // For this, check the trailing `vectorRank` dims of the extract_slice
-    // result tensor match the trailing dims of the inferred result tensor.
-    if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
-      return failure();
-
-    int64_t rankReduced =
-        extractOp.getSourceType().getRank() - extractOp.getType().getRank();
-
-    SmallVector<Value> newIndices;
-    // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
-    // indices first.
-    for (int64_t i = 0; i < rankReduced; ++i) {
-      OpFoldResult offset = extractOp.getMixedOffsets()[i];
-      newIndices.push_back(getValueOrCreateConstantIndexOp(
-          rewriter, extractOp.getLoc(), offset));
-    }
-    for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
-      OpFoldResult offset =
-          extractOp.getMixedOffsets()[it.index() + rankReduced];
-      newIndices.push_back(rewriter.create<arith::AddIOp>(
-          xferOp->getLoc(), it.value(),
-          getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
-                                          offset)));
-    }
-    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<TransferReadOp>(
-        xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
-        xferOp.getPadding(), ArrayRef<bool>{inBounds});
-
-    return success();
-  }
-};
-
 /// Store to load forwarding for transfer operations with permuation maps.
 /// Even if the permutation maps are 
diff erent we can still propagate the store
 /// into the load if the size of the dimensions read and written match. Then we
@@ -3875,13 +3774,7 @@ struct TransferReadAfterWriteToBroadcast
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  // clang-format off
-  results.add <
-               // TODO: this is brittle and should be deprecated in favor of a
-               // more general pattern that applies on-demand.
-               FoldExtractSliceIntoTransferRead,
-               TransferReadAfterWriteToBroadcast>(context);
-  // clang-format on
+  results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4217,93 +4110,6 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
   }
 };
 
-/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
-/// could directly write to the insert_slice's destination. E.g.:
-///
-/// ```
-/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
-///     : vector<4x5xf32>, tensor<4x5xf32>
-/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
-///     : tensor<4x5xf32> into tensor<?x?xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
-///     : vector<4x5xf32>, tensor<?x?xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldInsertSliceIntoTransferWrite
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    if (!insertOp.hasUnitStride())
-      return failure();
-
-    auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
-    if (!xferOp)
-      return failure();
-    // TODO: support 0-d corner case.
-    if (xferOp.getTransferRank() == 0)
-      return failure();
-
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
-      return failure();
-    if (xferOp.getMask())
-      return failure();
-    // Fold only if the TransferWriteOp completely overwrites the `source` with
-    // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
-    // content is the data of the vector.
-    if (!llvm::equal(xferOp.getVectorType().getShape(),
-                     xferOp.getShapedType().getShape()))
-      return failure();
-    if (!xferOp.getPermutationMap().isIdentity())
-      return failure();
-
-    // Bail on illegal rank-reduction: we need to check that the rank-reduced
-    // dims are exactly the leading dims. I.e. the following is illegal:
-    // ```
-    //    %0 = vector.transfer_write %v, %t[0,0], %cst :
-    //      vector<2x4xf32>, tensor<2x4xf32>
-    //    %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
-    //      tensor<2x4xf32> into tensor<2x1x4xf32>
-    // ```
-    //
-    // Cannot fold into:
-    // ```
-    //    %0 = vector.transfer_write %v, %t[0,0,0], %cst :
-    //      vector<2x4xf32>, tensor<2x1x4xf32>
-    // ```
-    // For this, check the trailing `vectorRank` dims of the insert_slice result
-    // tensor match the trailing dims of the inferred result tensor.
-    int64_t rankReduced =
-        insertOp.getType().getRank() - insertOp.getSourceType().getRank();
-    int64_t vectorRank = xferOp.getVectorType().getRank();
-    RankedTensorType inferredSourceTensorType =
-        tensor::ExtractSliceOp::inferResultType(
-            insertOp.getType(), insertOp.getMixedOffsets(),
-            insertOp.getMixedSizes(), insertOp.getMixedStrides());
-    auto actualSourceTensorShape = insertOp.getSourceType().getShape();
-    if (rankReduced > 0 &&
-        actualSourceTensorShape.take_back(vectorRank) !=
-            inferredSourceTensorType.getShape().take_back(vectorRank))
-      return failure();
-
-    SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
-        rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
-    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
-                                                 insertOp.getDest(), indices,
-                                                 ArrayRef<bool>{inBounds});
-    return success();
-  }
-};
-
 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
 /// overwritten and inserted into another tensor. After this rewrite, the
@@ -4415,13 +4221,7 @@ struct SwapExtractSliceOfTransferWrite
 
 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  // clang-format off
-  results.add<FoldWaw,
-              // TODO: this is brittle and should be deprecated in favor of a
-              // more general pattern that applies on-demand.
-              FoldInsertSliceIntoTransferWrite,
-              SwapExtractSliceOfTransferWrite>(context);
-  // clang-format on
+  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index deba91573e0ff..2d269ca3555d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorDropLeadUnitDim.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorTransferOpTransforms.cpp
+  VectorTransferTensorSliceTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUnroll.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
new file mode 100644
index 0000000000000..b3bd2cc85dfec
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
@@ -0,0 +1,237 @@
+//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Returns true if all rank reduced in the given `extractOp` happen in leading
+/// dimensions earlier than last `trailingRank` dimensions.
+static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
+                                        unsigned trailingRank) {
+  // If no ranks are reduced at all, it's a degenerated case; always true.
+  if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
+    return true;
+
+  RankedTensorType inferredType = extractOp.inferResultType(
+      extractOp.getSourceType(), extractOp.getMixedOffsets(),
+      extractOp.getMixedSizes(), extractOp.getMixedStrides());
+  return extractOp.getType().getShape().take_back(trailingRank) ==
+         inferredType.getShape().take_back(trailingRank);
+}
+
+namespace {
+/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
+///
+/// ```
+/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
+///     : tensor<?x?xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %p0 = arith.addi %a, %e : index
+/// %p1 = arith.addi %b, %f : index
+/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldExtractSliceIntoTransferRead final
+    : public OpRewritePattern<vector::TransferReadOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  FoldExtractSliceIntoTransferRead(MLIRContext *context,
+                                   std::function<bool(Operation *op)> controlFn,
+                                   PatternBenefit benefit)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (controlFn && !controlFn(xferOp))
+      return failure();
+
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
+    if (!xferOp.getPermutationMap().isMinorIdentity())
+      return failure();
+    if (xferOp.getMask())
+      return failure();
+    auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      return failure();
+    if (!extractOp.hasUnitStride())
+      return failure();
+
+    // Bail on illegal rank-reduction: we need to check that the rank-reduced
+    // dims are exactly the leading dims. I.e. the following is illegal:
+    // ```
+    //    %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
+    //      tensor<2x1x4xf32> to tensor<2x4xf32>
+    //    %1 = vector.transfer_read %0[0,0], %cst :
+    //      tensor<2x4xf32>, vector<2x4xf32>
+    // ```
+    //
+    // Cannot fold into:
+    // ```
+    //    %0 = vector.transfer_read %t[0,0,0], %cst :
+    //      tensor<2x1x4xf32>, vector<2x4xf32>
+    // ```
+    // For this, check the trailing `vectorRank` dims of the extract_slice
+    // result tensor match the trailing dims of the inferred result tensor.
+    if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
+      return failure();
+
+    int64_t rankReduced =
+        extractOp.getSourceType().getRank() - extractOp.getType().getRank();
+
+    SmallVector<Value> newIndices;
+    // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
+    // indices first.
+    for (int64_t i = 0; i < rankReduced; ++i) {
+      OpFoldResult offset = extractOp.getMixedOffsets()[i];
+      newIndices.push_back(getValueOrCreateConstantIndexOp(
+          rewriter, extractOp.getLoc(), offset));
+    }
+    for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
+      OpFoldResult offset =
+          extractOp.getMixedOffsets()[it.index() + rankReduced];
+      newIndices.push_back(rewriter.create<arith::AddIOp>(
+          xferOp->getLoc(), it.value(),
+          getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
+                                          offset)));
+    }
+    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+        xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
+        xferOp.getPadding(), ArrayRef<bool>{inBounds});
+
+    return success();
+  }
+
+private:
+  std::function<bool(Operation *)> controlFn;
+};
+
+/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
+/// could directly write to the insert_slice's destination. E.g.:
+///
+/// ```
+/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
+///     : vector<4x5xf32>, tensor<4x5xf32>
+/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
+///     : tensor<4x5xf32> into tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
+///     : vector<4x5xf32>, tensor<?x?xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldInsertSliceIntoTransferWrite final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  FoldInsertSliceIntoTransferWrite(MLIRContext *context,
+                                   std::function<bool(Operation *op)> controlFn,
+                                   PatternBenefit benefit)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    if (!insertOp.hasUnitStride())
+      return failure();
+
+    auto xferOp = insertOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+    if (!xferOp)
+      return failure();
+    if (controlFn && !controlFn(xferOp))
+      return failure();
+
+    // TODO: support 0-d corner case.
+    if (xferOp.getTransferRank() == 0)
+      return failure();
+
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
+    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
+      return failure();
+    if (xferOp.getMask())
+      return failure();
+    // Fold only if the TransferWriteOp completely overwrites the `source` with
+    // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
+    // content is the data of the vector.
+    if (!llvm::equal(xferOp.getVectorType().getShape(),
+                     xferOp.getShapedType().getShape()))
+      return failure();
+    if (!xferOp.getPermutationMap().isIdentity())
+      return failure();
+
+    // Bail on illegal rank-reduction: we need to check that the rank-reduced
+    // dims are exactly the leading dims. I.e. the following is illegal:
+    // ```
+    //    %0 = vector.transfer_write %v, %t[0,0], %cst :
+    //      vector<2x4xf32>, tensor<2x4xf32>
+    //    %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
+    //      tensor<2x4xf32> into tensor<2x1x4xf32>
+    // ```
+    //
+    // Cannot fold into:
+    // ```
+    //    %0 = vector.transfer_write %v, %t[0,0,0], %cst :
+    //      vector<2x4xf32>, tensor<2x1x4xf32>
+    // ```
+    // For this, check the trailing `vectorRank` dims of the insert_slice result
+    // tensor match the trailing dims of the inferred result tensor.
+    int64_t rankReduced =
+        insertOp.getType().getRank() - insertOp.getSourceType().getRank();
+    int64_t vectorRank = xferOp.getVectorType().getRank();
+    RankedTensorType inferredSourceTensorType =
+        tensor::ExtractSliceOp::inferResultType(
+            insertOp.getType(), insertOp.getMixedOffsets(),
+            insertOp.getMixedSizes(), insertOp.getMixedStrides());
+    auto actualSourceTensorShape = insertOp.getSourceType().getShape();
+    if (rankReduced > 0 &&
+        actualSourceTensorShape.take_back(vectorRank) !=
+            inferredSourceTensorType.getShape().take_back(vectorRank))
+      return failure();
+
+    SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
+        rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
+    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        insertOp, xferOp.getVector(), insertOp.getDest(), indices,
+        ArrayRef<bool>{inBounds});
+    return success();
+  }
+
+private:
+  std::function<bool(Operation *)> controlFn;
+};
+
+} // namespace
+
+void vector::populateVectorTransferTensorSliceTransforms(
+    RewritePatternSet &patterns,
+    std::function<bool(Operation *vectorOp)> controlFn,
+    PatternBenefit benefit) {
+  patterns
+      .add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite>(
+          patterns.getContext(), controlFn, benefit);
+}

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 88c91ff46a8b7..4ce4350f0e4f3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1201,116 +1201,6 @@ func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
 
 // -----
 
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
-//       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
-  return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
-//       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
-  return %1 : vector<6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
-//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
-//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
-//       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
-  return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
-//       CHECK:   extract_slice
-//       CHECK:   vector.transfer_read
-func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
-  %c3 = arith.constant 3 : index
-  %c4 = arith.constant 4 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
-  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
-  return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
-//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
-//       CHECK:   return %[[r]]
-func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
-  %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
-  return %1 : tensor<?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
-//       CHECK:   vector.transfer_write
-//       CHECK:   insert_slice
-func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
-  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
-  return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
-//       CHECK:   return %[[r]]
-func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
-  %c0 = arith.constant 0 : index
-  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
-  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
-  return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
 //       CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: func @swap_extract_slice_transfer_write

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
new file mode 100644
index 0000000000000..cc17025fe0f1e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
+  return %1 : vector<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
+//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
+//       CHECK:   extract_slice
+//       CHECK:   vector.transfer_read
+func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//       CHECK:   %[[c3:.*]] = arith.constant 3 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+//       CHECK:   return %[[r]]
+func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+  return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
+//       CHECK:   vector.transfer_write
+//       CHECK:   insert_slice
+func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+//       CHECK:   return %[[r]]
+func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d0c79ab989151..50dfeff635ccf 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -679,6 +679,26 @@ struct TestVectorGatherLowering
   }
 };
 
+struct TestVectorTransferTensorSlicePatterns
+    : public PassWrapper<TestVectorTransferTensorSlicePatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorTransferTensorSlicePatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-transfer-tensor-slice-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns that fold vector transfer and tensor slice ops";
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorTransferTensorSliceTransforms(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -713,6 +733,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestCreateVectorBroadcast>();
 
   PassRegistration<TestVectorGatherLowering>();
+
+  PassRegistration<TestVectorTransferTensorSlicePatterns>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list