[Mlir-commits] [mlir] 1a4263d - [mlir][Vector] Add linalg.copy-based pattern for splitting vector.transfer_read into full and partial copies.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Aug 4 05:55:58 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-04T08:46:08-04:00
New Revision: 1a4263d394c1a93757613bde4b1c2cf8d6a7bbb9

URL: https://github.com/llvm/llvm-project/commit/1a4263d394c1a93757613bde4b1c2cf8d6a7bbb9
DIFF: https://github.com/llvm/llvm-project/commit/1a4263d394c1a93757613bde4b1c2cf8d6a7bbb9.diff

LOG: [mlir][Vector] Add linalg.copy-based pattern for splitting vector.transfer_read into full and partial copies.

This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling:

```
   %1:3 = scf.if (%inBounds) {
      scf.yield %view : memref<A...>, index, index
    } else {
      %2 = linalg.fill(%extra_alloc, %pad)
      %3 = subview %view [...][...][...]
      linalg.copy(%3, %alloc)
      memref_cast %extra_alloc: memref<B...> to memref<A...>
      scf.yield %4 : memref<A...>, index, index
   }
   %res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
```
where `extra_alloc` is a top of the function alloca'ed buffer of one vector.

This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer.
The extra work only occurs on the boundary tiles.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index edf9557df389..562e07f98774 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -56,22 +56,48 @@ enum class VectorContractLowering {
 };
 /// Enum to control the lowering of `vector.transpose` operations.
 enum class VectorTransposeLowering {
-  // Lower transpose into element-wise extract and inserts.
+  /// Lower transpose into element-wise extract and inserts.
   EltWise = 0,
   /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
   /// intrinsics.
   Flat = 1,
 };
+/// Enum to control the splitting of `vector.transfer` operations into masked
+/// and unmasked variants.
+enum class VectorTransferSplit {
+  /// Do not split vector transfer operations.
+  None = 0,
+  /// Split using masked + unmasked vector.transfer operations.
+  VectorTransfer = 1,
+  /// Split using a unmasked vector.transfer + linalg.fill + linalg.copy
+  /// operations.
+  LinalgCopy = 2,
+  /// Do not split vector transfer operation but instead mark it as "unmasked".
+  ForceUnmasked = 3
+};
 /// Structure to control the behavior of vector transform patterns.
 struct VectorTransformsOptions {
+  /// Option to control the lowering of vector.contract.
   VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
-  VectorTransposeLowering vectorTransposeLowering =
-      VectorTransposeLowering::EltWise;
   VectorTransformsOptions &
   setVectorTransformsOptions(VectorContractLowering opt) {
     vectorContractLowering = opt;
     return *this;
   }
+  /// Option to control the lowering of vector.transpose.
+  VectorTransposeLowering vectorTransposeLowering =
+      VectorTransposeLowering::EltWise;
+  VectorTransformsOptions &
+  setVectorTransposeLowering(VectorTransposeLowering opt) {
+    vectorTransposeLowering = opt;
+    return *this;
+  }
+  /// Option to control the splitting of vector transfers.
+  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+    vectorTransferSplit = opt;
+    return *this;
+  }
 };
 
 /// Collect a set of transformation patterns that are related to contracting

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 835ad18a79ad..e6c7b7abebd5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -109,13 +109,13 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
   FilterConstraintType filter;
 };
 
-/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
-/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
-/// is `success, the `ifOp` points to the newly created conditional upon
-/// function return. To accomodate for the fact that the original
-/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
-/// in the temporary buffer, the scf.if op returns a view and values of type
-/// index. At this time, only vector.transfer_read is implemented.
+/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
 ///
 /// Example (a 2-D vector.transfer_read):
 /// ```
@@ -124,17 +124,17 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
 /// is transformed into:
 /// ```
 ///    %1:3 = scf.if (%inBounds) {
-///       scf.yield %0 : memref<A...>, index, index
-///     } else {
-///       %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-///       %3 = vector.type_cast %extra_alloc : memref<...> to
-///       memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
-///       memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
-///       memref<A...>, index, index
+///      // fastpath, direct cast
+///      memref_cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      // slowpath, masked vector.transfer or linalg.copy.
+///      memref_cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
 //     }
 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
 /// ```
-/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
 ///
 /// Preconditions:
 ///  1. `xferOp.permutation_map()` must be a minor identity map
@@ -143,9 +143,10 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
 ///  rank-reducing subviews.
 LogicalResult
 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
-LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
-                                          VectorTransferOpInterface xferOp,
-                                          scf::IfOp *ifOp = nullptr);
+LogicalResult splitFullAndPartialTransfer(
+    OpBuilder &b, VectorTransferOpInterface xferOp,
+    VectorTransformsOptions options = VectorTransformsOptions(),
+    scf::IfOp *ifOp = nullptr);
 
 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
 /// may take an extra filter to perform selection at a finer granularity.
@@ -155,16 +156,19 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
 
   explicit VectorTransferFullPartialRewriter(
       MLIRContext *context,
+      VectorTransformsOptions options = VectorTransformsOptions(),
       FilterConstraintType filter =
           [](VectorTransferOpInterface op) { return success(); },
       PatternBenefit benefit = 1)
-      : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
+      : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
+        filter(filter) {}
 
   /// Performs the rewrite.
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override;
 
 private:
+  VectorTransformsOptions options;
   FilterConstraintType filter;
 };
 

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 13dbf6da73fa..1087feba7fbd 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVector
   MLIRIR
   MLIRStandardOps
   MLIRAffineOps
+  MLIRLinalgOps
   MLIRSCF
   MLIRLoopAnalysis
   MLIRSideEffectInterfaces

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 573b822503f3..3c23c5a6d869 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -2056,7 +2057,16 @@ LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
   return success();
 }
 
-MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
+/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
+/// be cast. If the MemRefTypes don't have the same rank or are not strided,
+/// return null; otherwise:
+///   1. if `aT` and `bT` are cast-compatible, return `aT`.
+///   2. else return a new MemRefType obtained by iterating over the shape and
+///   strides and:
+///     a. keeping the ones that are static and equal across `aT` and `bT`.
+///     b. using a dynamic shape and/or stride for the dimeniosns that don't
+///        agree.
+static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
   if (MemRefCastOp::areCastCompatible(aT, bT))
     return aT;
   if (aT.getRank() != bT.getRank())
@@ -2086,13 +2096,154 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
 }
 
-/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
-/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
-/// is `success, the `ifOp` points to the newly created conditional upon
-/// function return. To accomodate for the fact that the original
-/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
-/// in the temporary buffer, the scf.if op returns a view and values of type
-/// index. At this time, only vector.transfer_read is implemented.
+/// Operates under a scoped context to build the intersection between the
+/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
+// TODO: view intersection/union/
diff erences should be a proper std op.
+static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
+                                             Value alloc) {
+  using namespace edsc::intrinsics;
+  int64_t memrefRank = xferOp.getMemRefType().getRank();
+  // TODO: relax this precondition, will require rank-reducing subviews.
+  assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
+         "Expected memref rank to match the alloc rank");
+  Value one = std_constant_index(1);
+  ValueRange leadingIndices =
+      xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
+  SmallVector<Value, 4> sizes;
+  sizes.append(leadingIndices.begin(), leadingIndices.end());
+  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
+    Value dimAlloc = std_dim(alloc, resultIdx);
+    Value index = xferOp.indices()[indicesIdx];
+    AffineExpr i, j, k;
+    bindDims(xferOp.getContext(), i, j, k);
+    SmallVector<AffineMap, 4> maps =
+        AffineMap::inferFromExprList(MapList{{i - j, k}});
+    // affine_min(%dimMemRef - %index, %dimAlloc)
+    Value affineMin = affine_min(index.getType(), maps[0],
+                                 ValueRange{dimMemRef, index, dimAlloc});
+    sizes.push_back(affineMin);
+  });
+  return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
+                      SmallVector<Value, 4>(memrefRank, one));
+}
+
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref_cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view, ... : compatibleMemRefType, index, index
+///    } else {
+///      %2 = linalg.fill(%alloc, %pad)
+///      %3 = subview %view [...][...][...]
+///      linalg.copy(%3, %alloc)
+///      memref_cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4, ... : compatibleMemRefType, index, index
+///   }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp createScopedFullPartialLinalgCopy(
+    vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
+    MemRefType compatibleMemRefType, Value alloc) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  scf::IfOp fullPartialIfOp;
+  Value zero = std_constant_index(0);
+  Value memref = xferOp.memref();
+  conditionBuilder(
+      returnTypes, inBoundsCond,
+      [&]() -> scf::ValueVector {
+        Value res = memref;
+        if (compatibleMemRefType != xferOp.getMemRefType())
+          res = std_memref_cast(memref, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{res};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+                              xferOp.indices().end());
+        return viewAndIndices;
+      },
+      [&]() -> scf::ValueVector {
+        linalg_fill(alloc, xferOp.padding());
+        // Take partial subview of memref which guarantees no dimension
+        // overflows.
+        Value memRefSubView = createScopedSubViewIntersection(
+            cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
+        linalg_copy(memRefSubView, alloc);
+        Value casted = std_memref_cast(alloc, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{casted};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+                              zero);
+        return viewAndIndices;
+      },
+      &fullPartialIfOp);
+  return fullPartialIfOp;
+}
+
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref_cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view, ... : compatibleMemRefType, index, index
+///    } else {
+///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
+///      %3 = vector.type_cast %extra_alloc :
+///        memref<...> to memref<vector<...>>
+///      store %2, %3[] : memref<vector<...>>
+///      %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4, ... : compatibleMemRefType, index, index
+///   }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp createScopedFullPartialVectorTransferRead(
+    vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
+    MemRefType compatibleMemRefType, Value alloc) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  scf::IfOp fullPartialIfOp;
+  Value zero = std_constant_index(0);
+  Value memref = xferOp.memref();
+  conditionBuilder(
+      returnTypes, inBoundsCond,
+      [&]() -> scf::ValueVector {
+        Value res = memref;
+        if (compatibleMemRefType != xferOp.getMemRefType())
+          res = std_memref_cast(memref, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{res};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+                              xferOp.indices().end());
+        return viewAndIndices;
+      },
+      [&]() -> scf::ValueVector {
+        Operation *newXfer =
+            ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
+        Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
+        std_store(vector, vector_type_cast(
+                              MemRefType::get({}, vector.getType()), alloc));
+
+        Value casted = std_memref_cast(alloc, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{casted};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+                              zero);
+
+        return viewAndIndices;
+      },
+      &fullPartialIfOp);
+  return fullPartialIfOp;
+}
+
+/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
 ///
 /// Example (a 2-D vector.transfer_read):
 /// ```
@@ -2101,17 +2252,17 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
 /// is transformed into:
 /// ```
 ///    %1:3 = scf.if (%inBounds) {
-///       scf.yield %0 : memref<A...>, index, index
-///     } else {
-///       %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-///       %3 = vector.type_cast %extra_alloc : memref<...> to
-///       memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
-///       memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
-///       memref<A...>, index, index
+///      // fastpath, direct cast
+///      memref_cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      // slowpath, masked vector.transfer or linalg.copy.
+///      memref_cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
 //     }
 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
 /// ```
-/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
 ///
 /// Preconditions:
 ///  1. `xferOp.permutation_map()` must be a minor identity map
@@ -2119,10 +2270,21 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
 ///  must be equal. This will be relaxed in the future but requires
 ///  rank-reducing subviews.
 LogicalResult mlir::vector::splitFullAndPartialTransfer(
-    OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
+    OpBuilder &b, VectorTransferOpInterface xferOp,
+    VectorTransformsOptions options, scf::IfOp *ifOp) {
   using namespace edsc;
   using namespace edsc::intrinsics;
 
+  if (options.vectorTransferSplit == VectorTransferSplit::None)
+    return failure();
+
+  SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
+  auto unmaskedAttr = b.getBoolArrayAttr(bools);
+  if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
+    xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
+    return success();
+  }
+
   assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
          "Expected splitFullAndPartialTransferPrecondition to hold");
   auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
@@ -2154,45 +2316,21 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
                        b.getI64IntegerAttr(32));
   }
 
-  Value memref = xferOp.memref();
-  SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
-  auto unmaskedAttr = b.getBoolArrayAttr(bools);
-
   MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
       xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
 
   // Read case: full fill + partial copy -> unmasked vector.xfer_read.
-  Value zero = std_constant_index(0);
   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
                                    b.getIndexType());
   returnTypes[0] = compatibleMemRefType;
-  scf::IfOp fullPartialIfOp;
-  conditionBuilder(
-      returnTypes, inBoundsCond,
-      [&]() -> scf::ValueVector {
-        Value res = memref;
-        if (compatibleMemRefType != xferOp.getMemRefType())
-          res = std_memref_cast(memref, compatibleMemRefType);
-        scf::ValueVector viewAndIndices{res};
-        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
-                              xferOp.indices().end());
-        return viewAndIndices;
-      },
-      [&]() -> scf::ValueVector {
-        Operation *newXfer =
-            ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
-        Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
-        std_store(vector, vector_type_cast(
-                              MemRefType::get({}, vector.getType()), alloc));
-
-        Value casted = std_memref_cast(alloc, compatibleMemRefType);
-        scf::ValueVector viewAndIndices{casted};
-        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
-                              zero);
-
-        return viewAndIndices;
-      },
-      &fullPartialIfOp);
+  scf::IfOp fullPartialIfOp =
+      options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
+          ? createScopedFullPartialVectorTransferRead(
+                xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
+                alloc)
+          : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
+                                              inBoundsCond,
+                                              compatibleMemRefType, alloc);
   if (ifOp)
     *ifOp = fullPartialIfOp;
 
@@ -2211,7 +2349,7 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
       failed(filter(xferOp)))
     return failure();
   rewriter.startRootUpdate(xferOp);
-  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
+  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
     rewriter.finalizeRootUpdate(xferOp);
     return success();
   }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index ef76247ee9d4..e36454203982 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -1,13 +1,26 @@
 // RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy | FileCheck %s --check-prefix=LINALG
 
 // CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
 // CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
 // CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 
+// LINALG-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// LINALG-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// LINALG-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// LINALG-DAG: #[[$map_2d_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+// LINALG-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
+// LINALG-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
+
 // CHECK-LABEL: split_vector_transfer_read_2d(
 //  CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
 //  CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
 //  CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
+
+// LINALG-LABEL: split_vector_transfer_read_2d(
+//  LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
+//  LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
+//  LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
 func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
   %c0 = constant 0 : index
   %f0 = constant 0.0 : f32
@@ -43,9 +56,45 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
   //      CHECK: }
   //      CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
   // CHECK_SAME:   {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
+
+  //  LINALG-DAG: %[[c0:.*]] = constant 0 : index
+  //  LINALG-DAG: %[[c1:.*]] = constant 1 : index
+  //  LINALG-DAG: %[[c4:.*]] = constant 4 : index
+  //  LINALG-DAG: %[[c8:.*]] = constant 8 : index
+  //  LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+  // alloca for boundary full tile
+  //      LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
+  // %i + 4 <= dim(%A, 0)
+  //      LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
+  //      LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
+  //      LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index
+  // %j + 8 <= dim(%A, 1)
+  //      LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
+  //      LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
+  // are both conds true
+  //      LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
+  //      LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
+  //               inBounds, just yield %A
+  //      LINALG:   scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
+  //      LINALG: } else {
+  //               slow path, fill tmp alloc and yield a memref_casted version of it
+  //      LINALG:   linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
+  //      LINALG:   %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
+  //      LINALG:   %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
+  //      LINALG:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
+  //      LINALG:   %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
+  // LINALG-SAME:     memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_dynamic]]>
+  //      LINALG:   linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
+  //      LINALG:   %[[yielded:.*]] = memref_cast %[[alloc]] :
+  // LINALG-SAME:     memref<4x8xf32> to memref<?x8xf32>
+  //      LINALG:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+  // LINALG-SAME:     memref<?x8xf32>, index, index
+  //      LINALG: }
+  //      LINALG: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
+  // LINALG_SAME:   {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
   %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
 
-  // CHECK: return %[[res]] : vector<4x8xf32>
+  // LINALG: return %[[res]] : vector<4x8xf32>
   return %1: vector<4x8xf32>
 }
 
@@ -53,6 +102,11 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
 //  CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
 //  CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
 //  CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
+
+// LINALG-LABEL: split_vector_transfer_read_strided_2d(
+//  LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
+//  LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
+//  LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
 func @split_vector_transfer_read_strided_2d(
     %A: memref<7x8xf32, offset:?, strides:[?, 1]>,
     %i: index, %j: index) -> vector<4x8xf32> {
@@ -94,6 +148,44 @@ func @split_vector_transfer_read_strided_2d(
   //      CHECK: }
   //      CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
   // CHECK-SAME:   memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
+
+  //  LINALG-DAG: %[[c0:.*]] = constant 0 : index
+  //  LINALG-DAG: %[[c1:.*]] = constant 1 : index
+  //  LINALG-DAG: %[[c4:.*]] = constant 4 : index
+  //  LINALG-DAG: %[[c7:.*]] = constant 7 : index
+  //  LINALG-DAG: %[[c8:.*]] = constant 8 : index
+  //  LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+  // alloca for boundary full tile
+  //      LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
+  // %i + 4 <= dim(%A, 0)
+  //      LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
+  //      LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index
+  // %j + 8 <= dim(%A, 1)
+  //      LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
+  //      LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
+  // are both conds true
+  //      LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
+  //      LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index) {
+  //               inBounds but not cast-compatible: yield a memref_casted form of %A
+  //      LINALG:   %[[casted:.*]] = memref_cast %arg0 :
+  // LINALG-SAME:     memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x8xf32, #[[$map_2d_stride_1]]>
+  //      LINALG:   scf.yield %[[casted]], %[[i]], %[[j]] :
+  // LINALG-SAME:     memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
+  //      LINALG: } else {
+  //               slow path, fill tmp alloc and yield a memref_casted version of it
+  //      LINALG:   linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
+  //      LINALG:   %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
+  //      LINALG:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
+  //      LINALG:   %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
+  // LINALG-SAME:     memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_dynamic]]>
+  //      LINALG:   linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
+  //      LINALG:   %[[yielded:.*]] = memref_cast %[[alloc]] :
+  // LINALG-SAME:     memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
+  //      LINALG:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+  // LINALG-SAME:     memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
+  //      LINALG: }
+  //      LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
+  // LINALG-SAME:   memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
   %1 = vector.transfer_read %A[%i, %j], %f0 :
     memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32>
 

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 0bba74e76385..9da3156d5359 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -125,10 +125,23 @@ struct TestVectorUnrollingPatterns
 struct TestVectorTransferFullPartialSplitPatterns
     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
                          FunctionPass> {
+  TestVectorTransferFullPartialSplitPatterns() = default;
+  TestVectorTransferFullPartialSplitPatterns(
+      const TestVectorTransferFullPartialSplitPatterns &pass) {}
+  Option<bool> useLinalgOps{
+      *this, "use-linalg-copy",
+      llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
+                     "linalg.copy operations."),
+      llvm::cl::init(false)};
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
-    patterns.insert<VectorTransferFullPartialRewriter>(ctx);
+    VectorTransformsOptions options;
+    if (useLinalgOps)
+      options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
+    else
+      options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
+    patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }
 };


        


More information about the Mlir-commits mailing list