[Mlir-commits] [mlir] c7dd0bf - [mlir][vector] NFC - Split out transfer split patterns

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jan 5 05:38:11 PST 2022

Author: Nicolas Vasilache
Date: 2022-01-05T08:38:04-05:00
New Revision: c7dd0bf41d8ee940893ae144cc0d813828c8233a

URL: https://github.com/llvm/llvm-project/commit/c7dd0bf41d8ee940893ae144cc0d813828c8233a
DIFF: https://github.com/llvm/llvm-project/commit/c7dd0bf41d8ee940893ae144cc0d813828c8233a.diff

LOG: [mlir][vector] NFC - Split out transfer split patterns

Differential Revision: https://reviews.llvm.org/D116648




diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 811d72192910e..32935d547df0d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -16,6 +16,8 @@ namespace mlir {
 class MLIRContext;
 class VectorTransferOpInterface;
 class RewritePatternSet;
+class RewriterBase;
 using OwningRewritePatternList = RewritePatternSet;
 namespace scf {
@@ -61,7 +63,7 @@ namespace vector {
 ///  must be equal. This will be relaxed in the future but requires
 ///  rank-reducing subviews.
 LogicalResult splitFullAndPartialTransfer(
-    OpBuilder &b, VectorTransferOpInterface xferOp,
+    RewriterBase &b, VectorTransferOpInterface xferOp,
     VectorTransformsOptions options = VectorTransformsOptions(),
     scf::IfOp *ifOp = nullptr);

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 143c6c7d688d1..36afc18ff0bea 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVector
+  VectorTransferSplitRewritePatterns.cpp

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
new file mode 100644
index 0000000000000..a6fbc303d64e3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
@@ -0,0 +1,625 @@
+//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This file implements target-independent patterns to rewrite a vector.transfer
+// op into a fully in-bounds part and a partial part.
+#include <type_traits>
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#define DEBUG_TYPE "vector-transfer-split"
+using namespace mlir;
+using namespace mlir::vector;
+static Optional<int64_t> extractConstantIndex(Value v) {
+  if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
+    return cstOp.value();
+  if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
+    if (affineApplyOp.getAffineMap().isSingleConstant())
+      return affineApplyOp.getAffineMap().getSingleConstantResult();
+  return None;
+// Missing foldings of scf.if make it necessary to perform poor man's folding
+// eagerly, especially in the case of unrolling. In the future, this should go
+// away once scf.if folds properly.
+static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
+  auto maybeCstV = extractConstantIndex(v);
+  auto maybeCstUb = extractConstantIndex(ub);
+  if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
+    return Value();
+  return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
+/// Build the condition to ensure that a particular VectorTransferOpInterface
+/// is in-bounds.
+static Value createInBoundsCond(RewriterBase &b,
+                                VectorTransferOpInterface xferOp) {
+  assert(xferOp.permutation_map().isMinorIdentity() &&
+         "Expected minor identity map");
+  Value inBoundsCond;
+  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+    // Zip over the resulting vector shape and memref indices.
+    // If the dimension is known to be in-bounds, it does not participate in
+    // the construction of `inBoundsCond`.
+    if (xferOp.isDimInBounds(resultIdx))
+      return;
+    // Fold or create the check that `index + vector_size` <= `memref_size`.
+    Location loc = xferOp.getLoc();
+    int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
+    auto d0 = getAffineDimExpr(0, xferOp.getContext());
+    auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
+    Value sum =
+        makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
+    Value cond = createFoldedSLE(
+        b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
+    if (!cond)
+      return;
+    // Conjunction over all dims for which we are in-bounds.
+    if (inBoundsCond)
+      inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
+    else
+      inBoundsCond = cond;
+  });
+  return inBoundsCond;
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
+/// Example (a 2-D vector.transfer_read):
+/// ```
+///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      // fastpath, direct cast
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      // slowpath, not in-bounds vector.transfer or linalg.copy.
+///      memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
+//     }
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+/// Preconditions:
+///  1. `xferOp.permutation_map()` must be a minor identity map
+///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+///  must be equal. This will be relaxed in the future but requires
+///  rank-reducing subviews.
+static LogicalResult
+splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
+  // TODO: support 0-d corner case.
+  if (xferOp.getTransferRank() == 0)
+    return failure();
+  // TODO: expand support to these 2 cases.
+  if (!xferOp.permutation_map().isMinorIdentity())
+    return failure();
+  // Must have some out-of-bounds dimension to be a candidate for splitting.
+  if (!xferOp.hasOutOfBoundsDim())
+    return failure();
+  // Don't split transfer operations directly under IfOp, this avoids applying
+  // the pattern recursively.
+  // TODO: improve the filtering condition to make it more applicable.
+  if (isa<scf::IfOp>(xferOp->getParentOp()))
+    return failure();
+  return success();
+/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
+/// be cast. If the MemRefTypes don't have the same rank or are not strided,
+/// return null; otherwise:
+///   1. if `aT` and `bT` are cast-compatible, return `aT`.
+///   2. else return a new MemRefType obtained by iterating over the shape and
+///   strides and:
+///     a. keeping the ones that are static and equal across `aT` and `bT`.
+///     b. using a dynamic shape and/or stride for the dimensions that don't
+///        agree.
+static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
+  if (memref::CastOp::areCastCompatible(aT, bT))
+    return aT;
+  if (aT.getRank() != bT.getRank())
+    return MemRefType();
+  int64_t aOffset, bOffset;
+  SmallVector<int64_t, 4> aStrides, bStrides;
+  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
+      failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+      aStrides.size() != bStrides.size())
+    return MemRefType();
+  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
+  int64_t resOffset;
+  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
+      resStrides(bT.getRank(), 0);
+  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
+    resShape[idx] =
+        (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
+    resStrides[idx] = (aStrides[idx] == bStrides[idx])
+                          ? aStrides[idx]
+                          : MemRefType::kDynamicStrideOrOffset;
+  }
+  resOffset =
+      (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
+  return MemRefType::get(
+      resShape, aT.getElementType(),
+      makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
+/// Operates under a scoped context to build the intersection between the
+/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
+// TODO: view intersection/union/
diff erences should be a proper std op.
+static std::pair<Value, Value>
+createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
+                          Value alloc) {
+  Location loc = xferOp.getLoc();
+  int64_t memrefRank = xferOp.getShapedType().getRank();
+  // TODO: relax this precondition, will require rank-reducing subviews.
+  assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
+         "Expected memref rank to match the alloc rank");
+  ValueRange leadingIndices =
+      xferOp.indices().take_front(xferOp.getLeadingShapedRank());
+  SmallVector<OpFoldResult, 4> sizes;
+  sizes.append(leadingIndices.begin(), leadingIndices.end());
+  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
+  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
+                                                xferOp.source(), indicesIdx);
+    Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
+    Value index = xferOp.indices()[indicesIdx];
+    AffineExpr i, j, k;
+    bindDims(xferOp.getContext(), i, j, k);
+    SmallVector<AffineMap, 4> maps =
+        AffineMap::inferFromExprList(MapList{{i - j, k}});
+    // affine_min(%dimMemRef - %index, %dimAlloc)
+    Value affineMin = b.create<AffineMinOp>(
+        loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
+    sizes.push_back(affineMin);
+  });
+  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
+      xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
+  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
+  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
+  auto copySrc = b.create<memref::SubViewOp>(
+      loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
+  auto copyDest = b.create<memref::SubViewOp>(
+      loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
+  return std::make_pair(copySrc, copyDest);
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view, ... : compatibleMemRefType, index, index
+///    } else {
+///      %2 = linalg.fill(%pad, %alloc)
+///      %3 = subview %view [...][...][...]
+///      %4 = subview %alloc [0, 0] [...] [...]
+///      linalg.copy(%3, %4)
+///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %5, ... : compatibleMemRefType, index, index
+///   }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp
+createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
+                            TypeRange returnTypes, Value inBoundsCond,
+                            MemRefType compatibleMemRefType, Value alloc) {
+  Location loc = xferOp.getLoc();
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value memref = xferOp.source();
+  return b.create<scf::IfOp>(
+      loc, returnTypes, inBoundsCond,
+      [&](OpBuilder &b, Location loc) {
+        Value res = memref;
+        if (compatibleMemRefType != xferOp.getShapedType())
+          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{res};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+                              xferOp.indices().end());
+        b.create<scf::YieldOp>(loc, viewAndIndices);
+      },
+      [&](OpBuilder &b, Location loc) {
+        b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
+        // Take partial subview of memref which guarantees no dimension
+        // overflows.
+        IRRewriter rewriter(b);
+        std::pair<Value, Value> copyArgs = createSubViewIntersection(
+            rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
+            alloc);
+        b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
+        Value casted =
+            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{casted};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+                              zero);
+        b.create<scf::YieldOp>(loc, viewAndIndices);
+      });
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view, ... : compatibleMemRefType, index, index
+///    } else {
+///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
+///      %3 = vector.type_cast %extra_alloc :
+///        memref<...> to memref<vector<...>>
+///      store %2, %3[] : memref<vector<...>>
+///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4, ... : compatibleMemRefType, index, index
+///   }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp createFullPartialVectorTransferRead(
+    RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
+    Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
+  Location loc = xferOp.getLoc();
+  scf::IfOp fullPartialIfOp;
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value memref = xferOp.source();
+  return b.create<scf::IfOp>(
+      loc, returnTypes, inBoundsCond,
+      [&](OpBuilder &b, Location loc) {
+        Value res = memref;
+        if (compatibleMemRefType != xferOp.getShapedType())
+          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{res};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+                              xferOp.indices().end());
+        b.create<scf::YieldOp>(loc, viewAndIndices);
+      },
+      [&](OpBuilder &b, Location loc) {
+        Operation *newXfer = b.clone(*xferOp.getOperation());
+        Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
+        b.create<memref::StoreOp>(
+            loc, vector,
+            b.create<vector::TypeCastOp>(
+                loc, MemRefType::get({}, vector.getType()), alloc));
+        Value casted =
+            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{casted};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+                              zero);
+        b.create<scf::YieldOp>(loc, viewAndIndices);
+      });
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view, ... : compatibleMemRefType, index, index
+///    } else {
+///      %3 = vector.type_cast %extra_alloc :
+///        memref<...> to memref<vector<...>>
+///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4, ... : compatibleMemRefType, index, index
+///   }
+/// ```
+static ValueRange
+getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
+                          TypeRange returnTypes, Value inBoundsCond,
+                          MemRefType compatibleMemRefType, Value alloc) {
+  Location loc = xferOp.getLoc();
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value memref = xferOp.source();
+  return b
+      .create<scf::IfOp>(
+          loc, returnTypes, inBoundsCond,
+          [&](OpBuilder &b, Location loc) {
+            Value res = memref;
+            if (compatibleMemRefType != xferOp.getShapedType())
+              res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+            scf::ValueVector viewAndIndices{res};
+            viewAndIndices.insert(viewAndIndices.end(),
+                                  xferOp.indices().begin(),
+                                  xferOp.indices().end());
+            b.create<scf::YieldOp>(loc, viewAndIndices);
+          },
+          [&](OpBuilder &b, Location loc) {
+            Value casted =
+                b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+            scf::ValueVector viewAndIndices{casted};
+            viewAndIndices.insert(viewAndIndices.end(),
+                                  xferOp.getTransferRank(), zero);
+            b.create<scf::YieldOp>(loc, viewAndIndices);
+          })
+      ->getResults();
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` has been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+///   3. it originally wrote to %view
+/// Produce IR resembling:
+/// ```
+///    %notInBounds = arith.xori %inBounds, %true
+///    scf.if (%notInBounds) {
+///      %3 = subview %alloc [...][...][...]
+///      %4 = subview %view [0, 0][...][...]
+///      linalg.copy(%3, %4)
+///   }
+/// ```
+static void createFullPartialLinalgCopy(RewriterBase &b,
+                                        vector::TransferWriteOp xferOp,
+                                        Value inBoundsCond, Value alloc) {
+  Location loc = xferOp.getLoc();
+  auto notInBounds = b.create<arith::XOrIOp>(
+      loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
+  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
+    IRRewriter rewriter(b);
+    std::pair<Value, Value> copyArgs = createSubViewIntersection(
+        rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
+        alloc);
+    b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
+    b.create<scf::YieldOp>(loc, ValueRange{});
+  });
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` has been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+///   3. it originally wrote to %view
+/// Produce IR resembling:
+/// ```
+///    %notInBounds = arith.xori %inBounds, %true
+///    scf.if (%notInBounds) {
+///      %2 = load %alloc : memref<vector<...>>
+///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
+///   }
+/// ```
+static void createFullPartialVectorTransferWrite(RewriterBase &b,
+                                                 vector::TransferWriteOp xferOp,
+                                                 Value inBoundsCond,
+                                                 Value alloc) {
+  Location loc = xferOp.getLoc();
+  auto notInBounds = b.create<arith::XOrIOp>(
+      loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
+  b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
+    BlockAndValueMapping mapping;
+    Value load = b.create<memref::LoadOp>(
+        loc, b.create<vector::TypeCastOp>(
+                 loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
+    mapping.map(xferOp.vector(), load);
+    b.clone(*xferOp.getOperation(), mapping);
+    b.create<scf::YieldOp>(loc, ValueRange{});
+  });
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fastpath and a slowpath.
+/// For vector.transfer_read:
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// Example (a 2-D vector.transfer_read):
+/// ```
+///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      // fastpath, direct cast
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      // slowpath, not in-bounds vector.transfer or linalg.copy.
+///      memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
+//     }
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+/// For vector.transfer_write:
+/// There are 2 conditional blocks. First a block to decide which memref and
+/// indices to use for an unmasked, inbounds write. Then a conditional block to
+/// further copy a partial buffer into the final result in the slow path case.
+/// Example (a 2-D vector.transfer_write):
+/// ```
+///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
+///     }
+///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
+///                                                                    true]}
+///    scf.if (%notInBounds) {
+///      // slowpath: not in-bounds vector.transfer or linalg.copy.
+///    }
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+/// Preconditions:
+///  1. `xferOp.permutation_map()` must be a minor identity map
+///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
+///  must be equal. This will be relaxed in the future but requires
+///  rank-reducing subviews.
+LogicalResult mlir::vector::splitFullAndPartialTransfer(
+    RewriterBase &b, VectorTransferOpInterface xferOp,
+    VectorTransformsOptions options, scf::IfOp *ifOp) {
+  if (options.vectorTransferSplit == VectorTransferSplit::None)
+    return failure();
+  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
+  auto inBoundsAttr = b.getBoolArrayAttr(bools);
+  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
+    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+    return success();
+  }
+  // Assert preconditions. Additionally, keep the variables in an inner scope to
+  // ensure they aren't used in the wrong scopes further down.
+  {
+    assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
+           "Expected splitFullAndPartialTransferPrecondition to hold");
+    auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
+    auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
+    if (!(xferReadOp || xferWriteOp))
+      return failure();
+    if (xferWriteOp && xferWriteOp.mask())
+      return failure();
+    if (xferReadOp && xferReadOp.mask())
+      return failure();
+  }
+  RewriterBase::InsertionGuard guard(b);
+  b.setInsertionPoint(xferOp);
+  Value inBoundsCond = createInBoundsCond(
+      b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
+  if (!inBoundsCond)
+    return failure();
+  // Top of the function `alloc` for transient storage.
+  Value alloc;
+  {
+    FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
+    RewriterBase::InsertionGuard guard(b);
+    b.setInsertionPointToStart(&funcOp.getRegion().front());
+    auto shape = xferOp.getVectorType().getShape();
+    Type elementType = xferOp.getVectorType().getElementType();
+    alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
+                                       MemRefType::get(shape, elementType),
+                                       ValueRange{}, b.getI64IntegerAttr(32));
+  }
+  MemRefType compatibleMemRefType =
+      getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
+                                  alloc.getType().cast<MemRefType>());
+  if (!compatibleMemRefType)
+    return failure();
+  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
+                                   b.getIndexType());
+  returnTypes[0] = compatibleMemRefType;
+  if (auto xferReadOp =
+          dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
+    // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
+    scf::IfOp fullPartialIfOp =
+        options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
+            ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
+                                                  inBoundsCond,
+                                                  compatibleMemRefType, alloc)
+            : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
+                                          inBoundsCond, compatibleMemRefType,
+                                          alloc);
+    if (ifOp)
+      *ifOp = fullPartialIfOp;
+    // Set existing read op to in-bounds, it always reads from a full buffer.
+    for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
+      xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
+    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+    return success();
+  }
+  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
+  // Decide which location to write the entire vector to.
+  auto memrefAndIndices = getLocationToWriteFullVec(
+      b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
+  // Do an in bounds write to either the output or the extra allocated buffer.
+  // The operation is cloned to prevent deleting information needed for the
+  // later IR creation.
+  BlockAndValueMapping mapping;
+  mapping.map(xferWriteOp.source(), memrefAndIndices.front());
+  mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
+  auto *clone = b.clone(*xferWriteOp, mapping);
+  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
+  // Create a potential copy from the allocated buffer to the final output in
+  // the slow path case.
+  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
+    createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
+  else
+    createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
+  xferOp->erase();
+  return success();
+LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
+  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
+      failed(filter(xferOp)))
+    return failure();
+  rewriter.startRootUpdate(xferOp);
+  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
+    rewriter.finalizeRootUpdate(xferOp);
+    return success();
+  }
+  rewriter.cancelRootUpdate(xferOp);
+  return failure();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 0b49ccd58b27d..cd26750c9eb7a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -144,7 +144,6 @@ static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
 namespace {
 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
 // Example:
@@ -1642,587 +1641,6 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
 } // namespace mlir
-static Optional<int64_t> extractConstantIndex(Value v) {
-  if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
-    return cstOp.value();
-  if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
-    if (affineApplyOp.getAffineMap().isSingleConstant())
-      return affineApplyOp.getAffineMap().getSingleConstantResult();
-  return None;
-// Missing foldings of scf.if make it necessary to perform poor man's folding
-// eagerly, especially in the case of unrolling. In the future, this should go
-// away once scf.if folds properly.
-static Value createFoldedSLE(OpBuilder &b, Value v, Value ub) {
-  auto maybeCstV = extractConstantIndex(v);
-  auto maybeCstUb = extractConstantIndex(ub);
-  if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
-    return Value();
-  return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
-// Operates under a scoped context to build the condition to ensure that a
-// particular VectorTransferOpInterface is in-bounds.
-static Value createInBoundsCond(OpBuilder &b,
-                                VectorTransferOpInterface xferOp) {
-  assert(xferOp.permutation_map().isMinorIdentity() &&
-         "Expected minor identity map");
-  Value inBoundsCond;
-  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
-    // Zip over the resulting vector shape and memref indices.
-    // If the dimension is known to be in-bounds, it does not participate in
-    // the construction of `inBoundsCond`.
-    if (xferOp.isDimInBounds(resultIdx))
-      return;
-    // Fold or create the check that `index + vector_size` <= `memref_size`.
-    Location loc = xferOp.getLoc();
-    ImplicitLocOpBuilder lb(loc, b);
-    int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
-    auto d0 = getAffineDimExpr(0, xferOp.getContext());
-    auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
-    Value sum =
-        makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
-    Value cond = createFoldedSLE(
-        b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
-    if (!cond)
-      return;
-    // Conjunction over all dims for which we are in-bounds.
-    if (inBoundsCond)
-      inBoundsCond = lb.create<arith::AndIOp>(inBoundsCond, cond);
-    else
-      inBoundsCond = cond;
-  });
-  return inBoundsCond;
-/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
-/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
-/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
-/// scf.if op returns a view and values of type index.
-/// At this time, only vector.transfer_read case is implemented.
-/// Example (a 2-D vector.transfer_read):
-/// ```
-///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-/// ```
-/// is transformed into:
-/// ```
-///    %1:3 = scf.if (%inBounds) {
-///      // fastpath, direct cast
-///      memref.cast %A: memref<A...> to compatibleMemRefType
-///      scf.yield %view : compatibleMemRefType, index, index
-///    } else {
-///      // slowpath, not in-bounds vector.transfer or linalg.copy.
-///      memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %4 : compatibleMemRefType, index, index
-//     }
-///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
-/// ```
-/// where `alloc` is a top of the function alloca'ed buffer of one vector.
-/// Preconditions:
-///  1. `xferOp.permutation_map()` must be a minor identity map
-///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
-///  must be equal. This will be relaxed in the future but requires
-///  rank-reducing subviews.
-static LogicalResult
-splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
-  // TODO: support 0-d corner case.
-  if (xferOp.getTransferRank() == 0)
-    return failure();
-  // TODO: expand support to these 2 cases.
-  if (!xferOp.permutation_map().isMinorIdentity())
-    return failure();
-  // Must have some out-of-bounds dimension to be a candidate for splitting.
-  if (!xferOp.hasOutOfBoundsDim())
-    return failure();
-  // Don't split transfer operations directly under IfOp, this avoids applying
-  // the pattern recursively.
-  // TODO: improve the filtering condition to make it more applicable.
-  if (isa<scf::IfOp>(xferOp->getParentOp()))
-    return failure();
-  return success();
-/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
-/// be cast. If the MemRefTypes don't have the same rank or are not strided,
-/// return null; otherwise:
-///   1. if `aT` and `bT` are cast-compatible, return `aT`.
-///   2. else return a new MemRefType obtained by iterating over the shape and
-///   strides and:
-///     a. keeping the ones that are static and equal across `aT` and `bT`.
-///     b. using a dynamic shape and/or stride for the dimensions that don't
-///        agree.
-static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
-  if (memref::CastOp::areCastCompatible(aT, bT))
-    return aT;
-  if (aT.getRank() != bT.getRank())
-    return MemRefType();
-  int64_t aOffset, bOffset;
-  SmallVector<int64_t, 4> aStrides, bStrides;
-  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
-      failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
-      aStrides.size() != bStrides.size())
-    return MemRefType();
-  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
-  int64_t resOffset;
-  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
-      resStrides(bT.getRank(), 0);
-  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
-    resShape[idx] =
-        (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
-    resStrides[idx] = (aStrides[idx] == bStrides[idx])
-                          ? aStrides[idx]
-                          : MemRefType::kDynamicStrideOrOffset;
-  }
-  resOffset =
-      (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
-  return MemRefType::get(
-      resShape, aT.getElementType(),
-      makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
-/// Operates under a scoped context to build the intersection between the
-/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
-// TODO: view intersection/union/
diff erences should be a proper std op.
-static std::pair<Value, Value>
-createSubViewIntersection(OpBuilder &b, VectorTransferOpInterface xferOp,
-                          Value alloc) {
-  ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
-  int64_t memrefRank = xferOp.getShapedType().getRank();
-  // TODO: relax this precondition, will require rank-reducing subviews.
-  assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
-         "Expected memref rank to match the alloc rank");
-  ValueRange leadingIndices =
-      xferOp.indices().take_front(xferOp.getLeadingShapedRank());
-  SmallVector<OpFoldResult, 4> sizes;
-  sizes.append(leadingIndices.begin(), leadingIndices.end());
-  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
-  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
-    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-    Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
-                                                xferOp.source(), indicesIdx);
-    Value dimAlloc = lb.create<memref::DimOp>(alloc, resultIdx);
-    Value index = xferOp.indices()[indicesIdx];
-    AffineExpr i, j, k;
-    bindDims(xferOp.getContext(), i, j, k);
-    SmallVector<AffineMap, 4> maps =
-        AffineMap::inferFromExprList(MapList{{i - j, k}});
-    // affine_min(%dimMemRef - %index, %dimAlloc)
-    Value affineMin = lb.create<AffineMinOp>(
-        index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
-    sizes.push_back(affineMin);
-  });
-  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
-      xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
-  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
-  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
-  auto copySrc = lb.create<memref::SubViewOp>(
-      isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
-  auto copyDest = lb.create<memref::SubViewOp>(
-      isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
-  return std::make_pair(copySrc, copyDest);
-/// Given an `xferOp` for which:
-///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
-///   2. a memref of single vector `alloc` has been allocated.
-/// Produce IR resembling:
-/// ```
-///    %1:3 = scf.if (%inBounds) {
-///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
-///      scf.yield %view, ... : compatibleMemRefType, index, index
-///    } else {
-///      %2 = linalg.fill(%pad, %alloc)
-///      %3 = subview %view [...][...][...]
-///      %4 = subview %alloc [0, 0] [...] [...]
-///      linalg.copy(%3, %4)
-///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %5, ... : compatibleMemRefType, index, index
-///   }
-/// ```
-/// Return the produced scf::IfOp.
-static scf::IfOp
-createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
-                            TypeRange returnTypes, Value inBoundsCond,
-                            MemRefType compatibleMemRefType, Value alloc) {
-  Location loc = xferOp.getLoc();
-  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
-  Value memref = xferOp.source();
-  return b.create<scf::IfOp>(
-      loc, returnTypes, inBoundsCond,
-      [&](OpBuilder &b, Location loc) {
-        Value res = memref;
-        if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
-        scf::ValueVector viewAndIndices{res};
-        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
-                              xferOp.indices().end());
-        b.create<scf::YieldOp>(loc, viewAndIndices);
-      },
-      [&](OpBuilder &b, Location loc) {
-        b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
-        // Take partial subview of memref which guarantees no dimension
-        // overflows.
-        std::pair<Value, Value> copyArgs = createSubViewIntersection(
-            b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
-        b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
-        Value casted =
-            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
-        scf::ValueVector viewAndIndices{casted};
-        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
-                              zero);
-        b.create<scf::YieldOp>(loc, viewAndIndices);
-      });
-/// Given an `xferOp` for which:
-///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
-///   2. a memref of single vector `alloc` has been allocated.
-/// Produce IR resembling:
-/// ```
-///    %1:3 = scf.if (%inBounds) {
-///      memref.cast %A: memref<A...> to compatibleMemRefType
-///      scf.yield %view, ... : compatibleMemRefType, index, index
-///    } else {
-///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
-///      %3 = vector.type_cast %extra_alloc :
-///        memref<...> to memref<vector<...>>
-///      store %2, %3[] : memref<vector<...>>
-///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %4, ... : compatibleMemRefType, index, index
-///   }
-/// ```
-/// Return the produced scf::IfOp.
-static scf::IfOp createFullPartialVectorTransferRead(
-    OpBuilder &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
-    Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
-  Location loc = xferOp.getLoc();
-  scf::IfOp fullPartialIfOp;
-  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
-  Value memref = xferOp.source();
-  return b.create<scf::IfOp>(
-      loc, returnTypes, inBoundsCond,
-      [&](OpBuilder &b, Location loc) {
-        Value res = memref;
-        if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
-        scf::ValueVector viewAndIndices{res};
-        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
-                              xferOp.indices().end());
-        b.create<scf::YieldOp>(loc, viewAndIndices);
-      },
-      [&](OpBuilder &b, Location loc) {
-        Operation *newXfer = b.clone(*xferOp.getOperation());
-        Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
-        b.create<memref::StoreOp>(
-            loc, vector,
-            b.create<vector::TypeCastOp>(
-                loc, MemRefType::get({}, vector.getType()), alloc));
-        Value casted =
-            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
-        scf::ValueVector viewAndIndices{casted};
-        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
-                              zero);
-        b.create<scf::YieldOp>(loc, viewAndIndices);
-      });
-/// Given an `xferOp` for which:
-///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
-///   2. a memref of single vector `alloc` has been allocated.
-/// Produce IR resembling:
-/// ```
-///    %1:3 = scf.if (%inBounds) {
-///      memref.cast %A: memref<A...> to compatibleMemRefType
-///      scf.yield %view, ... : compatibleMemRefType, index, index
-///    } else {
-///      %3 = vector.type_cast %extra_alloc :
-///        memref<...> to memref<vector<...>>
-///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %4, ... : compatibleMemRefType, index, index
-///   }
-/// ```
-static ValueRange
-getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp,
-                          TypeRange returnTypes, Value inBoundsCond,
-                          MemRefType compatibleMemRefType, Value alloc) {
-  Location loc = xferOp.getLoc();
-  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
-  Value memref = xferOp.source();
-  return b
-      .create<scf::IfOp>(
-          loc, returnTypes, inBoundsCond,
-          [&](OpBuilder &b, Location loc) {
-            Value res = memref;
-            if (compatibleMemRefType != xferOp.getShapedType())
-              res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
-            scf::ValueVector viewAndIndices{res};
-            viewAndIndices.insert(viewAndIndices.end(),
-                                  xferOp.indices().begin(),
-                                  xferOp.indices().end());
-            b.create<scf::YieldOp>(loc, viewAndIndices);
-          },
-          [&](OpBuilder &b, Location loc) {
-            Value casted =
-                b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
-            scf::ValueVector viewAndIndices{casted};
-            viewAndIndices.insert(viewAndIndices.end(),
-                                  xferOp.getTransferRank(), zero);
-            b.create<scf::YieldOp>(loc, viewAndIndices);
-          })
-      ->getResults();
-/// Given an `xferOp` for which:
-///   1. `inBoundsCond` has been computed.
-///   2. a memref of single vector `alloc` has been allocated.
-///   3. it originally wrote to %view
-/// Produce IR resembling:
-/// ```
-///    %notInBounds = arith.xori %inBounds, %true
-///    scf.if (%notInBounds) {
-///      %3 = subview %alloc [...][...][...]
-///      %4 = subview %view [0, 0][...][...]
-///      linalg.copy(%3, %4)
-///   }
-/// ```
-static void createFullPartialLinalgCopy(OpBuilder &b,
-                                        vector::TransferWriteOp xferOp,
-                                        Value inBoundsCond, Value alloc) {
-  ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
-  auto notInBounds = lb.create<arith::XOrIOp>(
-      inBoundsCond, lb.create<arith::ConstantIntOp>(true, 1));
-  lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
-    std::pair<Value, Value> copyArgs = createSubViewIntersection(
-        b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
-    b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
-    b.create<scf::YieldOp>(loc, ValueRange{});
-  });
-/// Given an `xferOp` for which:
-///   1. `inBoundsCond` has been computed.
-///   2. a memref of single vector `alloc` has been allocated.
-///   3. it originally wrote to %view
-/// Produce IR resembling:
-/// ```
-///    %notInBounds = arith.xori %inBounds, %true
-///    scf.if (%notInBounds) {
-///      %2 = load %alloc : memref<vector<...>>
-///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
-///   }
-/// ```
-static void createFullPartialVectorTransferWrite(OpBuilder &b,
-                                                 vector::TransferWriteOp xferOp,
-                                                 Value inBoundsCond,
-                                                 Value alloc) {
-  ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
-  auto notInBounds = lb.create<arith::XOrIOp>(
-      inBoundsCond, lb.create<arith::ConstantIntOp>(true, 1));
-  lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
-    BlockAndValueMapping mapping;
-    Value load = b.create<memref::LoadOp>(
-        loc, b.create<vector::TypeCastOp>(
-                 loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
-    mapping.map(xferOp.vector(), load);
-    b.clone(*xferOp.getOperation(), mapping);
-    b.create<scf::YieldOp>(loc, ValueRange{});
-  });
-/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
-/// For vector.transfer_read:
-/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
-/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
-/// scf.if op returns a view and values of type index.
-/// Example (a 2-D vector.transfer_read):
-/// ```
-///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-/// ```
-/// is transformed into:
-/// ```
-///    %1:3 = scf.if (%inBounds) {
-///      // fastpath, direct cast
-///      memref.cast %A: memref<A...> to compatibleMemRefType
-///      scf.yield %view : compatibleMemRefType, index, index
-///    } else {
-///      // slowpath, not in-bounds vector.transfer or linalg.copy.
-///      memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %4 : compatibleMemRefType, index, index
-//     }
-///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
-/// ```
-/// where `alloc` is a top of the function alloca'ed buffer of one vector.
-/// For vector.transfer_write:
-/// There are 2 conditional blocks. First a block to decide which memref and
-/// indices to use for an unmasked, inbounds write. Then a conditional block to
-/// further copy a partial buffer into the final result in the slow path case.
-/// Example (a 2-D vector.transfer_write):
-/// ```
-///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
-/// ```
-/// is transformed into:
-/// ```
-///    %1:3 = scf.if (%inBounds) {
-///      memref.cast %A: memref<A...> to compatibleMemRefType
-///      scf.yield %view : compatibleMemRefType, index, index
-///    } else {
-///      memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %4 : compatibleMemRefType, index, index
-///     }
-///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
-///                                                                    true]}
-///    scf.if (%notInBounds) {
-///      // slowpath: not in-bounds vector.transfer or linalg.copy.
-///    }
-/// ```
-/// where `alloc` is a top of the function alloca'ed buffer of one vector.
-/// Preconditions:
-///  1. `xferOp.permutation_map()` must be a minor identity map
-///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
-///  must be equal. This will be relaxed in the future but requires
-///  rank-reducing subviews.
-LogicalResult mlir::vector::splitFullAndPartialTransfer(
-    OpBuilder &b, VectorTransferOpInterface xferOp,
-    VectorTransformsOptions options, scf::IfOp *ifOp) {
-  if (options.vectorTransferSplit == VectorTransferSplit::None)
-    return failure();
-  SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
-  auto inBoundsAttr = b.getBoolArrayAttr(bools);
-  if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
-    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
-    return success();
-  }
-  // Assert preconditions. Additionally, keep the variables in an inner scope to
-  // ensure they aren't used in the wrong scopes further down.
-  {
-    assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
-           "Expected splitFullAndPartialTransferPrecondition to hold");
-    auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
-    auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
-    if (!(xferReadOp || xferWriteOp))
-      return failure();
-    if (xferWriteOp && xferWriteOp.mask())
-      return failure();
-    if (xferReadOp && xferReadOp.mask())
-      return failure();
-  }
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPoint(xferOp);
-  Value inBoundsCond = createInBoundsCond(
-      b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
-  if (!inBoundsCond)
-    return failure();
-  // Top of the function `alloc` for transient storage.
-  Value alloc;
-  {
-    FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
-    OpBuilder::InsertionGuard guard(b);
-    b.setInsertionPointToStart(&funcOp.getRegion().front());
-    auto shape = xferOp.getVectorType().getShape();
-    Type elementType = xferOp.getVectorType().getElementType();
-    alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
-                                       MemRefType::get(shape, elementType),
-                                       ValueRange{}, b.getI64IntegerAttr(32));
-  }
-  MemRefType compatibleMemRefType =
-      getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
-                                  alloc.getType().cast<MemRefType>());
-  if (!compatibleMemRefType)
-    return failure();
-  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
-                                   b.getIndexType());
-  returnTypes[0] = compatibleMemRefType;
-  if (auto xferReadOp =
-          dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
-    // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
-    scf::IfOp fullPartialIfOp =
-        options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
-            ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
-                                                  inBoundsCond,
-                                                  compatibleMemRefType, alloc)
-            : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
-                                          inBoundsCond, compatibleMemRefType,
-                                          alloc);
-    if (ifOp)
-      *ifOp = fullPartialIfOp;
-    // Set existing read op to in-bounds, it always reads from a full buffer.
-    for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
-      xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
-    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
-    return success();
-  }
-  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
-  // Decide which location to write the entire vector to.
-  auto memrefAndIndices = getLocationToWriteFullVec(
-      b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
-  // Do an in bounds write to either the output or the extra allocated buffer.
-  // The operation is cloned to prevent deleting information needed for the
-  // later IR creation.
-  BlockAndValueMapping mapping;
-  mapping.map(xferWriteOp.source(), memrefAndIndices.front());
-  mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
-  auto *clone = b.clone(*xferWriteOp, mapping);
-  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
-  // Create a potential copy from the allocated buffer to the final output in
-  // the slow path case.
-  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
-    createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
-  else
-    createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
-  xferOp->erase();
-  return success();
-LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
-  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
-      failed(filter(xferOp)))
-    return failure();
-  rewriter.startRootUpdate(xferOp);
-  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
-    rewriter.finalizeRootUpdate(xferOp);
-    return success();
-  }
-  rewriter.cancelRootUpdate(xferOp);
-  return failure();
 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
     ArrayRef<int64_t> multiplicity, const AffineMap &map) {


More information about the Mlir-commits mailing list