[Mlir-commits] [mlir] 9b77be5 - [mlir] Unrolled progressive-vector-to-scf.
Matthias Springer
llvmlistbot at llvm.org
Wed May 12 21:09:09 PDT 2021
Author: Matthias Springer
Date: 2021-05-13T13:08:48+09:00
New Revision: 9b77be5583d2da03f2ccd7319d33a2daedf8b1b3
URL: https://github.com/llvm/llvm-project/commit/9b77be5583d2da03f2ccd7319d33a2daedf8b1b3
DIFF: https://github.com/llvm/llvm-project/commit/9b77be5583d2da03f2ccd7319d33a2daedf8b1b3.diff
LOG: [mlir] Unrolled progressive-vector-to-scf.
Instead of an SCF for loop, these pattern generate fully unrolled loops with no temporary buffer allocations.
Differential Revision: https://reviews.llvm.org/D101981
Added:
mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir
Modified:
mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
mlir/lib/Interfaces/VectorInterfaces.cpp
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
index 5f765b18d339..b69ec01a0d5c 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
@@ -47,12 +47,24 @@ class RewritePatternSet;
/// When applying the pattern a second time, the existing alloca() operation
/// is reused and only a second vector.type_cast is added.
+struct ProgressiveVectorTransferToSCFOptions {
+ bool unroll = false;
+ ProgressiveVectorTransferToSCFOptions &setUnroll(bool u) {
+ unroll = u;
+ return *this;
+ }
+};
+
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
void populateProgressiveVectorToSCFConversionPatterns(
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ const ProgressiveVectorTransferToSCFOptions &options =
+ ProgressiveVectorTransferToSCFOptions());
/// Create a pass to convert a subset of vector ops to SCF.
-std::unique_ptr<Pass> createProgressiveConvertVectorToSCFPass();
+std::unique_ptr<Pass> createProgressiveConvertVectorToSCFPass(
+ const ProgressiveVectorTransferToSCFOptions &options =
+ ProgressiveVectorTransferToSCFOptions());
} // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 3b46df8fcc1b..7a016bd88547 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -262,6 +262,14 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front());
}
+/// Add the pass label to a vector transfer op if its rank is not the target
+/// rank.
+template <typename OpTy>
+static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) {
+ if (newXferOp.getVectorType().getRank() > kTargetRank)
+ newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
+}
+
/// Given a transfer op, find the memref from which the mask is loaded. This
/// is similar to Strategy<TransferWriteOp>::getBuffer.
template <typename OpTy>
@@ -352,8 +360,8 @@ struct Strategy<TransferReadOp> {
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
xferOp.padding(), Value(), inBoundsAttr).value;
- if (vecType.getRank() > kTargetRank)
- newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
+ maybeApplyPassLabel(builder,
+ dyn_cast<TransferReadOp>(newXfer.getDefiningOp()));
memref_store(newXfer, buffer, storeIndices);
return newXfer.getDefiningOp<TransferReadOp>();
@@ -424,15 +432,13 @@ struct Strategy<TransferWriteOp> {
getXferIndices(xferOp, iv, xferIndices);
auto vec = memref_load(buffer, loadIndices);
- auto vecType = vec.value.getType().dyn_cast<VectorType>();
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
auto newXfer = vector_transfer_write(
Type(), vec, xferOp.source(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
Value(), inBoundsAttr);
- if (vecType.getRank() > kTargetRank)
- newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
+ maybeApplyPassLabel(builder, newXfer.op);
return newXfer;
}
@@ -663,6 +669,264 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
}
};
+/// If the original transfer op has a mask, compute the mask of the new transfer
+/// op (for the current iteration `i`) and assign it.
+template <typename OpTy>
+static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp,
+ int64_t i) {
+ if (!xferOp.mask())
+ return;
+
+ if (xferOp.isBroadcastDim(0)) {
+ // To-be-unpacked dimension is a broadcast, which does not have a
+ // corresponding mask dimension. Mask attribute remains unchanged.
+ newXferOp.maskMutable().assign(xferOp.mask());
+ return;
+ }
+
+ if (xferOp.getMaskType().getRank() > 1) {
+ // Unpack one dimension of the mask.
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(newXferOp); // Insert load before newXfer.
+
+ llvm::SmallVector<int64_t, 1> indices({i});
+ auto newMask = vector_extract(xferOp.mask(), indices).value;
+ newXferOp.maskMutable().assign(newMask);
+ }
+
+ // If we end up here: The mask of the old transfer op is 1D and the unpacked
+ // dim is not a broadcast, so no mask is needed on the new transfer op.
+ // `generateInBoundsCheck` will have evaluated the mask already.
+}
+
+/// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
+/// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
+/// memref buffer is allocated and the SCF loop is fully unrolled.
+///
+/// ```
+/// E.g.:
+/// ```
+/// %vec = vector.transfer_read %A[%a, %b, %c], %padding
+/// : memref<?x?x?xf32>, vector<5x4xf32>
+/// ```
+/// is rewritten to IR such as (simplified):
+/// ```
+/// %v_init = splat %padding : vector<5x4xf32>
+/// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
+/// : memref<?x?x?xf32>, vector<4xf32>
+/// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
+/// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
+/// : memref<?x?x?xf32>, vector<4xf32>
+/// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
+/// ...
+/// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
+/// : memref<?x?x?xf32>, vector<4xf32>
+/// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
+/// ```
+///
+/// Note: A pass label is attached to new TransferReadOps, so that subsequent
+/// applications of this pattern do not create an additional %v_init vector.
+struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
+ using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+
+ /// Find the result vector %v_init or create a new vector if this the first
+ /// application of the pattern.
+ Value getResultVector(TransferReadOp xferOp,
+ PatternRewriter &rewriter) const {
+ if (xferOp->hasAttr(kPassLabel)) {
+ return getInsertOp(xferOp).dest();
+ }
+ return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
+ }
+
+ /// Assuming that this not the first application of the pattern, return the
+ /// vector.insert op in which the result of this transfer op is used.
+ vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
+ Operation *xferOpUser = *xferOp->getUsers().begin();
+ return dyn_cast<vector::InsertOp>(xferOpUser);
+ }
+
+ /// Assuming that this not the first application of the pattern, return the
+ /// indices of the vector.insert op in which the result of this transfer op
+ /// is used.
+ void getInsertionIndices(TransferReadOp xferOp,
+ SmallVector<int64_t, 8> &indices) const {
+ if (xferOp->hasAttr(kPassLabel)) {
+ llvm::for_each(getInsertOp(xferOp).position(), [&](Attribute attr) {
+ indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
+ });
+ }
+ }
+
+ /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
+ /// accesses, and broadcasts and transposes in permutation maps.
+ LogicalResult matchAndRewrite(TransferReadOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (xferOp.getVectorType().getRank() <= kTargetRank)
+ return failure();
+
+ ScopedContext scope(rewriter, xferOp.getLoc());
+ auto vec = getResultVector(xferOp, rewriter);
+ auto vecType = vec.getType().dyn_cast<VectorType>();
+ auto xferVecType = xferOp.getVectorType();
+ auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
+ xferVecType.getElementType());
+ int64_t dimSize = xferVecType.getShape()[0];
+
+ // Generate fully unrolled loop of transfer ops.
+ for (int64_t i = 0; i < dimSize; ++i) {
+ Value iv = std_constant_index(i);
+
+ vec = generateInBoundsCheck(
+ xferOp, iv, rewriter, unpackedDim(xferOp), TypeRange(vecType),
+ /*inBoundsCase=*/
+ [&](OpBuilder &b, Location loc) {
+ ScopedContext scope(b, loc);
+
+ // Indices for the new transfer op.
+ SmallVector<Value, 8> xferIndices;
+ getXferIndices(xferOp, iv, xferIndices);
+
+ // Indices for the new vector.insert op.
+ SmallVector<int64_t, 8> insertionIndices;
+ getInsertionIndices(xferOp, insertionIndices);
+ insertionIndices.push_back(i);
+
+ auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+ auto newXferOpVal =
+ vector_transfer_read(
+ newXferVecType, xferOp.source(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(xferOp, b)),
+ xferOp.padding(), Value(), inBoundsAttr)
+ .value;
+ auto newXferOp =
+ dyn_cast<TransferReadOp>(newXferOpVal.getDefiningOp());
+
+ maybeAssignMask(b, xferOp, newXferOp, i);
+ maybeApplyPassLabel(b, newXferOp);
+
+ return vector_insert(newXferOp, vec, insertionIndices).value;
+ },
+ /*outOfBoundsCase=*/
+ [&](OpBuilder &b, Location loc) {
+ // Loop through original (unmodified) vector.
+ return vec;
+ });
+ }
+
+ if (xferOp->hasAttr(kPassLabel)) {
+ rewriter.replaceOp(getInsertOp(xferOp), vec);
+ rewriter.eraseOp(xferOp);
+ } else {
+ rewriter.replaceOp(xferOp, vec);
+ }
+
+ return success();
+ }
+};
+
+/// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
+/// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
+/// memref buffer is allocated and the SCF loop is fully unrolled.
+///
+/// ```
+/// E.g.:
+/// ```
+/// vector.transfer_write %vec, %A[%a, %b, %c]
+/// : vector<5x4xf32>, memref<?x?x?xf32>
+/// ```
+/// is rewritten to IR such as (simplified):
+/// ```
+/// %v0 = vector.extract %vec[0] : vector<5x4xf32>
+/// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
+/// %v1 = vector.extract %vec[1] : vector<5x4xf32>
+/// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
+/// ...
+/// %v4 = vector.extract %vec[4] : vector<5x4xf32>
+/// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
+/// ```
+///
+/// Note: A pass label is attached to new TransferWriteOps, so that subsequent
+/// applications of this pattern can read the indices of previously generated
+/// vector.extract ops.
+struct UnrollTransferWriteConversion
+ : public OpRewritePattern<TransferWriteOp> {
+ using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+
+ /// If this is not the first application of the pattern, find the original
+ /// vector %vec that is written by this transfer op. Otherwise, return the
+ /// vector of this transfer op.
+ Value getDataVector(TransferWriteOp xferOp) const {
+ if (xferOp->hasAttr(kPassLabel))
+ return getExtractOp(xferOp).vector();
+ return xferOp.vector();
+ }
+
+ /// Assuming that this is not the first application of the pattern, find the
+ /// vector.extract op whose result is written by this transfer op.
+ vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
+ return dyn_cast<vector::ExtractOp>(xferOp.vector().getDefiningOp());
+ }
+
+ void getExtractionIndices(TransferWriteOp xferOp,
+ SmallVector<int64_t, 8> &indices) const {
+ if (xferOp->hasAttr(kPassLabel)) {
+ llvm::for_each(getExtractOp(xferOp).position(), [&](Attribute attr) {
+ indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
+ });
+ }
+ }
+
+ /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
+ /// accesses, and broadcasts and transposes in permutation maps.
+ LogicalResult matchAndRewrite(TransferWriteOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (xferOp.getVectorType().getRank() <= kTargetRank)
+ return failure();
+
+ ScopedContext scope(rewriter, xferOp.getLoc());
+ auto vec = getDataVector(xferOp);
+ auto xferVecType = xferOp.getVectorType();
+ int64_t dimSize = xferVecType.getShape()[0];
+
+ // Generate fully unrolled loop of transfer ops.
+ for (int64_t i = 0; i < dimSize; ++i) {
+ Value iv = std_constant_index(i);
+
+ generateInBoundsCheck(
+ xferOp, iv, rewriter, unpackedDim(xferOp),
+ /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
+ ScopedContext scope(b, loc);
+
+ // Indices for the new transfer op.
+ SmallVector<Value, 8> xferIndices;
+ getXferIndices(xferOp, iv, xferIndices);
+
+ // Indices for the new vector.extract op.
+ SmallVector<int64_t, 8> extractionIndices;
+ getExtractionIndices(xferOp, extractionIndices);
+ extractionIndices.push_back(i);
+
+ auto extracted = vector_extract(vec, extractionIndices).value;
+ auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+
+ auto newXferOp =
+ vector_transfer_write(
+ Type(), extracted, xferOp.source(), xferIndices,
+ AffineMapAttr::get(unpackedPermutationMap(xferOp, b)),
+ Value(), inBoundsAttr)
+ .op;
+
+ maybeAssignMask(b, xferOp, newXferOp, i);
+ maybeApplyPassLabel(b, newXferOp);
+ });
+ }
+
+ rewriter.eraseOp(xferOp);
+ return success();
+ }
+};
+
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
/// part of TransferOp1dConversion. Return the memref dimension on which
/// the transfer is operating. A return value of None indicates a broadcast.
@@ -819,11 +1083,16 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
namespace mlir {
void populateProgressiveVectorToSCFConversionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<PrepareTransferReadConversion,
- PrepareTransferWriteConversion,
- TransferOpConversion<TransferReadOp>,
- TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+ RewritePatternSet &patterns,
+ const ProgressiveVectorTransferToSCFOptions &options) {
+ if (options.unroll) {
+ patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
+ patterns.getContext());
+ } else {
+ patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
+ TransferOpConversion<TransferReadOp>,
+ TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+ }
if (kTargetRank == 1) {
patterns.add<TransferOp1dConversion<TransferReadOp>,
@@ -834,16 +1103,22 @@ void populateProgressiveVectorToSCFConversionPatterns(
struct ConvertProgressiveVectorToSCFPass
: public ConvertVectorToSCFBase<ConvertProgressiveVectorToSCFPass> {
+ ConvertProgressiveVectorToSCFPass(
+ const ProgressiveVectorTransferToSCFOptions &opt)
+ : options(opt) {}
+
void runOnFunction() override {
RewritePatternSet patterns(getFunction().getContext());
- populateProgressiveVectorToSCFConversionPatterns(patterns);
+ populateProgressiveVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
+
+ ProgressiveVectorTransferToSCFOptions options;
};
} // namespace mlir
-std::unique_ptr<Pass>
-mlir::createProgressiveConvertVectorToSCFPass() {
- return std::make_unique<ConvertProgressiveVectorToSCFPass>();
+std::unique_ptr<Pass> mlir::createProgressiveConvertVectorToSCFPass(
+ const ProgressiveVectorTransferToSCFOptions &options) {
+ return std::make_unique<ConvertProgressiveVectorToSCFPass>(options);
}
diff --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
index 36dfd4ff87a5..625ffa985239 100644
--- a/mlir/lib/Interfaces/VectorInterfaces.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -16,7 +16,7 @@ VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
SmallVector<int64_t, 8> shape;
for (int64_t i = 0; i < vecType.getRank(); ++i) {
// Only result dims have a corresponding dim in the mask.
- if (auto expr = map.getResult(i).template isa<AffineDimExpr>()) {
+ if (map.getResult(i).template isa<AffineDimExpr>()) {
shape.push_back(vecType.getDimSize(i));
}
}
diff --git a/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir
new file mode 100644
index 000000000000..f90d20a518a6
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_inbounds
+func @transfer_read_inbounds(%A : memref<?x?x?xf32>) -> (vector<2x3x4xf32>) {
+ %f0 = constant 0.0: f32
+ %c0 = constant 0: index
+
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NOT: scf.if
+ // CHECK-NOT: scf.for
+ %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0 {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<2x3x4xf32>
+ return %vec : vector<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_out_of_bounds
+func @transfer_read_out_of_bounds(%A : memref<?x?x?xf32>) -> (vector<2x3x4xf32>) {
+ %f0 = constant 0.0: f32
+ %c0 = constant 0: index
+
+ // CHECK: scf.if
+ // CHECK: scf.if
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK: scf.if
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK: scf.if
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK: scf.if
+ // CHECK: scf.if
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK: scf.if
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK: scf.if
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NOT: scf.for
+ %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0 : memref<?x?x?xf32>, vector<2x3x4xf32>
+ return %vec : vector<2x3x4xf32>
+}
+
+// -----
+
+func @transfer_read_mask(%A : memref<?x?x?xf32>, %mask : vector<2x3x4xi1>) -> (vector<2x3x4xf32>) {
+ %f0 = constant 0.0: f32
+ %c0 = constant 0: index
+
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<2x3x4xi1>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[0, 1] : vector<2x3x4xi1>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[0, 2] : vector<2x3x4xi1>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[1, 0] : vector<2x3x4xi1>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[1, 1] : vector<2x3x4xi1>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: vector.extract %{{.*}}[1, 2] : vector<2x3x4xi1>
+ // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NOT: scf.if
+ // CHECK-NOT: scf.for
+ %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0, %mask {in_bounds = [true, true, true]}: memref<?x?x?xf32>, vector<2x3x4xf32>
+ return %vec : vector<2x3x4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 4521b293ddb3..d7dc9d6f1e59 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index c391e3d39a2c..1fc11fab8528 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -3,6 +3,11 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
[10., 11., 12., 13.],
[20., 21., 22., 23.]]>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index ff64dbbc8e4c..902f4f50d223 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -3,6 +3,11 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fm42 = constant -42.0: f32
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 0e62d2c09be7..d1ac5e1b994f 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -390,16 +390,20 @@ struct TestVectorMultiReductionLoweringPatterns
}
};
+template <bool Unroll>
struct TestProgressiveVectorToSCFLoweringPatterns
- : public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns,
+ : public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns<Unroll>,
FunctionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<memref::MemRefDialect, scf::SCFDialect, AffineDialect>();
}
void runOnFunction() override {
- RewritePatternSet patterns(&getContext());
- populateProgressiveVectorToSCFConversionPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ RewritePatternSet patterns(&this->getContext());
+ ProgressiveVectorTransferToSCFOptions options;
+ options.unroll = Unroll;
+ populateProgressiveVectorToSCFConversionPatterns(patterns, options);
+ (void)applyPatternsAndFoldGreedily(this->getFunction(),
+ std::move(patterns));
}
};
@@ -450,9 +454,18 @@ void registerTestVectorConversions() {
"test-vector-transfer-lowering-patterns",
"Test conversion patterns to lower transfer ops to other vector ops");
- PassRegistration<TestProgressiveVectorToSCFLoweringPatterns> transferOpToSCF(
- "test-progressive-convert-vector-to-scf",
- "Test conversion patterns to progressively lower transfer ops to SCF");
+ PassRegistration<TestProgressiveVectorToSCFLoweringPatterns<
+ /*Unroll=*/false>>
+ transferOpToSCF("test-progressive-convert-vector-to-scf",
+ "Test conversion patterns to progressively lower "
+ "transfer ops to SCF");
+
+ PassRegistration<TestProgressiveVectorToSCFLoweringPatterns<
+ /*Unroll=*/true>>
+ transferOpToSCFUnrolled(
+ "test-unrolled-progressive-convert-vector-to-scf",
+ "Test conversion patterns to progressively lower transfer ops to SCF"
+ "(unrolled variant)");
PassRegistration<TestVectorMultiReductionLoweringPatterns>
multiDimReductionOpLoweringPass(
More information about the Mlir-commits
mailing list