[Mlir-commits] [mlir] c7dd0bf - [mlir][vector] NFC - Split out transfer split patterns
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jan 5 05:38:11 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-05T08:38:04-05:00
New Revision: c7dd0bf41d8ee940893ae144cc0d813828c8233a
URL: https://github.com/llvm/llvm-project/commit/c7dd0bf41d8ee940893ae144cc0d813828c8233a
DIFF: https://github.com/llvm/llvm-project/commit/c7dd0bf41d8ee940893ae144cc0d813828c8233a.diff
LOG: [mlir][vector] NFC - Split out transfer split patterns
Differential Revision: https://reviews.llvm.org/D116648
Added:
mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 811d72192910e..32935d547df0d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -16,6 +16,8 @@ namespace mlir {
class MLIRContext;
class VectorTransferOpInterface;
class RewritePatternSet;
+class RewriterBase;
+
using OwningRewritePatternList = RewritePatternSet;
namespace scf {
@@ -61,7 +63,7 @@ namespace vector {
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult splitFullAndPartialTransfer(
- OpBuilder &b, VectorTransferOpInterface xferOp,
+ RewriterBase &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),
scf::IfOp *ifOp = nullptr);
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 143c6c7d688d1..36afc18ff0bea 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVector
VectorMultiDimReductionTransforms.cpp
VectorOps.cpp
VectorTransferOpTransforms.cpp
+ VectorTransferSplitRewritePatterns.cpp
VectorTransferPermutationMapRewritePatterns.cpp
VectorTransforms.cpp
VectorUnrollDistribute.cpp
diff --git a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
new file mode 100644
index 0000000000000..a6fbc303d64e3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
@@ -0,0 +1,625 @@
+//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent patterns to rewrite a vector.transfer
+// op into a fully in-bounds part and a partial part.
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "vector-transfer-split"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+static Optional<int64_t> extractConstantIndex(Value v) {
+ if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
+ return cstOp.value();
+ if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
+ if (affineApplyOp.getAffineMap().isSingleConstant())
+ return affineApplyOp.getAffineMap().getSingleConstantResult();
+ return None;
+}
+
+// Missing foldings of scf.if make it necessary to perform poor man's folding
+// eagerly, especially in the case of unrolling. In the future, this should go
+// away once scf.if folds properly.
+static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
+ auto maybeCstV = extractConstantIndex(v);
+ auto maybeCstUb = extractConstantIndex(ub);
+ if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
+ return Value();
+ return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
+}
+
+/// Build the condition to ensure that a particular VectorTransferOpInterface
+/// is in-bounds.
+static Value createInBoundsCond(RewriterBase &b,
+ VectorTransferOpInterface xferOp) {
+ assert(xferOp.permutation_map().isMinorIdentity() &&
+ "Expected minor identity map");
+ Value inBoundsCond;
+ xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+ // Zip over the resulting vector shape and memref indices.
+ // If the dimension is known to be in-bounds, it does not participate in
+ // the construction of `inBoundsCond`.
+ if (xferOp.isDimInBounds(resultIdx))
+ return;
+ // Fold or create the check that `index + vector_size` <= `memref_size`.
+ Location loc = xferOp.getLoc();
+ int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
+ auto d0 = getAffineDimExpr(0, xferOp.getContext());
+ auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
+ Value sum =
+ makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
+ Value cond = createFoldedSLE(
+ b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
+ if (!cond)
+ return;
+ // Conjunction over all dims for which we are in-bounds.
+ if (inBoundsCond)
+ inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
+ else
+ inBoundsCond = cond;
+ });
+ return inBoundsCond;
+}
+
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
+///
+/// Example (a 2-D vector.transfer_read):
+/// ```
+/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// // fastpath, direct cast
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// // slowpath, not in-bounds vector.transfer or linalg.copy.
+/// memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
+// }
+/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// Preconditions:
+/// 1. `xferOp.permutation_map()` must be a minor identity map
+/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+/// must be equal. This will be relaxed in the future but requires
+/// rank-reducing subviews.
+static LogicalResult
+splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
+ // TODO: expand support to these 2 cases.
+ if (!xferOp.permutation_map().isMinorIdentity())
+ return failure();
+ // Must have some out-of-bounds dimension to be a candidate for splitting.
+ if (!xferOp.hasOutOfBoundsDim())
+ return failure();
+ // Don't split transfer operations directly under IfOp, this avoids applying
+ // the pattern recursively.
+ // TODO: improve the filtering condition to make it more applicable.
+ if (isa<scf::IfOp>(xferOp->getParentOp()))
+ return failure();
+ return success();
+}
+
+/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
+/// be cast. If the MemRefTypes don't have the same rank or are not strided,
+/// return null; otherwise:
+/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
+/// 2. else return a new MemRefType obtained by iterating over the shape and
+/// strides and:
+/// a. keeping the ones that are static and equal across `aT` and `bT`.
+/// b. using a dynamic shape and/or stride for the dimensions that don't
+/// agree.
+static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
+ if (memref::CastOp::areCastCompatible(aT, bT))
+ return aT;
+ if (aT.getRank() != bT.getRank())
+ return MemRefType();
+ int64_t aOffset, bOffset;
+ SmallVector<int64_t, 4> aStrides, bStrides;
+ if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
+ failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+ aStrides.size() != bStrides.size())
+ return MemRefType();
+
+ ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
+ int64_t resOffset;
+ SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
+ resStrides(bT.getRank(), 0);
+ for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
+ resShape[idx] =
+ (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
+ resStrides[idx] = (aStrides[idx] == bStrides[idx])
+ ? aStrides[idx]
+ : MemRefType::kDynamicStrideOrOffset;
+ }
+ resOffset =
+ (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
+ return MemRefType::get(
+ resShape, aT.getElementType(),
+ makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
+}
+
+/// Operates under a scoped context to build the intersection between the
+/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
+// TODO: view intersection/union/
diff erences should be a proper std op.
+static std::pair<Value, Value>
+createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
+ Value alloc) {
+ Location loc = xferOp.getLoc();
+ int64_t memrefRank = xferOp.getShapedType().getRank();
+ // TODO: relax this precondition, will require rank-reducing subviews.
+ assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
+ "Expected memref rank to match the alloc rank");
+ ValueRange leadingIndices =
+ xferOp.indices().take_front(xferOp.getLeadingShapedRank());
+ SmallVector<OpFoldResult, 4> sizes;
+ sizes.append(leadingIndices.begin(), leadingIndices.end());
+ auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
+ xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
+ xferOp.source(), indicesIdx);
+ Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
+ Value index = xferOp.indices()[indicesIdx];
+ AffineExpr i, j, k;
+ bindDims(xferOp.getContext(), i, j, k);
+ SmallVector<AffineMap, 4> maps =
+ AffineMap::inferFromExprList(MapList{{i - j, k}});
+ // affine_min(%dimMemRef - %index, %dimAlloc)
+ Value affineMin = b.create<AffineMinOp>(
+ loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
+ sizes.push_back(affineMin);
+ });
+
+ SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
+ xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
+ SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
+ auto copySrc = b.create<memref::SubViewOp>(
+ loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
+ auto copyDest = b.create<memref::SubViewOp>(
+ loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
+ return std::make_pair(copySrc, copyDest);
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view, ... : compatibleMemRefType, index, index
+/// } else {
+/// %2 = linalg.fill(%pad, %alloc)
+/// %3 = subview %view [...][...][...]
+/// %4 = subview %alloc [0, 0] [...] [...]
+/// linalg.copy(%3, %4)
+/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %5, ... : compatibleMemRefType, index, index
+/// }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp
+createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
+ TypeRange returnTypes, Value inBoundsCond,
+ MemRefType compatibleMemRefType, Value alloc) {
+ Location loc = xferOp.getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value memref = xferOp.source();
+ return b.create<scf::IfOp>(
+ loc, returnTypes, inBoundsCond,
+ [&](OpBuilder &b, Location loc) {
+ Value res = memref;
+ if (compatibleMemRefType != xferOp.getShapedType())
+ res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+ xferOp.indices().end());
+ b.create<scf::YieldOp>(loc, viewAndIndices);
+ },
+ [&](OpBuilder &b, Location loc) {
+ b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
+ // Take partial subview of memref which guarantees no dimension
+ // overflows.
+ IRRewriter rewriter(b);
+ std::pair<Value, Value> copyArgs = createSubViewIntersection(
+ rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
+ alloc);
+ b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
+ Value casted =
+ b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+ zero);
+ b.create<scf::YieldOp>(loc, viewAndIndices);
+ });
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view, ... : compatibleMemRefType, index, index
+/// } else {
+/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
+/// %3 = vector.type_cast %extra_alloc :
+/// memref<...> to memref<vector<...>>
+/// store %2, %3[] : memref<vector<...>>
+/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4, ... : compatibleMemRefType, index, index
+/// }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp createFullPartialVectorTransferRead(
+ RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
+ Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
+ Location loc = xferOp.getLoc();
+ scf::IfOp fullPartialIfOp;
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value memref = xferOp.source();
+ return b.create<scf::IfOp>(
+ loc, returnTypes, inBoundsCond,
+ [&](OpBuilder &b, Location loc) {
+ Value res = memref;
+ if (compatibleMemRefType != xferOp.getShapedType())
+ res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+ xferOp.indices().end());
+ b.create<scf::YieldOp>(loc, viewAndIndices);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Operation *newXfer = b.clone(*xferOp.getOperation());
+ Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
+ b.create<memref::StoreOp>(
+ loc, vector,
+ b.create<vector::TypeCastOp>(
+ loc, MemRefType::get({}, vector.getType()), alloc));
+
+ Value casted =
+ b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+ zero);
+ b.create<scf::YieldOp>(loc, viewAndIndices);
+ });
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view, ... : compatibleMemRefType, index, index
+/// } else {
+/// %3 = vector.type_cast %extra_alloc :
+/// memref<...> to memref<vector<...>>
+/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4, ... : compatibleMemRefType, index, index
+/// }
+/// ```
+static ValueRange
+getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
+ TypeRange returnTypes, Value inBoundsCond,
+ MemRefType compatibleMemRefType, Value alloc) {
+ Location loc = xferOp.getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value memref = xferOp.source();
+ return b
+ .create<scf::IfOp>(
+ loc, returnTypes, inBoundsCond,
+ [&](OpBuilder &b, Location loc) {
+ Value res = memref;
+ if (compatibleMemRefType != xferOp.getShapedType())
+ res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ viewAndIndices.insert(viewAndIndices.end(),
+ xferOp.indices().begin(),
+ xferOp.indices().end());
+ b.create<scf::YieldOp>(loc, viewAndIndices);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value casted =
+ b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(),
+ xferOp.getTransferRank(), zero);
+ b.create<scf::YieldOp>(loc, viewAndIndices);
+ })
+ ->getResults();
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` has been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// 3. it originally wrote to %view
+/// Produce IR resembling:
+/// ```
+/// %notInBounds = arith.xori %inBounds, %true
+/// scf.if (%notInBounds) {
+/// %3 = subview %alloc [...][...][...]
+/// %4 = subview %view [0, 0][...][...]
+/// linalg.copy(%3, %4)
+/// }
+/// ```
+static void createFullPartialLinalgCopy(RewriterBase &b,
+ vector::TransferWriteOp xferOp,
+ Value inBoundsCond, Value alloc) {
+ Location loc = xferOp.getLoc();
+ auto notInBounds = b.create<arith::XOrIOp>(
+ loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
+ b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
+ IRRewriter rewriter(b);
+ std::pair<Value, Value> copyArgs = createSubViewIntersection(
+ rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
+ alloc);
+ b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
+ b.create<scf::YieldOp>(loc, ValueRange{});
+ });
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` has been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// 3. it originally wrote to %view
+/// Produce IR resembling:
+/// ```
+/// %notInBounds = arith.xori %inBounds, %true
+/// scf.if (%notInBounds) {
+/// %2 = load %alloc : memref<vector<...>>
+/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
+/// }
+/// ```
+static void createFullPartialVectorTransferWrite(RewriterBase &b,
+ vector::TransferWriteOp xferOp,
+ Value inBoundsCond,
+ Value alloc) {
+ Location loc = xferOp.getLoc();
+ auto notInBounds = b.create<arith::XOrIOp>(
+ loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
+ b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
+ BlockAndValueMapping mapping;
+ Value load = b.create<memref::LoadOp>(
+ loc, b.create<vector::TypeCastOp>(
+ loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
+ mapping.map(xferOp.vector(), load);
+ b.clone(*xferOp.getOperation(), mapping);
+ b.create<scf::YieldOp>(loc, ValueRange{});
+ });
+}
+
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fastpath and a slowpath.
+///
+/// For vector.transfer_read:
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+///
+/// Example (a 2-D vector.transfer_read):
+/// ```
+/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// // fastpath, direct cast
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// // slowpath, not in-bounds vector.transfer or linalg.copy.
+/// memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
+// }
+/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// For vector.transfer_write:
+/// There are 2 conditional blocks. First a block to decide which memref and
+/// indices to use for an unmasked, inbounds write. Then a conditional block to
+/// further copy a partial buffer into the final result in the slow path case.
+///
+/// Example (a 2-D vector.transfer_write):
+/// ```
+/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
+/// }
+/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
+/// true]}
+/// scf.if (%notInBounds) {
+/// // slowpath: not in-bounds vector.transfer or linalg.copy.
+/// }
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// Preconditions:
+/// 1. `xferOp.permutation_map()` must be a minor identity map
+/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
+/// must be equal. This will be relaxed in the future but requires
+/// rank-reducing subviews.
+LogicalResult mlir::vector::splitFullAndPartialTransfer(
+ RewriterBase &b, VectorTransferOpInterface xferOp,
+ VectorTransformsOptions options, scf::IfOp *ifOp) {
+ if (options.vectorTransferSplit == VectorTransferSplit::None)
+ return failure();
+
+ SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
+ auto inBoundsAttr = b.getBoolArrayAttr(bools);
+ if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
+ xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+ return success();
+ }
+
+ // Assert preconditions. Additionally, keep the variables in an inner scope to
+ // ensure they aren't used in the wrong scopes further down.
+ {
+ assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
+ "Expected splitFullAndPartialTransferPrecondition to hold");
+
+ auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
+ auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
+
+ if (!(xferReadOp || xferWriteOp))
+ return failure();
+ if (xferWriteOp && xferWriteOp.mask())
+ return failure();
+ if (xferReadOp && xferReadOp.mask())
+ return failure();
+ }
+
+ RewriterBase::InsertionGuard guard(b);
+ b.setInsertionPoint(xferOp);
+ Value inBoundsCond = createInBoundsCond(
+ b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
+ if (!inBoundsCond)
+ return failure();
+
+ // Top of the function `alloc` for transient storage.
+ Value alloc;
+ {
+ FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
+ RewriterBase::InsertionGuard guard(b);
+ b.setInsertionPointToStart(&funcOp.getRegion().front());
+ auto shape = xferOp.getVectorType().getShape();
+ Type elementType = xferOp.getVectorType().getElementType();
+ alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
+ MemRefType::get(shape, elementType),
+ ValueRange{}, b.getI64IntegerAttr(32));
+ }
+
+ MemRefType compatibleMemRefType =
+ getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
+ alloc.getType().cast<MemRefType>());
+ if (!compatibleMemRefType)
+ return failure();
+
+ SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
+ b.getIndexType());
+ returnTypes[0] = compatibleMemRefType;
+
+ if (auto xferReadOp =
+ dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
+ // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
+ scf::IfOp fullPartialIfOp =
+ options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
+ ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
+ inBoundsCond,
+ compatibleMemRefType, alloc)
+ : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
+ inBoundsCond, compatibleMemRefType,
+ alloc);
+ if (ifOp)
+ *ifOp = fullPartialIfOp;
+
+ // Set existing read op to in-bounds, it always reads from a full buffer.
+ for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
+ xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
+
+ xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+
+ return success();
+ }
+
+ auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
+
+ // Decide which location to write the entire vector to.
+ auto memrefAndIndices = getLocationToWriteFullVec(
+ b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
+
+ // Do an in bounds write to either the output or the extra allocated buffer.
+ // The operation is cloned to prevent deleting information needed for the
+ // later IR creation.
+ BlockAndValueMapping mapping;
+ mapping.map(xferWriteOp.source(), memrefAndIndices.front());
+ mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
+ auto *clone = b.clone(*xferWriteOp, mapping);
+ clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
+
+ // Create a potential copy from the allocated buffer to the final output in
+ // the slow path case.
+ if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
+ createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
+ else
+ createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
+
+ xferOp->erase();
+
+ return success();
+}
+
+LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
+ if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
+ failed(filter(xferOp)))
+ return failure();
+ rewriter.startRootUpdate(xferOp);
+ if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
+ rewriter.finalizeRootUpdate(xferOp);
+ return success();
+ }
+ rewriter.cancelRootUpdate(xferOp);
+ return failure();
+}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 0b49ccd58b27d..cd26750c9eb7a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -144,7 +144,6 @@ static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
namespace {
-
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
//
// Example:
@@ -1642,587 +1641,6 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
} // namespace mlir
-static Optional<int64_t> extractConstantIndex(Value v) {
- if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
- return cstOp.value();
- if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
- if (affineApplyOp.getAffineMap().isSingleConstant())
- return affineApplyOp.getAffineMap().getSingleConstantResult();
- return None;
-}
-
-// Missing foldings of scf.if make it necessary to perform poor man's folding
-// eagerly, especially in the case of unrolling. In the future, this should go
-// away once scf.if folds properly.
-static Value createFoldedSLE(OpBuilder &b, Value v, Value ub) {
- auto maybeCstV = extractConstantIndex(v);
- auto maybeCstUb = extractConstantIndex(ub);
- if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
- return Value();
- return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
-}
-
-// Operates under a scoped context to build the condition to ensure that a
-// particular VectorTransferOpInterface is in-bounds.
-static Value createInBoundsCond(OpBuilder &b,
- VectorTransferOpInterface xferOp) {
- assert(xferOp.permutation_map().isMinorIdentity() &&
- "Expected minor identity map");
- Value inBoundsCond;
- xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
- // Zip over the resulting vector shape and memref indices.
- // If the dimension is known to be in-bounds, it does not participate in
- // the construction of `inBoundsCond`.
- if (xferOp.isDimInBounds(resultIdx))
- return;
- // Fold or create the check that `index + vector_size` <= `memref_size`.
- Location loc = xferOp.getLoc();
- ImplicitLocOpBuilder lb(loc, b);
- int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
- auto d0 = getAffineDimExpr(0, xferOp.getContext());
- auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
- Value sum =
- makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
- Value cond = createFoldedSLE(
- b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
- if (!cond)
- return;
- // Conjunction over all dims for which we are in-bounds.
- if (inBoundsCond)
- inBoundsCond = lb.create<arith::AndIOp>(inBoundsCond, cond);
- else
- inBoundsCond = cond;
- });
- return inBoundsCond;
-}
-/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
-/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
-/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
-/// scf.if op returns a view and values of type index.
-/// At this time, only vector.transfer_read case is implemented.
-///
-/// Example (a 2-D vector.transfer_read):
-/// ```
-/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-/// ```
-/// is transformed into:
-/// ```
-/// %1:3 = scf.if (%inBounds) {
-/// // fastpath, direct cast
-/// memref.cast %A: memref<A...> to compatibleMemRefType
-/// scf.yield %view : compatibleMemRefType, index, index
-/// } else {
-/// // slowpath, not in-bounds vector.transfer or linalg.copy.
-/// memref.cast %alloc: memref<B...> to compatibleMemRefType
-/// scf.yield %4 : compatibleMemRefType, index, index
-// }
-/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
-/// ```
-/// where `alloc` is a top of the function alloca'ed buffer of one vector.
-///
-/// Preconditions:
-/// 1. `xferOp.permutation_map()` must be a minor identity map
-/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
-/// must be equal. This will be relaxed in the future but requires
-/// rank-reducing subviews.
-static LogicalResult
-splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
- // TODO: support 0-d corner case.
- if (xferOp.getTransferRank() == 0)
- return failure();
-
- // TODO: expand support to these 2 cases.
- if (!xferOp.permutation_map().isMinorIdentity())
- return failure();
- // Must have some out-of-bounds dimension to be a candidate for splitting.
- if (!xferOp.hasOutOfBoundsDim())
- return failure();
- // Don't split transfer operations directly under IfOp, this avoids applying
- // the pattern recursively.
- // TODO: improve the filtering condition to make it more applicable.
- if (isa<scf::IfOp>(xferOp->getParentOp()))
- return failure();
- return success();
-}
-
-/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
-/// be cast. If the MemRefTypes don't have the same rank or are not strided,
-/// return null; otherwise:
-/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
-/// 2. else return a new MemRefType obtained by iterating over the shape and
-/// strides and:
-/// a. keeping the ones that are static and equal across `aT` and `bT`.
-/// b. using a dynamic shape and/or stride for the dimensions that don't
-/// agree.
-static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
- if (memref::CastOp::areCastCompatible(aT, bT))
- return aT;
- if (aT.getRank() != bT.getRank())
- return MemRefType();
- int64_t aOffset, bOffset;
- SmallVector<int64_t, 4> aStrides, bStrides;
- if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
- failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
- aStrides.size() != bStrides.size())
- return MemRefType();
-
- ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
- int64_t resOffset;
- SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
- resStrides(bT.getRank(), 0);
- for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
- resShape[idx] =
- (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
- resStrides[idx] = (aStrides[idx] == bStrides[idx])
- ? aStrides[idx]
- : MemRefType::kDynamicStrideOrOffset;
- }
- resOffset =
- (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
- return MemRefType::get(
- resShape, aT.getElementType(),
- makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
-}
-
-/// Operates under a scoped context to build the intersection between the
-/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
-// TODO: view intersection/union/
diff erences should be a proper std op.
-static std::pair<Value, Value>
-createSubViewIntersection(OpBuilder &b, VectorTransferOpInterface xferOp,
- Value alloc) {
- ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
- int64_t memrefRank = xferOp.getShapedType().getRank();
- // TODO: relax this precondition, will require rank-reducing subviews.
- assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
- "Expected memref rank to match the alloc rank");
- ValueRange leadingIndices =
- xferOp.indices().take_front(xferOp.getLeadingShapedRank());
- SmallVector<OpFoldResult, 4> sizes;
- sizes.append(leadingIndices.begin(), leadingIndices.end());
- auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
- xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
- xferOp.source(), indicesIdx);
- Value dimAlloc = lb.create<memref::DimOp>(alloc, resultIdx);
- Value index = xferOp.indices()[indicesIdx];
- AffineExpr i, j, k;
- bindDims(xferOp.getContext(), i, j, k);
- SmallVector<AffineMap, 4> maps =
- AffineMap::inferFromExprList(MapList{{i - j, k}});
- // affine_min(%dimMemRef - %index, %dimAlloc)
- Value affineMin = lb.create<AffineMinOp>(
- index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
- sizes.push_back(affineMin);
- });
-
- SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
- xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
- SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
- auto copySrc = lb.create<memref::SubViewOp>(
- isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
- auto copyDest = lb.create<memref::SubViewOp>(
- isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
- return std::make_pair(copySrc, copyDest);
-}
-
-/// Given an `xferOp` for which:
-/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
-/// 2. a memref of single vector `alloc` has been allocated.
-/// Produce IR resembling:
-/// ```
-/// %1:3 = scf.if (%inBounds) {
-/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
-/// scf.yield %view, ... : compatibleMemRefType, index, index
-/// } else {
-/// %2 = linalg.fill(%pad, %alloc)
-/// %3 = subview %view [...][...][...]
-/// %4 = subview %alloc [0, 0] [...] [...]
-/// linalg.copy(%3, %4)
-/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
-/// scf.yield %5, ... : compatibleMemRefType, index, index
-/// }
-/// ```
-/// Return the produced scf::IfOp.
-static scf::IfOp
-createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
- TypeRange returnTypes, Value inBoundsCond,
- MemRefType compatibleMemRefType, Value alloc) {
- Location loc = xferOp.getLoc();
- Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- Value memref = xferOp.source();
- return b.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
- [&](OpBuilder &b, Location loc) {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
- scf::ValueVector viewAndIndices{res};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
- xferOp.indices().end());
- b.create<scf::YieldOp>(loc, viewAndIndices);
- },
- [&](OpBuilder &b, Location loc) {
- b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
- // Take partial subview of memref which guarantees no dimension
- // overflows.
- std::pair<Value, Value> copyArgs = createSubViewIntersection(
- b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
- b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
- Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
- scf::ValueVector viewAndIndices{casted};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
- zero);
- b.create<scf::YieldOp>(loc, viewAndIndices);
- });
-}
-
-/// Given an `xferOp` for which:
-/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
-/// 2. a memref of single vector `alloc` has been allocated.
-/// Produce IR resembling:
-/// ```
-/// %1:3 = scf.if (%inBounds) {
-/// memref.cast %A: memref<A...> to compatibleMemRefType
-/// scf.yield %view, ... : compatibleMemRefType, index, index
-/// } else {
-/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
-/// %3 = vector.type_cast %extra_alloc :
-/// memref<...> to memref<vector<...>>
-/// store %2, %3[] : memref<vector<...>>
-/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
-/// scf.yield %4, ... : compatibleMemRefType, index, index
-/// }
-/// ```
-/// Return the produced scf::IfOp.
-static scf::IfOp createFullPartialVectorTransferRead(
- OpBuilder &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
- Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
- Location loc = xferOp.getLoc();
- scf::IfOp fullPartialIfOp;
- Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- Value memref = xferOp.source();
- return b.create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
- [&](OpBuilder &b, Location loc) {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
- scf::ValueVector viewAndIndices{res};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
- xferOp.indices().end());
- b.create<scf::YieldOp>(loc, viewAndIndices);
- },
- [&](OpBuilder &b, Location loc) {
- Operation *newXfer = b.clone(*xferOp.getOperation());
- Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
- b.create<memref::StoreOp>(
- loc, vector,
- b.create<vector::TypeCastOp>(
- loc, MemRefType::get({}, vector.getType()), alloc));
-
- Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
- scf::ValueVector viewAndIndices{casted};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
- zero);
- b.create<scf::YieldOp>(loc, viewAndIndices);
- });
-}
-
-/// Given an `xferOp` for which:
-/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
-/// 2. a memref of single vector `alloc` has been allocated.
-/// Produce IR resembling:
-/// ```
-/// %1:3 = scf.if (%inBounds) {
-/// memref.cast %A: memref<A...> to compatibleMemRefType
-/// scf.yield %view, ... : compatibleMemRefType, index, index
-/// } else {
-/// %3 = vector.type_cast %extra_alloc :
-/// memref<...> to memref<vector<...>>
-/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
-/// scf.yield %4, ... : compatibleMemRefType, index, index
-/// }
-/// ```
-static ValueRange
-getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp,
- TypeRange returnTypes, Value inBoundsCond,
- MemRefType compatibleMemRefType, Value alloc) {
- Location loc = xferOp.getLoc();
- Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- Value memref = xferOp.source();
- return b
- .create<scf::IfOp>(
- loc, returnTypes, inBoundsCond,
- [&](OpBuilder &b, Location loc) {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
- scf::ValueVector viewAndIndices{res};
- viewAndIndices.insert(viewAndIndices.end(),
- xferOp.indices().begin(),
- xferOp.indices().end());
- b.create<scf::YieldOp>(loc, viewAndIndices);
- },
- [&](OpBuilder &b, Location loc) {
- Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
- scf::ValueVector viewAndIndices{casted};
- viewAndIndices.insert(viewAndIndices.end(),
- xferOp.getTransferRank(), zero);
- b.create<scf::YieldOp>(loc, viewAndIndices);
- })
- ->getResults();
-}
-
-/// Given an `xferOp` for which:
-/// 1. `inBoundsCond` has been computed.
-/// 2. a memref of single vector `alloc` has been allocated.
-/// 3. it originally wrote to %view
-/// Produce IR resembling:
-/// ```
-/// %notInBounds = arith.xori %inBounds, %true
-/// scf.if (%notInBounds) {
-/// %3 = subview %alloc [...][...][...]
-/// %4 = subview %view [0, 0][...][...]
-/// linalg.copy(%3, %4)
-/// }
-/// ```
-static void createFullPartialLinalgCopy(OpBuilder &b,
- vector::TransferWriteOp xferOp,
- Value inBoundsCond, Value alloc) {
- ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
- auto notInBounds = lb.create<arith::XOrIOp>(
- inBoundsCond, lb.create<arith::ConstantIntOp>(true, 1));
- lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
- std::pair<Value, Value> copyArgs = createSubViewIntersection(
- b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
- b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
- b.create<scf::YieldOp>(loc, ValueRange{});
- });
-}
-
-/// Given an `xferOp` for which:
-/// 1. `inBoundsCond` has been computed.
-/// 2. a memref of single vector `alloc` has been allocated.
-/// 3. it originally wrote to %view
-/// Produce IR resembling:
-/// ```
-/// %notInBounds = arith.xori %inBounds, %true
-/// scf.if (%notInBounds) {
-/// %2 = load %alloc : memref<vector<...>>
-/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
-/// }
-/// ```
-static void createFullPartialVectorTransferWrite(OpBuilder &b,
- vector::TransferWriteOp xferOp,
- Value inBoundsCond,
- Value alloc) {
- ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
- auto notInBounds = lb.create<arith::XOrIOp>(
- inBoundsCond, lb.create<arith::ConstantIntOp>(true, 1));
- lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
- BlockAndValueMapping mapping;
- Value load = b.create<memref::LoadOp>(
- loc, b.create<vector::TypeCastOp>(
- loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
- mapping.map(xferOp.vector(), load);
- b.clone(*xferOp.getOperation(), mapping);
- b.create<scf::YieldOp>(loc, ValueRange{});
- });
-}
-
-/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
-///
-/// For vector.transfer_read:
-/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
-/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
-/// scf.if op returns a view and values of type index.
-///
-/// Example (a 2-D vector.transfer_read):
-/// ```
-/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-/// ```
-/// is transformed into:
-/// ```
-/// %1:3 = scf.if (%inBounds) {
-/// // fastpath, direct cast
-/// memref.cast %A: memref<A...> to compatibleMemRefType
-/// scf.yield %view : compatibleMemRefType, index, index
-/// } else {
-/// // slowpath, not in-bounds vector.transfer or linalg.copy.
-/// memref.cast %alloc: memref<B...> to compatibleMemRefType
-/// scf.yield %4 : compatibleMemRefType, index, index
-// }
-/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
-/// ```
-/// where `alloc` is a top of the function alloca'ed buffer of one vector.
-///
-/// For vector.transfer_write:
-/// There are 2 conditional blocks. First a block to decide which memref and
-/// indices to use for an unmasked, inbounds write. Then a conditional block to
-/// further copy a partial buffer into the final result in the slow path case.
-///
-/// Example (a 2-D vector.transfer_write):
-/// ```
-/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
-/// ```
-/// is transformed into:
-/// ```
-/// %1:3 = scf.if (%inBounds) {
-/// memref.cast %A: memref<A...> to compatibleMemRefType
-/// scf.yield %view : compatibleMemRefType, index, index
-/// } else {
-/// memref.cast %alloc: memref<B...> to compatibleMemRefType
-/// scf.yield %4 : compatibleMemRefType, index, index
-/// }
-/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
-/// true]}
-/// scf.if (%notInBounds) {
-/// // slowpath: not in-bounds vector.transfer or linalg.copy.
-/// }
-/// ```
-/// where `alloc` is a top of the function alloca'ed buffer of one vector.
-///
-/// Preconditions:
-/// 1. `xferOp.permutation_map()` must be a minor identity map
-/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
-/// must be equal. This will be relaxed in the future but requires
-/// rank-reducing subviews.
-LogicalResult mlir::vector::splitFullAndPartialTransfer(
- OpBuilder &b, VectorTransferOpInterface xferOp,
- VectorTransformsOptions options, scf::IfOp *ifOp) {
- if (options.vectorTransferSplit == VectorTransferSplit::None)
- return failure();
-
- SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
- auto inBoundsAttr = b.getBoolArrayAttr(bools);
- if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
- xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
- return success();
- }
-
- // Assert preconditions. Additionally, keep the variables in an inner scope to
- // ensure they aren't used in the wrong scopes further down.
- {
- assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
- "Expected splitFullAndPartialTransferPrecondition to hold");
-
- auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
- auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
-
- if (!(xferReadOp || xferWriteOp))
- return failure();
- if (xferWriteOp && xferWriteOp.mask())
- return failure();
- if (xferReadOp && xferReadOp.mask())
- return failure();
- }
-
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(xferOp);
- Value inBoundsCond = createInBoundsCond(
- b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
- if (!inBoundsCond)
- return failure();
-
- // Top of the function `alloc` for transient storage.
- Value alloc;
- {
- FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(&funcOp.getRegion().front());
- auto shape = xferOp.getVectorType().getShape();
- Type elementType = xferOp.getVectorType().getElementType();
- alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
- MemRefType::get(shape, elementType),
- ValueRange{}, b.getI64IntegerAttr(32));
- }
-
- MemRefType compatibleMemRefType =
- getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
- alloc.getType().cast<MemRefType>());
- if (!compatibleMemRefType)
- return failure();
-
- SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
- b.getIndexType());
- returnTypes[0] = compatibleMemRefType;
-
- if (auto xferReadOp =
- dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
- // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
- scf::IfOp fullPartialIfOp =
- options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
- ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
- inBoundsCond,
- compatibleMemRefType, alloc)
- : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
- inBoundsCond, compatibleMemRefType,
- alloc);
- if (ifOp)
- *ifOp = fullPartialIfOp;
-
- // Set existing read op to in-bounds, it always reads from a full buffer.
- for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
- xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
-
- xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
-
- return success();
- }
-
- auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
-
- // Decide which location to write the entire vector to.
- auto memrefAndIndices = getLocationToWriteFullVec(
- b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
-
- // Do an in bounds write to either the output or the extra allocated buffer.
- // The operation is cloned to prevent deleting information needed for the
- // later IR creation.
- BlockAndValueMapping mapping;
- mapping.map(xferWriteOp.source(), memrefAndIndices.front());
- mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
- auto *clone = b.clone(*xferWriteOp, mapping);
- clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
-
- // Create a potential copy from the allocated buffer to the final output in
- // the slow path case.
- if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
- createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
- else
- createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
-
- xferOp->erase();
-
- return success();
-}
-
-LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
- if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
- failed(filter(xferOp)))
- return failure();
- rewriter.startRootUpdate(xferOp);
- if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
- rewriter.finalizeRootUpdate(xferOp);
- return success();
- }
- rewriter.cancelRootUpdate(xferOp);
- return failure();
-}
-
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
More information about the Mlir-commits
mailing list