[Mlir-commits] [mlir] 7c3c5b1 - [mlir][Vector] Add option to fully unroll for VectorTransfer to SCF lowering
Nicolas Vasilache
llvmlistbot at llvm.org
Wed May 20 08:06:06 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-20T11:02:13-04:00
New Revision: 7c3c5b11b1fa285443573e6e8ecc2383fcda6554
URL: https://github.com/llvm/llvm-project/commit/7c3c5b11b1fa285443573e6e8ecc2383fcda6554
DIFF: https://github.com/llvm/llvm-project/commit/7c3c5b11b1fa285443573e6e8ecc2383fcda6554.diff
LOG: [mlir][Vector] Add option to fully unroll for VectorTransfer to SCF lowering
Summary:
Previously, the only support partial lowering from vector transfers to SCF was
going through loops. This requires a dedicated allocation and extra memory
roundtrips because LLVM aggregates cannot be indexed dynamically (for more
details see the [deep-dive](https://mlir.llvm.org/docs/Dialects/Vector/#deeperdive)).
This revision allows specifying full unrolling which removes this additional roundtrip.
This should be used carefully though because full unrolling will spill, negating the
benefits of removing the interim alloc in the first place.
Proper heuristics are left for a later time.
Differential Revision: https://reviews.llvm.org/D80100
Added:
Modified:
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Dialect/Vector/VectorUtils.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 976751e48cb1..d7a6f829f10f 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -9,13 +9,160 @@
#ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
#define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
-/// Collect a set of patterns to convert from the Vector dialect to loops + std.
-void populateVectorToSCFConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+/// Control whether unrolling is used when lowering vector transfer ops to SCF.
+///
+/// Case 1:
+/// =======
+/// When `unroll` is false, a temporary buffer is created through which
+/// individual 1-D vector are staged. this is consistent with the lack of an
+/// LLVM instruction to dynamically index into an aggregate (see the Vector
+/// dialect lowering to LLVM deep dive).
+/// An instruction such as:
+/// ```
+/// vector.transfer_write %vec, %A[%base, %base] :
+/// vector<17x15xf32>, memref<?x?xf32>
+/// ```
+/// Lowers to pseudo-IR resembling:
+/// ```
+/// %0 = alloc() : memref<17xvector<15xf32>>
+/// %1 = vector.type_cast %0 :
+/// memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
+/// store %vec, %1[] : memref<vector<17x15xf32>>
+/// %dim = dim %A, 0 : memref<?x?xf32>
+/// affine.for %I = 0 to 17 {
+/// %add = affine.apply %I + %base
+/// %cmp = cmpi "slt", %add, %dim : index
+/// scf.if %cmp {
+/// %vec_1d = load %0[%I] : memref<17xvector<15xf32>>
+/// vector.transfer_write %vec_1d, %A[%add, %base] :
+/// vector<15xf32>, memref<?x?xf32>
+/// ```
+///
+/// Case 2:
+/// =======
+/// When `unroll` is true, the temporary buffer is skipped and static indices
+/// into aggregates can be used (see the Vector dialect lowering to LLVM deep
+/// dive).
+/// An instruction such as:
+/// ```
+/// vector.transfer_write %vec, %A[%base, %base] :
+/// vector<3x15xf32>, memref<?x?xf32>
+/// ```
+/// Lowers to pseudo-IR resembling:
+/// ```
+/// %0 = vector.extract %arg2[0] : vector<3x15xf32>
+/// vector.transfer_write %0, %arg0[%arg1, %arg1] : vector<15xf32>,
+/// memref<?x?xf32> %1 = affine.apply #map1()[%arg1] %2 = vector.extract
+/// %arg2[1] : vector<3x15xf32> vector.transfer_write %2, %arg0[%1, %arg1] :
+/// vector<15xf32>, memref<?x?xf32> %3 = affine.apply #map2()[%arg1] %4 =
+/// vector.extract %arg2[2] : vector<3x15xf32> vector.transfer_write %4,
+/// %arg0[%3, %arg1] : vector<15xf32>, memref<?x?xf32>
+/// ```
+struct VectorTransferToSCFOptions {
+ bool unroll = false;
+ VectorTransferToSCFOptions &setUnroll(bool u) {
+ unroll = u;
+ return *this;
+ }
+};
+
+/// Implements lowering of TransferReadOp and TransferWriteOp to a
+/// proper abstraction for the hardware.
+///
+/// There are multiple cases.
+///
+/// Case A: Permutation Map does not permute or broadcast.
+/// ======================================================
+///
+/// Progressive lowering occurs to 1-D vector transfer ops according to the
+/// description in `VectorTransferToSCFOptions`.
+///
+/// Case B: Permutation Map permutes and/or broadcasts.
+/// ======================================================
+///
+/// This path will be progressively deprecated and folded into the case above by
+/// using vector broadcast and transpose operations.
+///
+/// This path only emits a simple loop nest that performs clipped pointwise
+/// copies from a remote to a locally allocated memory.
+///
+/// Consider the case:
+///
+/// ```mlir
+/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
+/// // vector<32x256xf32> and pad with %f0 to handle the boundary case:
+/// %f0 = constant 0.0f : f32
+/// scf.for %i0 = 0 to %0 {
+/// scf.for %i1 = 0 to %1 step %c256 {
+/// scf.for %i2 = 0 to %2 step %c32 {
+/// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0
+/// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
+/// memref<?x?x?xf32>, vector<32x256xf32>
+/// }}}
+/// ```
+///
+/// The rewriters construct loop and indices that access MemRef A in a pattern
+/// resembling the following (while guaranteeing an always full-tile
+/// abstraction):
+///
+/// ```mlir
+/// scf.for %d2 = 0 to %c256 {
+/// scf.for %d1 = 0 to %c32 {
+/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
+/// %tmp[%d2, %d1] = %s
+/// }
+/// }
+/// ```
+///
+/// In the current state, only a clipping transfer is implemented by `clip`,
+/// which creates individual indexing expressions of the form:
+///
+/// ```mlir-dsc
+/// auto condMax = i + ii < N;
+/// auto max = std_select(condMax, i + ii, N - one)
+/// auto cond = i + ii < zero;
+/// std_select(cond, zero, max);
+/// ```
+///
+/// In the future, clipping should not be the only way and instead we should
+/// load vectors + mask them. Similarly on the write side, load/mask/store for
+/// implementing RMW behavior.
+///
+/// Lowers TransferOp into a combination of:
+/// 1. local memory allocation;
+/// 2. perfect loop nest over:
+/// a. scalar load/stores from local buffers (viewed as a scalar memref);
+/// a. scalar store/load to original memref (with clipping).
+/// 3. vector_load/store
+/// 4. local memory deallocation.
+/// Minor variations occur depending on whether a TransferReadOp or
+/// a TransferWriteOp is rewritten.
+template <typename TransferOpTy>
+struct VectorTransferRewriter : public RewritePattern {
+ explicit VectorTransferRewriter(VectorTransferToSCFOptions options,
+ MLIRContext *context);
+
+ /// Used for staging the transfer in a local buffer.
+ MemRefType tmpMemRefType(TransferOpTy transfer) const;
+
+ /// Performs the rewrite.
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+
+ /// See description of `VectorTransferToSCFOptions`.
+ VectorTransferToSCFOptions options;
+};
+
+/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
+void populateVectorToSCFConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context,
+ const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 29e72857b291..575b99d51c97 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -373,7 +373,11 @@ def Vector_ExtractOp :
}];
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value source,"
- "ArrayRef<int64_t>">];
+ "ArrayRef<int64_t> position">,
+ // Convenience builder which assumes the values in `position` are defined by
+ // ConstantIndexOp.
+ OpBuilder<"OpBuilder &builder, OperationState &result, Value source,"
+ "ValueRange position">];
let extraClassDeclaration = [{
static StringRef getPositionAttrName() { return "position"; }
VectorType getVectorType() {
@@ -535,8 +539,12 @@ def Vector_InsertOp :
}];
let builders = [OpBuilder<
- "OpBuilder &builder, OperationState &result, Value source, " #
- "Value dest, ArrayRef<int64_t>">];
+ "OpBuilder &builder, OperationState &result, Value source, "
+ "Value dest, ArrayRef<int64_t> position">,
+ // Convenience builder which assumes all values are constant indices.
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, Value source, "
+ "Value dest, ValueRange position">];
let extraClassDeclaration = [{
static StringRef getPositionAttrName() { return "position"; }
Type getSourceType() { return source().getType(); }
diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 35527542d639..58f936ca305c 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -1,4 +1,4 @@
-//===- VectorUtils.h - VectorOps Utilities ------------------*- C++ -*-=======//
+//===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -26,18 +26,28 @@ class Operation;
class Value;
class VectorType;
+/// Return the number of elements of basis, `0` if empty.
+int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
+
+/// Given a shape with sizes greater than 0 along all dimensions,
+/// return the distance, in number of elements, between a slice in a dimension
+/// and the next slice in the same dimension.
+/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1]
+SmallVector<int64_t, 8> computeStrides(ArrayRef<int64_t> shape);
+
/// Given the shape and sizes of a vector, returns the corresponding
/// strides for each dimension.
+/// TODO: needs better doc of how it is used.
SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
ArrayRef<int64_t> sizes);
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
-/// Given the slice strides together with a linear index in the dimension
+/// Given the strides together with a linear index in the dimension
/// space, returns the vector-space offsets in each dimension for a
/// de-linearized index.
-SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
+SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
int64_t linearIndex);
/// Given the target sizes of a vector, together with vector-space offsets,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index a06c5984c4e9..4b0368f8c8cc 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@@ -75,8 +76,9 @@ namespace {
template <typename ConcreteOp>
class NDTransferOpHelper {
public:
- NDTransferOpHelper(PatternRewriter &rewriter, ConcreteOp xferOp)
- : rewriter(rewriter), loc(xferOp.getLoc()),
+ NDTransferOpHelper(PatternRewriter &rewriter, ConcreteOp xferOp,
+ const VectorTransferToSCFOptions &options)
+ : rewriter(rewriter), options(options), loc(xferOp.getLoc()),
scope(std::make_unique<ScopedContext>(rewriter, loc)), xferOp(xferOp),
op(xferOp.getOperation()) {
vectorType = xferOp.getVectorType();
@@ -105,19 +107,17 @@ class NDTransferOpHelper {
void emitLoops(Lambda loopBodyBuilder);
/// Operate within the body of `emitLoops` to:
- /// 1. Compute the indexings `majorIvs + majorOffsets`.
- /// 2. Compute a boolean that determines whether the first `majorIvs.rank()`
+ /// 1. Compute the indexings `majorIvs + majorOffsets` and save them in
+ /// `majorIvsPlusOffsets`.
+ /// 2. Return a boolean that determines whether the first `majorIvs.rank()`
/// dimensions `majorIvs + majorOffsets` are all within `memrefBounds`.
- /// 3. Create an IfOp conditioned on the boolean in step 2.
- /// 4. Call a `thenBlockBuilder` and an `elseBlockBuilder` to append
- /// operations to the IfOp blocks as appropriate.
- template <typename LambdaThen, typename LambdaElse>
- void emitInBounds(ValueRange majorIvs, ValueRange majorOffsets,
- MemRefBoundsCapture &memrefBounds,
- LambdaThen thenBlockBuilder, LambdaElse elseBlockBuilder);
+ Value emitInBoundsCondition(ValueRange majorIvs, ValueRange majorOffsets,
+ MemRefBoundsCapture &memrefBounds,
+ SmallVectorImpl<Value> &majorIvsPlusOffsets);
/// Common state to lower vector transfer ops.
PatternRewriter &rewriter;
+ const VectorTransferToSCFOptions &options;
Location loc;
std::unique_ptr<ScopedContext> scope;
ConcreteOp xferOp;
@@ -139,27 +139,43 @@ template <typename Lambda>
void NDTransferOpHelper<ConcreteOp>::emitLoops(Lambda loopBodyBuilder) {
/// Loop nest operates on the major dimensions
MemRefBoundsCapture memrefBoundsCapture(xferOp.memref());
- VectorBoundsCapture vectorBoundsCapture(majorVectorType);
- auto majorLbs = vectorBoundsCapture.getLbs();
- auto majorUbs = vectorBoundsCapture.getUbs();
- auto majorSteps = vectorBoundsCapture.getSteps();
- SmallVector<Value, 8> majorIvs(vectorBoundsCapture.rank());
- AffineLoopNestBuilder(majorIvs, majorLbs, majorUbs, majorSteps)([&] {
+
+ if (options.unroll) {
+ auto shape = majorVectorType.getShape();
+ auto strides = computeStrides(shape);
+ unsigned numUnrolledInstances = computeMaxLinearIndex(shape);
ValueRange indices(xferOp.indices());
- loopBodyBuilder(majorIvs, indices.take_front(leadingRank),
- indices.drop_front(leadingRank).take_front(majorRank),
- indices.take_back(minorRank), memrefBoundsCapture);
- });
+ for (unsigned idx = 0; idx < numUnrolledInstances; ++idx) {
+ SmallVector<int64_t, 4> offsets = delinearize(strides, idx);
+ SmallVector<Value, 4> offsetValues =
+ llvm::to_vector<4>(llvm::map_range(offsets, [](int64_t off) -> Value {
+ return std_constant_index(off);
+ }));
+ loopBodyBuilder(offsetValues, indices.take_front(leadingRank),
+ indices.drop_front(leadingRank).take_front(majorRank),
+ indices.take_back(minorRank), memrefBoundsCapture);
+ }
+ } else {
+ VectorBoundsCapture vectorBoundsCapture(majorVectorType);
+ auto majorLbs = vectorBoundsCapture.getLbs();
+ auto majorUbs = vectorBoundsCapture.getUbs();
+ auto majorSteps = vectorBoundsCapture.getSteps();
+ SmallVector<Value, 8> majorIvs(vectorBoundsCapture.rank());
+ AffineLoopNestBuilder(majorIvs, majorLbs, majorUbs, majorSteps)([&] {
+ ValueRange indices(xferOp.indices());
+ loopBodyBuilder(majorIvs, indices.take_front(leadingRank),
+ indices.drop_front(leadingRank).take_front(majorRank),
+ indices.take_back(minorRank), memrefBoundsCapture);
+ });
+ }
}
template <typename ConcreteOp>
-template <typename LambdaThen, typename LambdaElse>
-void NDTransferOpHelper<ConcreteOp>::emitInBounds(
+Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
ValueRange majorIvs, ValueRange majorOffsets,
- MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder,
- LambdaElse elseBlockBuilder) {
- Value inBounds;
- SmallVector<Value, 4> majorIvsPlusOffsets;
+ MemRefBoundsCapture &memrefBounds,
+ SmallVectorImpl<Value> &majorIvsPlusOffsets) {
+ Value inBoundsCondition;
majorIvsPlusOffsets.reserve(majorIvs.size());
unsigned idx = 0;
for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) {
@@ -167,42 +183,33 @@ void NDTransferOpHelper<ConcreteOp>::emitInBounds(
using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off);
if (xferOp.isMaskedDim(leadingRank + idx)) {
- Value inBounds2 = majorIvsPlusOffsets.back() < ub;
- inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2;
+ Value inBounds = majorIvsPlusOffsets.back() < ub;
+ inBoundsCondition =
+ (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds;
}
++idx;
}
-
- if (inBounds) {
- auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
- ScopedContext::getLocation(), TypeRange{}, inBounds,
- /*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
- BlockBuilder(&ifOp.thenRegion().front(),
- Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); });
- if (std::is_same<ConcreteOp, TransferReadOp>())
- BlockBuilder(&ifOp.elseRegion().front(),
- Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); });
- } else {
- // Just build the body of the then block right here.
- thenBlockBuilder(majorIvsPlusOffsets);
- }
+ return inBoundsCondition;
}
template <>
LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
- Value alloc = std_alloc(memRefMinorVectorType);
+ Value alloc, result;
+ if (options.unroll)
+ result = std_splat(vectorType, xferOp.padding());
+ else
+ alloc = std_alloc(memRefMinorVectorType);
emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets,
ValueRange majorOffsets, ValueRange minorOffsets,
MemRefBoundsCapture &memrefBounds) {
- // If in-bounds, index into memref and lower to 1-D transfer read.
- auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {
+ /// Lambda to load 1-D vector in the current loop ivs + offset context.
+ auto load1DVector = [&](ValueRange majorIvsPlusOffsets) -> Value {
SmallVector<Value, 8> indexing;
indexing.reserve(leadingRank + majorRank + minorRank);
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
-
Value memref = xferOp.memref();
auto map = TransferReadOp::getTransferMinorIdentityMap(
xferOp.getMemRefType(), minorVectorType);
@@ -211,46 +218,103 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
OpBuilder &b = ScopedContext::getBuilderRef();
masked = b.getBoolArrayAttr({true});
}
- auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing,
- AffineMapAttr::get(map),
- xferOp.padding(), masked);
- // Store the 1-D vector.
- std_store(loaded1D, alloc, majorIvs);
+ return vector_transfer_read(minorVectorType, memref, indexing,
+ AffineMapAttr::get(map), xferOp.padding(),
+ masked);
};
- // If out-of-bounds, just store a splatted vector.
- auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {
- auto vector = std_splat(minorVectorType, xferOp.padding());
- std_store(vector, alloc, majorIvs);
- };
- emitInBounds(majorIvs, majorOffsets, memrefBounds, thenBlockBuilder,
- elseBlockBuilder);
+
+ // 1. Compute the inBoundsCondition in the current loops ivs + offset
+ // context.
+ SmallVector<Value, 4> majorIvsPlusOffsets;
+ Value inBoundsCondition = emitInBoundsCondition(
+ majorIvs, majorOffsets, memrefBounds, majorIvsPlusOffsets);
+
+ if (inBoundsCondition) {
+ // 2. If the condition is not null, we need an IfOp, which may yield
+ // if `options.unroll` is true.
+ SmallVector<Type, 1> resultType;
+ if (options.unroll)
+ resultType.push_back(vectorType);
+ auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
+ ScopedContext::getLocation(), resultType, inBoundsCondition,
+ /*withElseRegion=*/true);
+
+ // 3.a. If in-bounds, progressively lower to a 1-D transfer read.
+ BlockBuilder(&ifOp.thenRegion().front(), Append())([&] {
+ Value vector = load1DVector(majorIvsPlusOffsets);
+ // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
+ // aggregate. We must yield and merge with the `else` branch.
+ if (options.unroll) {
+ vector = vector_insert(vector, result, majorIvs);
+ (loop_yield(vector));
+ return;
+ }
+ // 3.a.ii. Otherwise, just go through the temporary `alloc`.
+ std_store(vector, alloc, majorIvs);
+ });
+
+ // 3.b. If not in-bounds, splat a 1-D vector.
+ BlockBuilder(&ifOp.elseRegion().front(), Append())([&] {
+ Value vector = std_splat(minorVectorType, xferOp.padding());
+ // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
+ // aggregate. We must yield and merge with the `then` branch.
+ if (options.unroll) {
+ vector = vector_insert(vector, result, majorIvs);
+ (loop_yield(vector));
+ return;
+ }
+ // 3.b.ii. Otherwise, just go through the temporary `alloc`.
+ std_store(vector, alloc, majorIvs);
+ });
+ if (!resultType.empty())
+ result = *ifOp.results().begin();
+ } else {
+ // 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read.
+ Value loaded1D = load1DVector(majorIvsPlusOffsets);
+ // 5.a. If `options.unroll` is true, insert the 1-D vector in the
+ // aggregate.
+ if (options.unroll)
+ result = vector_insert(loaded1D, result, majorIvs);
+ // 5.b. Otherwise, just go through the temporary `alloc`.
+ else
+ std_store(loaded1D, alloc, majorIvs);
+ }
});
- Value loaded =
- std_load(vector_type_cast(MemRefType::get({}, vectorType), alloc));
- rewriter.replaceOp(op, loaded);
+ assert((!options.unroll ^ result) && "Expected resulting Value iff unroll");
+ if (!result)
+ result = std_load(vector_type_cast(MemRefType::get({}, vectorType), alloc));
+ rewriter.replaceOp(op, result);
return success();
}
template <>
LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
- Value alloc = std_alloc(memRefMinorVectorType);
-
- std_store(xferOp.vector(),
- vector_type_cast(MemRefType::get({}, vectorType), alloc));
+ Value alloc;
+ if (!options.unroll) {
+ alloc = std_alloc(memRefMinorVectorType);
+ std_store(xferOp.vector(),
+ vector_type_cast(MemRefType::get({}, vectorType), alloc));
+ }
emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets,
ValueRange majorOffsets, ValueRange minorOffsets,
MemRefBoundsCapture &memrefBounds) {
- auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {
+ // Lower to 1-D vector_transfer_write and let recursion handle it.
+ auto emitTransferWrite = [&](ValueRange majorIvsPlusOffsets) {
SmallVector<Value, 8> indexing;
indexing.reserve(leadingRank + majorRank + minorRank);
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
- // Lower to 1-D vector_transfer_write and let recursion handle it.
- Value loaded1D = std_load(alloc, majorIvs);
+ Value result;
+ // If `options.unroll` is true, extract the 1-D vector from the
+ // aggregate.
+ if (options.unroll)
+ result = vector_extract(xferOp.vector(), majorIvs);
+ else
+ result = std_load(alloc, majorIvs);
auto map = TransferWriteOp::getTransferMinorIdentityMap(
xferOp.getMemRefType(), minorVectorType);
ArrayAttr masked;
@@ -258,13 +322,28 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
OpBuilder &b = ScopedContext::getBuilderRef();
masked = b.getBoolArrayAttr({true});
}
- vector_transfer_write(loaded1D, xferOp.memref(), indexing,
+ vector_transfer_write(result, xferOp.memref(), indexing,
AffineMapAttr::get(map), masked);
};
- // Don't write anything when out of bounds.
- auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {};
- emitInBounds(majorIvs, majorOffsets, memrefBounds, thenBlockBuilder,
- elseBlockBuilder);
+
+ // 1. Compute the inBoundsCondition in the current loops ivs + offset
+ // context.
+ SmallVector<Value, 4> majorIvsPlusOffsets;
+ Value inBoundsCondition = emitInBoundsCondition(
+ majorIvs, majorOffsets, memrefBounds, majorIvsPlusOffsets);
+
+ if (inBoundsCondition) {
+ // 2.a. If the condition is not null, we need an IfOp, to write
+ // conditionally. Progressively lower to a 1-D transfer write.
+ auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
+ ScopedContext::getLocation(), TypeRange{}, inBoundsCondition,
+ /*withElseRegion=*/false);
+ BlockBuilder(&ifOp.thenRegion().front(),
+ Append())([&] { emitTransferWrite(majorIvsPlusOffsets); });
+ } else {
+ // 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write.
+ emitTransferWrite(majorIvsPlusOffsets);
+ }
});
rewriter.eraseOp(op);
@@ -358,81 +437,20 @@ clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef<Value> ivs) {
return clippedScalarAccessExprs;
}
-namespace {
-
-/// Implements lowering of TransferReadOp and TransferWriteOp to a
-/// proper abstraction for the hardware.
-///
-/// For now, we only emit a simple loop nest that performs clipped pointwise
-/// copies from a remote to a locally allocated memory.
-///
-/// Consider the case:
-///
-/// ```mlir
-/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
-/// // vector<32x256xf32> and pad with %f0 to handle the boundary case:
-/// %f0 = constant 0.0f : f32
-/// scf.for %i0 = 0 to %0 {
-/// scf.for %i1 = 0 to %1 step %c256 {
-/// scf.for %i2 = 0 to %2 step %c32 {
-/// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0
-/// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
-/// memref<?x?x?xf32>, vector<32x256xf32>
-/// }}}
-/// ```
-///
-/// The rewriters construct loop and indices that access MemRef A in a pattern
-/// resembling the following (while guaranteeing an always full-tile
-/// abstraction):
-///
-/// ```mlir
-/// scf.for %d2 = 0 to %c256 {
-/// scf.for %d1 = 0 to %c32 {
-/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
-/// %tmp[%d2, %d1] = %s
-/// }
-/// }
-/// ```
-///
-/// In the current state, only a clipping transfer is implemented by `clip`,
-/// which creates individual indexing expressions of the form:
-///
-/// ```mlir-dsc
-/// auto condMax = i + ii < N;
-/// auto max = std_select(condMax, i + ii, N - one)
-/// auto cond = i + ii < zero;
-/// std_select(cond, zero, max);
-/// ```
-///
-/// In the future, clipping should not be the only way and instead we should
-/// load vectors + mask them. Similarly on the write side, load/mask/store for
-/// implementing RMW behavior.
-///
-/// Lowers TransferOp into a combination of:
-/// 1. local memory allocation;
-/// 2. perfect loop nest over:
-/// a. scalar load/stores from local buffers (viewed as a scalar memref);
-/// a. scalar store/load to original memref (with clipping).
-/// 3. vector_load/store
-/// 4. local memory deallocation.
-/// Minor variations occur depending on whether a TransferReadOp or
-/// a TransferWriteOp is rewritten.
template <typename TransferOpTy>
-struct VectorTransferRewriter : public RewritePattern {
- explicit VectorTransferRewriter(MLIRContext *context)
- : RewritePattern(TransferOpTy::getOperationName(), 1, context) {}
-
- /// Used for staging the transfer in a local scalar buffer.
- MemRefType tmpMemRefType(TransferOpTy transfer) const {
- auto vectorType = transfer.getVectorType();
- return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
- {}, 0);
- }
+VectorTransferRewriter<TransferOpTy>::VectorTransferRewriter(
+ VectorTransferToSCFOptions options, MLIRContext *context)
+ : RewritePattern(TransferOpTy::getOperationName(), 1, context),
+ options(options) {}
- /// Performs the rewrite.
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
-};
+/// Used for staging the transfer in a local buffer.
+template <typename TransferOpTy>
+MemRefType VectorTransferRewriter<TransferOpTy>::tmpMemRefType(
+ TransferOpTy transfer) const {
+ auto vectorType = transfer.getVectorType();
+ return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
+ 0);
+}
/// Lowers TransferReadOp into a combination of:
/// 1. local memory allocation;
@@ -486,7 +504,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
if (AffineMap::isMinorIdentity(transfer.permutation_map())) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
if (transfer.getVectorType().getRank() > 1)
- return NDTransferOpHelper<TransferReadOp>(rewriter, transfer).doReplace();
+ return NDTransferOpHelper<TransferReadOp>(rewriter, transfer, options)
+ .doReplace();
// If 1-D this is now handled by the target-specific lowering.
if (transfer.getVectorType().getRank() == 1)
return failure();
@@ -558,7 +577,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
if (AffineMap::isMinorIdentity(transfer.permutation_map())) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
if (transfer.getVectorType().getRank() > 1)
- return NDTransferOpHelper<TransferWriteOp>(rewriter, transfer)
+ return NDTransferOpHelper<TransferWriteOp>(rewriter, transfer, options)
.doReplace();
// If 1-D this is now handled by the target-specific lowering.
if (transfer.getVectorType().getRank() == 1)
@@ -603,10 +622,10 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
return success();
}
-} // namespace
-
void mlir::populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns, MLIRContext *context,
+ const VectorTransferToSCFOptions &options) {
patterns.insert<VectorTransferRewriter<vector::TransferReadOp>,
- VectorTransferRewriter<vector::TransferWriteOp>>(context);
+ VectorTransferRewriter<vector::TransferWriteOp>>(options,
+ context);
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 01894d1ad7d1..ca07ee140774 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -470,6 +470,16 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
result.addAttribute(getPositionAttrName(), positionAttr);
}
+// Convenience builder which assumes the values are constant indices.
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+ Value source, ValueRange position) {
+ SmallVector<int64_t, 4> positionConstants =
+ llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
+ return pos.getDefiningOp<ConstantIndexOp>().getValue();
+ }));
+ build(builder, result, source, positionConstants);
+}
+
static void print(OpAsmPrinter &p, vector::ExtractOp op) {
p << op.getOperationName() << " " << op.vector() << op.position();
p.printOptionalAttrDict(op.getAttrs(), {"position"});
@@ -739,6 +749,16 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
result.addAttribute(getPositionAttrName(), positionAttr);
}
+// Convenience builder which assumes the values are constant indices.
+void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
+ Value dest, ValueRange position) {
+ SmallVector<int64_t, 4> positionConstants =
+ llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
+ return pos.getDefiningOp<ConstantIndexOp>().getValue();
+ }));
+ build(builder, result, source, dest, positionConstants);
+}
+
static LogicalResult verify(InsertOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index cf1bdede9027..1c1de155d8b6 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -39,33 +39,6 @@
using namespace mlir;
using llvm::dbgs;
-/// Given a shape with sizes greater than 0 along all dimensions,
-/// returns the distance, in number of elements, between a slice in a dimension
-/// and the next slice in the same dimension.
-/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1]
-static SmallVector<int64_t, 8> computeStrides(ArrayRef<int64_t> shape) {
- if (shape.empty())
- return {};
- SmallVector<int64_t, 8> tmp;
- tmp.reserve(shape.size());
- int64_t running = 1;
- for (auto size : llvm::reverse(shape)) {
- assert(size > 0 && "size must be nonnegative");
- tmp.push_back(running);
- running *= size;
- }
- return SmallVector<int64_t, 8>(tmp.rbegin(), tmp.rend());
-}
-
-static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
- if (basis.empty())
- return 0;
- int64_t res = 1;
- for (auto b : basis)
- res *= b;
- return res;
-}
-
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 1ed89e3f7010..ccd243e8a7de 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
+#include <numeric>
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
@@ -28,6 +29,32 @@ using llvm::SetVector;
using namespace mlir;
+/// Return the number of elements of basis, `0` if empty.
+int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+ if (basis.empty())
+ return 0;
+ return std::accumulate(basis.begin(), basis.end(), 1,
+ std::multiplies<int64_t>());
+}
+
+/// Given a shape with sizes greater than 0 along all dimensions,
+/// return the distance, in number of elements, between a slice in a dimension
+/// and the next slice in the same dimension.
+/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1]
+SmallVector<int64_t, 8> mlir::computeStrides(ArrayRef<int64_t> shape) {
+ if (shape.empty())
+ return {};
+ SmallVector<int64_t, 8> tmp;
+ tmp.reserve(shape.size());
+ int64_t running = 1;
+ for (auto size : llvm::reverse(shape)) {
+ assert(size > 0 && "size must be nonnegative");
+ tmp.push_back(running);
+ running *= size;
+ }
+ return SmallVector<int64_t, 8>(tmp.rbegin(), tmp.rend());
+}
+
SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
ArrayRef<int64_t> sizes) {
int64_t rank = shape.size();
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
index c0bc5542e21d..dc35058cfd89 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-convert-vector-to-scf -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-vector-to-scf=full-unroll=true -split-input-file | FileCheck %s --check-prefix=FULL-UNROLL
// CHECK-LABEL: func @materialize_read_1d() {
func @materialize_read_1d() {
@@ -213,32 +214,76 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// FULL-UNROLL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// FULL-UNROLL-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// FULL-UNROLL-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)>
+
+
// CHECK-LABEL: transfer_read_progressive(
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index
-func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17x15xf32> {
+
+// FULL-UNROLL-LABEL: transfer_read_progressive(
+// FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index
+
+func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x15xf32> {
// CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32
%f7 = constant 7.0: f32
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
- // CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
+ // CHECK-DAG: %[[alloc:.*]] = alloc() : memref<3xvector<15xf32>>
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref<?x?xf32>
- // CHECK: affine.for %[[I:.*]] = 0 to 17 {
+ // CHECK: affine.for %[[I:.*]] = 0 to 3 {
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
// CHECK: scf.if %[[cond1]] {
// CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] : memref<?x?xf32>, vector<15xf32>
- // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>>
+ // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>>
// CHECK: } else {
- // CHECK: store %[[splat]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>>
+ // CHECK: store %[[splat]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>>
// CHECK: }
- // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
- // CHECK: %[[cst:.*]] = load %[[vmemref]][] : memref<vector<17x15xf32>>
- %f = vector.transfer_read %A[%base, %base], %f7
- {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
- memref<?x?xf32>, vector<17x15xf32>
+ // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
+ // CHECK: %[[cst:.*]] = load %[[vmemref]][] : memref<vector<3x15xf32>>
+
+ // FULL-UNROLL: %[[pad:.*]] = constant 7.000000e+00 : f32
+ // FULL-UNROLL: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32>
+ // FULL-UNROLL: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32>
+ // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], 0 : memref<?x?xf32>
+ // FULL-UNROLL: cmpi "slt", %[[base]], %[[DIM]] : index
+ // FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
+ // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], %[[pad]] : memref<?x?xf32>, vector<15xf32>
+ // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC0]] [0] : vector<15xf32> into vector<3x15xf32>
+ // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
+ // FULL-UNROLL: } else {
+ // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC0]] [0] : vector<15xf32> into vector<3x15xf32>
+ // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
+ // FULL-UNROLL: }
+ // FULL-UNROLL: affine.apply #[[MAP1]]()[%[[base]]]
+ // FULL-UNROLL: cmpi "slt", %{{.*}}, %[[DIM]] : index
+ // FULL-UNROLL: %[[VEC2:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
+ // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %[[pad]] : memref<?x?xf32>, vector<15xf32>
+ // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC1]] [1] : vector<15xf32> into vector<3x15xf32>
+ // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
+ // FULL-UNROLL: } else {
+ // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC1]] [1] : vector<15xf32> into vector<3x15xf32>
+ // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
+ // FULL-UNROLL: }
+ // FULL-UNROLL: affine.apply #[[MAP2]]()[%[[base]]]
+ // FULL-UNROLL: cmpi "slt", %{{.*}}, %[[DIM]] : index
+ // FULL-UNROLL: %[[VEC3:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
+ // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %[[pad]] : memref<?x?xf32>, vector<15xf32>
+ // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC2]] [2] : vector<15xf32> into vector<3x15xf32>
+ // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
+ // FULL-UNROLL: } else {
+ // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC2]] [2] : vector<15xf32> into vector<3x15xf32>
+ // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
+ // FULL-UNROLL: }
+
+ %f = vector.transfer_read %A[%base, %base], %f7 :
+ memref<?x?xf32>, vector<3x15xf32>
- return %f: vector<17x15xf32>
+ return %f: vector<3x15xf32>
}
// -----
@@ -246,25 +291,52 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<17
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// FULL-UNROLL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// FULL-UNROLL-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// FULL-UNROLL-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)>
+
// CHECK-LABEL: transfer_write_progressive(
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
-func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
- // CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
- // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
- // CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
+// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
+// FULL-UNROLL-LABEL: transfer_write_progressive(
+// FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index,
+// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
+func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<3x15xf32>) {
+ // CHECK: %[[alloc:.*]] = alloc() : memref<3xvector<15xf32>>
+ // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
+ // CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<3x15xf32>>
// CHECK: %[[dim:.*]] = dim %[[A]], 0 : memref<?x?xf32>
- // CHECK: affine.for %[[I:.*]] = 0 to 17 {
+ // CHECK: affine.for %[[I:.*]] = 0 to 3 {
// CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
// CHECK: scf.if %[[cmp]] {
- // CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
+ // CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<3xvector<15xf32>>
// CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
// CHECK: }
- vector.transfer_write %vec, %A[%base, %base]
- {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
- vector<17x15xf32>, memref<?x?xf32>
+
+ // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], 0 : memref<?x?xf32>
+ // FULL-UNROLL: %[[CMP0:.*]] = cmpi "slt", %[[base]], %[[DIM]] : index
+ // FULL-UNROLL: scf.if %[[CMP0]] {
+ // FULL-UNROLL: %[[V0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32>
+ // FULL-UNROLL: vector.transfer_write %[[V0]], %[[A]][%[[base]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+ // FULL-UNROLL: }
+ // FULL-UNROLL: %[[I1:.*]] = affine.apply #[[MAP1]]()[%[[base]]]
+ // FULL-UNROLL: %[[CMP1:.*]] = cmpi "slt", %[[I1]], %[[DIM]] : index
+ // FULL-UNROLL: scf.if %[[CMP1]] {
+ // FULL-UNROLL: %[[V1:.*]] = vector.extract %[[vec]][1] : vector<3x15xf32>
+ // FULL-UNROLL: vector.transfer_write %[[V1]], %[[A]][%[[I1]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+ // FULL-UNROLL: }
+ // FULL-UNROLL: %[[I2:.*]] = affine.apply #[[MAP2]]()[%[[base]]]
+ // FULL-UNROLL: %[[CMP2:.*]] = cmpi "slt", %[[I2]], %[[DIM]] : index
+ // FULL-UNROLL: scf.if %[[CMP2]] {
+ // FULL-UNROLL: %[[V2:.*]] = vector.extract %[[vec]][2] : vector<3x15xf32>
+ // FULL-UNROLL: vector.transfer_write %[[V2]], %[[A]][%[[I2]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+ // FULL-UNROLL: }
+
+ vector.transfer_write %vec, %A[%base, %base] :
+ vector<3x15xf32>, memref<?x?xf32>
return
}
@@ -273,20 +345,37 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// FULL-UNROLL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// FULL-UNROLL-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// FULL-UNROLL-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)>
+
// CHECK-LABEL: transfer_write_progressive_not_masked(
// CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index,
-// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32>
-func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index, %vec: vector<17x15xf32>) {
+// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
+// FULL-UNROLL-LABEL: transfer_write_progressive_not_masked(
+// FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index,
+// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
+func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index, %vec: vector<3x15xf32>) {
// CHECK-NOT: scf.if
- // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>>
- // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
- // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref<vector<17x15xf32>>
- // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 {
+ // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<3xvector<15xf32>>
+ // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
+ // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref<vector<3x15xf32>>
+ // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 3 {
// CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]]
- // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>>
+ // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<3xvector<15xf32>>
// CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+
+ // FULL-UNROLL: %[[VEC0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32>
+ // FULL-UNROLL: vector.transfer_write %[[VEC0]], %[[A]][%[[base]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+ // FULL-UNROLL: %[[I1:.*]] = affine.apply #[[MAP1]]()[%[[base]]]
+ // FULL-UNROLL: %[[VEC1:.*]] = vector.extract %[[vec]][1] : vector<3x15xf32>
+ // FULL-UNROLL: vector.transfer_write %2, %[[A]][%[[I1]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
+ // FULL-UNROLL: %[[I2:.*]] = affine.apply #[[MAP2]]()[%[[base]]]
+ // FULL-UNROLL: %[[VEC2:.*]] = vector.extract %[[vec]][2] : vector<3x15xf32>
+ // FULL-UNROLL: vector.transfer_write %[[VEC2:.*]], %[[A]][%[[I2]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} :
- vector<17x15xf32>, memref<?x?xf32>
+ vector<3x15xf32>, memref<?x?xf32>
return
}
diff --git a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp b/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp
index fb3010d88b52..7a83e20e47ac 100644
--- a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp
+++ b/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp
@@ -19,10 +19,20 @@ namespace {
struct TestVectorToSCFPass
: public PassWrapper<TestVectorToSCFPass, FunctionPass> {
+ TestVectorToSCFPass() = default;
+ TestVectorToSCFPass(const TestVectorToSCFPass &pass) {}
+
+ Option<bool> fullUnroll{
+ *this, "full-unroll",
+ llvm::cl::desc(
+ "Perform full unrolling when converting vector transfers to SCF"),
+ llvm::cl::init(false)};
+
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *context = &getContext();
- populateVectorToSCFConversionPatterns(patterns, context);
+ populateVectorToSCFConversionPatterns(
+ patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
More information about the Mlir-commits
mailing list