[Mlir-commits] [mlir] 867afe5 - [mlir][vector] Remove duplicate tensor subset <-> vector transfer patterns

Matthias Springer llvmlistbot at llvm.org
Tue Jul 11 02:16:23 PDT 2023


Author: Matthias Springer
Date: 2023-07-11T11:12:29+02:00
New Revision: 867afe5e53194607d671dbf1ae5d190793f02d1f

URL: https://github.com/llvm/llvm-project/commit/867afe5e53194607d671dbf1ae5d190793f02d1f
DIFF: https://github.com/llvm/llvm-project/commit/867afe5e53194607d671dbf1ae5d190793f02d1f.diff

LOG: [mlir][vector] Remove duplicate tensor subset <-> vector transfer patterns

Remove patterns that fold tensor subset ops into vector transfer ops from the vector dialect. These patterns already exist in the tensor dialect.

Differential Revision: https://reviews.llvm.org/D154932

Added: 
    mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Removed: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 5d9ced89603859..66c6021418b471 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -63,6 +63,18 @@ def ApplyFoldTensorSubsetOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor.extract_slice -> vector.transfer_read and
+    vector.transfer_write -> tensor.insert_slice op chains should be folded into
+    vector tranfer read and write ops
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyMergeConsecutiveInsertExtractSlicePatternsOp : Op<Transform_Dialect,
     "apply_patterns.tensor.merge_consecutive_insert_extract_slice",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index c7e157e01d06a7..705b30e7ded477 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -34,10 +34,15 @@ FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
 // Populate functions.
 //===----------------------------------------------------------------------===//
 
-/// Appends patterns for folding tensor aliasing ops into consumer load/store
-/// ops into `patterns`.
+/// Appends patterns for folding tensor subset ops into consumer load/store
+/// ops into `patterns`. (This includes patterns for folding tensor subset ops
+/// into vector transfer ops.)
 void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns for folding tensor subset ops into vector transfer ops.
+void populateFoldTensorSubsetIntoVectorTransferPatterns(
+    RewritePatternSet &patterns);
+
 /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
 /// into one. These patterns are in this separate entry point because the
 /// bufferization is sensitive to IR structure, particularly those

diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 806c3f9fca50db..253aeedf15aba5 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,16 +306,4 @@ def ApplyTransferToScfPatternsOp : Op<Transform_Dialect,
   }];
 }
 
-def ApplyFoldTensorSliceIntoTransferPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.vector.fold_tensor_slice_into_transfer",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Indicates that tensor.extract_slice -> vector.transfer_read and
-    vector.transfer_write -> tensor.insert_slice op chains should be folded into
-    vector tranfer read and write ops
-  }];
-
-  let assemblyFormat = "attr-dict";
-}
-
 #endif // VECTOR_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 55fd2fcd34b68b..12254bc215db57 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -214,17 +214,6 @@ void populateBreakDownVectorBitCastOpPatterns(
 void populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
-/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
-/// read and write ops.
-///
-/// If `controlFn` is not nullptr, the pattern will only apply to ops where
-/// `controlFn` returns true, given the vector transfer read/write op as input.
-void populateVectorTransferTensorSliceTransforms(
-    RewritePatternSet &patterns,
-    std::function<bool(Operation *vectorOp)> controlFn = nullptr,
-    PatternBenefit benefit = 1);
-
 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
 /// `options` structure controls which operations are unrolled and the target
 /// shape.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index cf12e78145fd77..e86bb7c545109c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2973,7 +2973,7 @@ transform::VectorizeOp::applyToOne(transform::TransformRewriter &rewriter,
                                                        /*benefit=*/2);
   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
-  vector::populateVectorTransferTensorSliceTransforms(patterns);
+  tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
 
   patterns.add<CopyVectorizationPattern>(ctx);
 

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 6ee4bfad23bbf8..3cec9138939224 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -103,6 +103,11 @@ void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
   tensor::populateFoldTensorSubsetOpPatterns(patterns);
 }
 
+void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
+}
+
 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
     populatePatterns(RewritePatternSet &patterns) {
   tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 882b8db6b6c57e..3b8d3708bb7314 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -222,12 +222,18 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
 };
 
 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
-  patterns.add<TransferReadOfExtractSliceOpFolder,
-               InsertSliceOfTransferWriteOpFolder,
-               InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
+  populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
+  patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
                InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
       patterns.getContext());
 }
+
+void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferReadOfExtractSliceOpFolder,
+               InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // Pass registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 505cb5c11253a5..da99232ed6ab8f 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,11 +138,6 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
-void transform::ApplyFoldTensorSliceIntoTransferPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  populateVectorTransferTensorSliceTransforms(patterns);
-}
-
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 2d269ca3555d56..deba91573e0ff1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorDropLeadUnitDim.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorTransferOpTransforms.cpp
-  VectorTransferTensorSliceTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUnroll.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
deleted file mode 100644
index b3bd2cc85dfecc..00000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
+++ /dev/null
@@ -1,237 +0,0 @@
-//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
-#include "mlir/IR/PatternMatch.h"
-
-using namespace mlir;
-
-/// Returns true if all rank reduced in the given `extractOp` happen in leading
-/// dimensions earlier than last `trailingRank` dimensions.
-static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
-                                        unsigned trailingRank) {
-  // If no ranks are reduced at all, it's a degenerated case; always true.
-  if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
-    return true;
-
-  RankedTensorType inferredType = extractOp.inferResultType(
-      extractOp.getSourceType(), extractOp.getMixedOffsets(),
-      extractOp.getMixedSizes(), extractOp.getMixedStrides());
-  return extractOp.getType().getShape().take_back(trailingRank) ==
-         inferredType.getShape().take_back(trailingRank);
-}
-
-namespace {
-/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
-///
-/// ```
-/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
-///     : tensor<?x?xf32> to tensor<?x?xf32>
-/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
-///     : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %p0 = arith.addi %a, %e : index
-/// %p1 = arith.addi %b, %f : index
-/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
-///     : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-class FoldExtractSliceIntoTransferRead final
-    : public OpRewritePattern<vector::TransferReadOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  FoldExtractSliceIntoTransferRead(MLIRContext *context,
-                                   std::function<bool(Operation *op)> controlFn,
-                                   PatternBenefit benefit)
-      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
-                                PatternRewriter &rewriter) const override {
-    if (controlFn && !controlFn(xferOp))
-      return failure();
-
-    // TODO: support 0-d corner case.
-    if (xferOp.getTransferRank() == 0)
-      return failure();
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (!xferOp.getPermutationMap().isMinorIdentity())
-      return failure();
-    if (xferOp.getMask())
-      return failure();
-    auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
-    if (!extractOp)
-      return failure();
-    if (!extractOp.hasUnitStride())
-      return failure();
-
-    // Bail on illegal rank-reduction: we need to check that the rank-reduced
-    // dims are exactly the leading dims. I.e. the following is illegal:
-    // ```
-    //    %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
-    //      tensor<2x1x4xf32> to tensor<2x4xf32>
-    //    %1 = vector.transfer_read %0[0,0], %cst :
-    //      tensor<2x4xf32>, vector<2x4xf32>
-    // ```
-    //
-    // Cannot fold into:
-    // ```
-    //    %0 = vector.transfer_read %t[0,0,0], %cst :
-    //      tensor<2x1x4xf32>, vector<2x4xf32>
-    // ```
-    // For this, check the trailing `vectorRank` dims of the extract_slice
-    // result tensor match the trailing dims of the inferred result tensor.
-    if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
-      return failure();
-
-    int64_t rankReduced =
-        extractOp.getSourceType().getRank() - extractOp.getType().getRank();
-
-    SmallVector<Value> newIndices;
-    // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
-    // indices first.
-    for (int64_t i = 0; i < rankReduced; ++i) {
-      OpFoldResult offset = extractOp.getMixedOffsets()[i];
-      newIndices.push_back(getValueOrCreateConstantIndexOp(
-          rewriter, extractOp.getLoc(), offset));
-    }
-    for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
-      OpFoldResult offset =
-          extractOp.getMixedOffsets()[it.index() + rankReduced];
-      newIndices.push_back(rewriter.create<arith::AddIOp>(
-          xferOp->getLoc(), it.value(),
-          getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
-                                          offset)));
-    }
-    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-        xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
-        xferOp.getPadding(), ArrayRef<bool>{inBounds});
-
-    return success();
-  }
-
-private:
-  std::function<bool(Operation *)> controlFn;
-};
-
-/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
-/// could directly write to the insert_slice's destination. E.g.:
-///
-/// ```
-/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
-///     : vector<4x5xf32>, tensor<4x5xf32>
-/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
-///     : tensor<4x5xf32> into tensor<?x?xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
-///     : vector<4x5xf32>, tensor<?x?xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-class FoldInsertSliceIntoTransferWrite final
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  FoldInsertSliceIntoTransferWrite(MLIRContext *context,
-                                   std::function<bool(Operation *op)> controlFn,
-                                   PatternBenefit benefit)
-      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
-
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    if (!insertOp.hasUnitStride())
-      return failure();
-
-    auto xferOp = insertOp.getSource().getDefiningOp<vector::TransferWriteOp>();
-    if (!xferOp)
-      return failure();
-    if (controlFn && !controlFn(xferOp))
-      return failure();
-
-    // TODO: support 0-d corner case.
-    if (xferOp.getTransferRank() == 0)
-      return failure();
-
-    if (xferOp.hasOutOfBoundsDim())
-      return failure();
-    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
-      return failure();
-    if (xferOp.getMask())
-      return failure();
-    // Fold only if the TransferWriteOp completely overwrites the `source` with
-    // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
-    // content is the data of the vector.
-    if (!llvm::equal(xferOp.getVectorType().getShape(),
-                     xferOp.getShapedType().getShape()))
-      return failure();
-    if (!xferOp.getPermutationMap().isIdentity())
-      return failure();
-
-    // Bail on illegal rank-reduction: we need to check that the rank-reduced
-    // dims are exactly the leading dims. I.e. the following is illegal:
-    // ```
-    //    %0 = vector.transfer_write %v, %t[0,0], %cst :
-    //      vector<2x4xf32>, tensor<2x4xf32>
-    //    %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
-    //      tensor<2x4xf32> into tensor<2x1x4xf32>
-    // ```
-    //
-    // Cannot fold into:
-    // ```
-    //    %0 = vector.transfer_write %v, %t[0,0,0], %cst :
-    //      vector<2x4xf32>, tensor<2x1x4xf32>
-    // ```
-    // For this, check the trailing `vectorRank` dims of the insert_slice result
-    // tensor match the trailing dims of the inferred result tensor.
-    int64_t rankReduced =
-        insertOp.getType().getRank() - insertOp.getSourceType().getRank();
-    int64_t vectorRank = xferOp.getVectorType().getRank();
-    RankedTensorType inferredSourceTensorType =
-        tensor::ExtractSliceOp::inferResultType(
-            insertOp.getType(), insertOp.getMixedOffsets(),
-            insertOp.getMixedSizes(), insertOp.getMixedStrides());
-    auto actualSourceTensorShape = insertOp.getSourceType().getShape();
-    if (rankReduced > 0 &&
-        actualSourceTensorShape.take_back(vectorRank) !=
-            inferredSourceTensorType.getShape().take_back(vectorRank))
-      return failure();
-
-    SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
-        rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
-    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        insertOp, xferOp.getVector(), insertOp.getDest(), indices,
-        ArrayRef<bool>{inBounds});
-    return success();
-  }
-
-private:
-  std::function<bool(Operation *)> controlFn;
-};
-
-} // namespace
-
-void vector::populateVectorTransferTensorSliceTransforms(
-    RewritePatternSet &patterns,
-    std::function<bool(Operation *vectorOp)> controlFn,
-    PatternBenefit benefit) {
-  patterns
-      .add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite>(
-          patterns.getContext(), controlFn, benefit);
-}

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
similarity index 76%
rename from mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
rename to mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
index 5fa9e1ae3160c5..e335277ccf18a1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
@@ -3,15 +3,18 @@
 transform.sequence failures(propagate) {
 ^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.fold_tensor_slice_into_transfer
+    transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers
   } : !transform.op<"func.func">
 }
 
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 3)>
+// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
 // CHECK-LABEL: func @transfer_read_of_extract_slice(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
 //   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%[[s1]]]
 //       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
 //       CHECK:   return %[[r]]
 func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
@@ -25,9 +28,8 @@ func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2
 
 // CHECK-LABEL: func @transfer_read_of_extract_slice_1d(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
 //   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%[[s1]]]
 //       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
 //       CHECK:   return %[[r]]
 func.func @transfer_read_of_extract_slice_1d(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
@@ -41,10 +43,9 @@ func.func @transfer_read_of_extract_slice_1d(%t : tensor<?x?xf32>, %s1 : index,
 
 // CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
 //   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
 //   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
-//       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map1]]()[%[[s1]]]
 //       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
 //       CHECK:   return %[[r]]
 func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
@@ -56,10 +57,13 @@ func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>,
   return %1 : vector<5x6xf32>
 }
 
-// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
-//       CHECK:   extract_slice
-//       CHECK:   vector.transfer_read
-func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+// CHECK-LABEL: func @transfer_read_of_extract_slice_non_leading_rank_reduction(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[s1]], %[[c10]]], %{{.*}} {in_bounds = [true, true], permutation_map = #[[$map2]]} : tensor<?x?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func.func @transfer_read_of_extract_slice_non_leading_rank_reduction(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
   %c3 = arith.constant 3 : index
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.0 : f32
@@ -80,10 +84,12 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
   return %1 : tensor<?x12xf32>
 }
 
-// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
-//       CHECK:   vector.transfer_write
-//       CHECK:   insert_slice
-func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+// CHECK-LABEL: func @insert_slice_of_transfer_write_non_leading_rank_reduction(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true], permutation_map = #[[$map2]]} : vector<5x6xf32>, tensor<?x?x12xf32>
+func.func @insert_slice_of_transfer_write_non_leading_rank_reduction(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
   %c0 = arith.constant 0 : index
   %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
   %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>


        


More information about the Mlir-commits mailing list