[Mlir-commits] [mlir] 1a4263d - [mlir][Vector] Add linalg.copy-based pattern for splitting vector.transfer_read into full and partial copies.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Aug 4 05:55:58 PDT 2020
Author: Nicolas Vasilache
Date: 2020-08-04T08:46:08-04:00
New Revision: 1a4263d394c1a93757613bde4b1c2cf8d6a7bbb9
URL: https://github.com/llvm/llvm-project/commit/1a4263d394c1a93757613bde4b1c2cf8d6a7bbb9
DIFF: https://github.com/llvm/llvm-project/commit/1a4263d394c1a93757613bde4b1c2cf8d6a7bbb9.diff
LOG: [mlir][Vector] Add linalg.copy-based pattern for splitting vector.transfer_read into full and partial copies.
This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling:
```
%1:3 = scf.if (%inBounds) {
scf.yield %view : memref<A...>, index, index
} else {
%2 = linalg.fill(%extra_alloc, %pad)
%3 = subview %view [...][...][...]
linalg.copy(%3, %alloc)
memref_cast %extra_alloc: memref<B...> to memref<A...>
scf.yield %4 : memref<A...>, index, index
}
%res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
```
where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer.
The extra work only occurs on the boundary tiles.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index edf9557df389..562e07f98774 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -56,22 +56,48 @@ enum class VectorContractLowering {
};
/// Enum to control the lowering of `vector.transpose` operations.
enum class VectorTransposeLowering {
- // Lower transpose into element-wise extract and inserts.
+ /// Lower transpose into element-wise extract and inserts.
EltWise = 0,
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
};
+/// Enum to control the splitting of `vector.transfer` operations into masked
+/// and unmasked variants.
+enum class VectorTransferSplit {
+ /// Do not split vector transfer operations.
+ None = 0,
+ /// Split using masked + unmasked vector.transfer operations.
+ VectorTransfer = 1,
+ /// Split using a unmasked vector.transfer + linalg.fill + linalg.copy
+ /// operations.
+ LinalgCopy = 2,
+ /// Do not split vector transfer operation but instead mark it as "unmasked".
+ ForceUnmasked = 3
+};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
+ /// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
- VectorTransposeLowering vectorTransposeLowering =
- VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
+ /// Option to control the lowering of vector.transpose.
+ VectorTransposeLowering vectorTransposeLowering =
+ VectorTransposeLowering::EltWise;
+ VectorTransformsOptions &
+ setVectorTransposeLowering(VectorTransposeLowering opt) {
+ vectorTransposeLowering = opt;
+ return *this;
+ }
+ /// Option to control the splitting of vector transfers.
+ VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+ VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+ vectorTransferSplit = opt;
+ return *this;
+ }
};
/// Collect a set of transformation patterns that are related to contracting
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 835ad18a79ad..e6c7b7abebd5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -109,13 +109,13 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
FilterConstraintType filter;
};
-/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
-/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
-/// is `success, the `ifOp` points to the newly created conditional upon
-/// function return. To accomodate for the fact that the original
-/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
-/// in the temporary buffer, the scf.if op returns a view and values of type
-/// index. At this time, only vector.transfer_read is implemented.
+/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
@@ -124,17 +124,17 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
-/// scf.yield %0 : memref<A...>, index, index
-/// } else {
-/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-/// %3 = vector.type_cast %extra_alloc : memref<...> to
-/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
-/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
-/// memref<A...>, index, index
+/// // fastpath, direct cast
+/// memref_cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// // slowpath, masked vector.transfer or linalg.copy.
+/// memref_cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
-/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
@@ -143,9 +143,10 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
/// rank-reducing subviews.
LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
-LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
- VectorTransferOpInterface xferOp,
- scf::IfOp *ifOp = nullptr);
+LogicalResult splitFullAndPartialTransfer(
+ OpBuilder &b, VectorTransferOpInterface xferOp,
+ VectorTransformsOptions options = VectorTransformsOptions(),
+ scf::IfOp *ifOp = nullptr);
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
@@ -155,16 +156,19 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
+ VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
- : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
+ : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
+ filter(filter) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
+ VectorTransformsOptions options;
FilterConstraintType filter;
};
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 13dbf6da73fa..1087feba7fbd 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVector
MLIRIR
MLIRStandardOps
MLIRAffineOps
+ MLIRLinalgOps
MLIRSCF
MLIRLoopAnalysis
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 573b822503f3..3c23c5a6d869 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -2056,7 +2057,16 @@ LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
return success();
}
-MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
+/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
+/// be cast. If the MemRefTypes don't have the same rank or are not strided,
+/// return null; otherwise:
+/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
+/// 2. else return a new MemRefType obtained by iterating over the shape and
+/// strides and:
+/// a. keeping the ones that are static and equal across `aT` and `bT`.
+/// b. using a dynamic shape and/or stride for the dimeniosns that don't
+/// agree.
+static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
if (MemRefCastOp::areCastCompatible(aT, bT))
return aT;
if (aT.getRank() != bT.getRank())
@@ -2086,13 +2096,154 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
}
-/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
-/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
-/// is `success, the `ifOp` points to the newly created conditional upon
-/// function return. To accomodate for the fact that the original
-/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
-/// in the temporary buffer, the scf.if op returns a view and values of type
-/// index. At this time, only vector.transfer_read is implemented.
+/// Operates under a scoped context to build the intersection between the
+/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
+// TODO: view intersection/union/
diff erences should be a proper std op.
+static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
+ Value alloc) {
+ using namespace edsc::intrinsics;
+ int64_t memrefRank = xferOp.getMemRefType().getRank();
+ // TODO: relax this precondition, will require rank-reducing subviews.
+ assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
+ "Expected memref rank to match the alloc rank");
+ Value one = std_constant_index(1);
+ ValueRange leadingIndices =
+ xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
+ SmallVector<Value, 4> sizes;
+ sizes.append(leadingIndices.begin(), leadingIndices.end());
+ xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
+ Value dimAlloc = std_dim(alloc, resultIdx);
+ Value index = xferOp.indices()[indicesIdx];
+ AffineExpr i, j, k;
+ bindDims(xferOp.getContext(), i, j, k);
+ SmallVector<AffineMap, 4> maps =
+ AffineMap::inferFromExprList(MapList{{i - j, k}});
+ // affine_min(%dimMemRef - %index, %dimAlloc)
+ Value affineMin = affine_min(index.getType(), maps[0],
+ ValueRange{dimMemRef, index, dimAlloc});
+ sizes.push_back(affineMin);
+ });
+ return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
+ SmallVector<Value, 4>(memrefRank, one));
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// memref_cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view, ... : compatibleMemRefType, index, index
+/// } else {
+/// %2 = linalg.fill(%alloc, %pad)
+/// %3 = subview %view [...][...][...]
+/// linalg.copy(%3, %alloc)
+/// memref_cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4, ... : compatibleMemRefType, index, index
+/// }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp createScopedFullPartialLinalgCopy(
+ vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
+ MemRefType compatibleMemRefType, Value alloc) {
+ using namespace edsc;
+ using namespace edsc::intrinsics;
+ scf::IfOp fullPartialIfOp;
+ Value zero = std_constant_index(0);
+ Value memref = xferOp.memref();
+ conditionBuilder(
+ returnTypes, inBoundsCond,
+ [&]() -> scf::ValueVector {
+ Value res = memref;
+ if (compatibleMemRefType != xferOp.getMemRefType())
+ res = std_memref_cast(memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+ xferOp.indices().end());
+ return viewAndIndices;
+ },
+ [&]() -> scf::ValueVector {
+ linalg_fill(alloc, xferOp.padding());
+ // Take partial subview of memref which guarantees no dimension
+ // overflows.
+ Value memRefSubView = createScopedSubViewIntersection(
+ cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
+ linalg_copy(memRefSubView, alloc);
+ Value casted = std_memref_cast(alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+ zero);
+ return viewAndIndices;
+ },
+ &fullPartialIfOp);
+ return fullPartialIfOp;
+}
+
+/// Given an `xferOp` for which:
+/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+/// 2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// memref_cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view, ... : compatibleMemRefType, index, index
+/// } else {
+/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
+/// %3 = vector.type_cast %extra_alloc :
+/// memref<...> to memref<vector<...>>
+/// store %2, %3[] : memref<vector<...>>
+/// %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4, ... : compatibleMemRefType, index, index
+/// }
+/// ```
+/// Return the produced scf::IfOp.
+static scf::IfOp createScopedFullPartialVectorTransferRead(
+ vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
+ MemRefType compatibleMemRefType, Value alloc) {
+ using namespace edsc;
+ using namespace edsc::intrinsics;
+ scf::IfOp fullPartialIfOp;
+ Value zero = std_constant_index(0);
+ Value memref = xferOp.memref();
+ conditionBuilder(
+ returnTypes, inBoundsCond,
+ [&]() -> scf::ValueVector {
+ Value res = memref;
+ if (compatibleMemRefType != xferOp.getMemRefType())
+ res = std_memref_cast(memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+ xferOp.indices().end());
+ return viewAndIndices;
+ },
+ [&]() -> scf::ValueVector {
+ Operation *newXfer =
+ ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
+ Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
+ std_store(vector, vector_type_cast(
+ MemRefType::get({}, vector.getType()), alloc));
+
+ Value casted = std_memref_cast(alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+ zero);
+
+ return viewAndIndices;
+ },
+ &fullPartialIfOp);
+ return fullPartialIfOp;
+}
+
+/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
@@ -2101,17 +2252,17 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
-/// scf.yield %0 : memref<A...>, index, index
-/// } else {
-/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
-/// %3 = vector.type_cast %extra_alloc : memref<...> to
-/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
-/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
-/// memref<A...>, index, index
+/// // fastpath, direct cast
+/// memref_cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// // slowpath, masked vector.transfer or linalg.copy.
+/// memref_cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
-/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
@@ -2119,10 +2270,21 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult mlir::vector::splitFullAndPartialTransfer(
- OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
+ OpBuilder &b, VectorTransferOpInterface xferOp,
+ VectorTransformsOptions options, scf::IfOp *ifOp) {
using namespace edsc;
using namespace edsc::intrinsics;
+ if (options.vectorTransferSplit == VectorTransferSplit::None)
+ return failure();
+
+ SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
+ auto unmaskedAttr = b.getBoolArrayAttr(bools);
+ if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
+ xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
+ return success();
+ }
+
assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
"Expected splitFullAndPartialTransferPrecondition to hold");
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
@@ -2154,45 +2316,21 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
b.getI64IntegerAttr(32));
}
- Value memref = xferOp.memref();
- SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
- auto unmaskedAttr = b.getBoolArrayAttr(bools);
-
MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
- Value zero = std_constant_index(0);
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
b.getIndexType());
returnTypes[0] = compatibleMemRefType;
- scf::IfOp fullPartialIfOp;
- conditionBuilder(
- returnTypes, inBoundsCond,
- [&]() -> scf::ValueVector {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getMemRefType())
- res = std_memref_cast(memref, compatibleMemRefType);
- scf::ValueVector viewAndIndices{res};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
- xferOp.indices().end());
- return viewAndIndices;
- },
- [&]() -> scf::ValueVector {
- Operation *newXfer =
- ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
- Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
- std_store(vector, vector_type_cast(
- MemRefType::get({}, vector.getType()), alloc));
-
- Value casted = std_memref_cast(alloc, compatibleMemRefType);
- scf::ValueVector viewAndIndices{casted};
- viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
- zero);
-
- return viewAndIndices;
- },
- &fullPartialIfOp);
+ scf::IfOp fullPartialIfOp =
+ options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
+ ? createScopedFullPartialVectorTransferRead(
+ xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
+ alloc)
+ : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
+ inBoundsCond,
+ compatibleMemRefType, alloc);
if (ifOp)
*ifOp = fullPartialIfOp;
@@ -2211,7 +2349,7 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
failed(filter(xferOp)))
return failure();
rewriter.startRootUpdate(xferOp);
- if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
+ if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
rewriter.finalizeRootUpdate(xferOp);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index ef76247ee9d4..e36454203982 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -1,13 +1,26 @@
// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy | FileCheck %s --check-prefix=LINALG
// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// LINALG-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// LINALG-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// LINALG-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// LINALG-DAG: #[[$map_2d_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+// LINALG-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
+// LINALG-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
+
// CHECK-LABEL: split_vector_transfer_read_2d(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
+
+// LINALG-LABEL: split_vector_transfer_read_2d(
+// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
+// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
+// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
%c0 = constant 0 : index
%f0 = constant 0.0 : f32
@@ -43,9 +56,45 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
// CHECK: }
// CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
// CHECK_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
+
+ // LINALG-DAG: %[[c0:.*]] = constant 0 : index
+ // LINALG-DAG: %[[c1:.*]] = constant 1 : index
+ // LINALG-DAG: %[[c4:.*]] = constant 4 : index
+ // LINALG-DAG: %[[c8:.*]] = constant 8 : index
+ // LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+ // alloca for boundary full tile
+ // LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
+ // %i + 4 <= dim(%A, 0)
+ // LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
+ // LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
+ // LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index
+ // %j + 8 <= dim(%A, 1)
+ // LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
+ // LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
+ // are both conds true
+ // LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
+ // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
+ // inBounds, just yield %A
+ // LINALG: scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
+ // LINALG: } else {
+ // slow path, fill tmp alloc and yield a memref_casted version of it
+ // LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
+ // LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
+ // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
+ // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
+ // LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
+ // LINALG-SAME: memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_dynamic]]>
+ // LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
+ // LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] :
+ // LINALG-SAME: memref<4x8xf32> to memref<?x8xf32>
+ // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+ // LINALG-SAME: memref<?x8xf32>, index, index
+ // LINALG: }
+ // LINALG: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
+ // LINALG_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
- // CHECK: return %[[res]] : vector<4x8xf32>
+ // LINALG: return %[[res]] : vector<4x8xf32>
return %1: vector<4x8xf32>
}
@@ -53,6 +102,11 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
+
+// LINALG-LABEL: split_vector_transfer_read_strided_2d(
+// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
+// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
+// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
func @split_vector_transfer_read_strided_2d(
%A: memref<7x8xf32, offset:?, strides:[?, 1]>,
%i: index, %j: index) -> vector<4x8xf32> {
@@ -94,6 +148,44 @@ func @split_vector_transfer_read_strided_2d(
// CHECK: }
// CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
// CHECK-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
+
+ // LINALG-DAG: %[[c0:.*]] = constant 0 : index
+ // LINALG-DAG: %[[c1:.*]] = constant 1 : index
+ // LINALG-DAG: %[[c4:.*]] = constant 4 : index
+ // LINALG-DAG: %[[c7:.*]] = constant 7 : index
+ // LINALG-DAG: %[[c8:.*]] = constant 8 : index
+ // LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+ // alloca for boundary full tile
+ // LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
+ // %i + 4 <= dim(%A, 0)
+ // LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
+ // LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index
+ // %j + 8 <= dim(%A, 1)
+ // LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
+ // LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
+ // are both conds true
+ // LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
+ // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index) {
+ // inBounds but not cast-compatible: yield a memref_casted form of %A
+ // LINALG: %[[casted:.*]] = memref_cast %arg0 :
+ // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x8xf32, #[[$map_2d_stride_1]]>
+ // LINALG: scf.yield %[[casted]], %[[i]], %[[j]] :
+ // LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
+ // LINALG: } else {
+ // slow path, fill tmp alloc and yield a memref_casted version of it
+ // LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
+ // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
+ // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
+ // LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
+ // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_dynamic]]>
+ // LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
+ // LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] :
+ // LINALG-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
+ // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+ // LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
+ // LINALG: }
+ // LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
+ // LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
%1 = vector.transfer_read %A[%i, %j], %f0 :
memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32>
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 0bba74e76385..9da3156d5359 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -125,10 +125,23 @@ struct TestVectorUnrollingPatterns
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
+ TestVectorTransferFullPartialSplitPatterns() = default;
+ TestVectorTransferFullPartialSplitPatterns(
+ const TestVectorTransferFullPartialSplitPatterns &pass) {}
+ Option<bool> useLinalgOps{
+ *this, "use-linalg-copy",
+ llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
+ "linalg.copy operations."),
+ llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
- patterns.insert<VectorTransferFullPartialRewriter>(ctx);
+ VectorTransformsOptions options;
+ if (useLinalgOps)
+ options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
+ else
+ options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
+ patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
More information about the Mlir-commits
mailing list