[Mlir-commits] [mlir] cb1a423 - [mlir][vector] Move splitting transfer ops into a separate entry point
Lei Zhang
llvmlistbot at llvm.org
Tue Feb 16 07:04:49 PST 2021
Author: Lei Zhang
Date: 2021-02-16T10:04:34-05:00
New Revision: cb1a42359bff2ba49d072df88ad3ffb4c66c16d8
URL: https://github.com/llvm/llvm-project/commit/cb1a42359bff2ba49d072df88ad3ffb4c66c16d8
DIFF: https://github.com/llvm/llvm-project/commit/cb1a42359bff2ba49d072df88ad3ffb4c66c16d8.diff
LOG: [mlir][vector] Move splitting transfer ops into a separate entry point
These patterns unrolls transfer read/write ops if the vector consumers/
producers are extract/insert slices op. Transfer ops can map to hardware
load/store functionalities, where the vector size matters for bandwidth
considerations. So these patterns should be collected separately, instead
of being generic canonicalization patterns.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D96782
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 3650470bc7be..ee7ed62dcf01 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -45,6 +45,18 @@ void populateVectorToVectorCanonicalizationPatterns(
void populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
+/// Collect a set of patterns to split transfer read/write ops.
+///
+/// These patterns unrolls transfer read/write ops if the vector consumers/
+/// producers are extract/insert slices op. Transfer ops can map to hardware
+/// load/store functionalities, where the vector size matters for bandwith
+/// considerations. So these patterns should be collected separately, instead
+/// of being generic canonicalization patterns. Also one can let the
+/// `ignoreFilter` to return true to fail matching for fine-grained control.
+void populateSplitVectorTransferPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context,
+ std::function<bool(Operation *)> ignoreFilter = nullptr);
+
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index dd3ea8ead746..0a6c88d4d99b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -705,22 +705,33 @@ mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
namespace {
-// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
-// scheme of its unique ExtractSlicesOp user.
-struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+// Splits a TransferReadOp into smaller TransferReadOps based on slicing
+// scheme of its unique ExtractSlicesOp users.
+class SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
+public:
+ SplitTransferReadOp(MLIRContext *context,
+ std::function<bool(Operation *)> ignoreFilter = nullptr,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {}
- LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- // TODO: Support splitting TransferReadOp with non-identity
- // permutation maps. Repurpose code from MaterializeVectors transformation.
- if (!isIdentitySuffix(xferReadOp.permutation_map()))
+ if (ignoreFilter && ignoreFilter(readOp))
+ return failure();
+
+ // TODO: Support splitting TransferReadOp with non-identity permutation
+ // maps. Repurpose code from MaterializeVectors transformation.
+ if (!isIdentitySuffix(readOp.permutation_map()))
+ return failure();
+
+ // Return unless there is only one user, and it is an ExtractSlicesOp.
+ Value readResult = readOp.getResult();
+ if (!readResult.hasOneUse())
return failure();
- // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
- Value xferReadResult = xferReadOp.getResult();
+
auto extractSlicesOp =
- dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
- if (!xferReadResult.hasOneUse() || !extractSlicesOp)
+ dyn_cast<vector::ExtractSlicesOp>(readResult.use_begin()->getOwner());
+ if (!extractSlicesOp)
return failure();
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
@@ -730,37 +741,48 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
extractSlicesOp.getStrides(strides);
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
- Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
+ Value newVec = unrollTransferReadOp(readOp, sizes, rewriter);
if (!newVec)
return failure();
- rewriter.replaceOp(xferReadOp, newVec);
+ rewriter.replaceOp(readOp, newVec);
return success();
}
+
+private:
+ std::function<bool(Operation *)> ignoreFilter;
};
-// Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
-struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+// Splits a TransferWriteOp into smaller TransferWriteOps for each source.
+class SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
+public:
+ SplitTransferWriteOp(MLIRContext *context,
+ std::function<bool(Operation *)> ignoreFilter = nullptr,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {}
- LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
- // TODO: Support splitting TransferWriteOp with non-identity
- // permutation maps. Repurpose code from MaterializeVectors transformation.
- if (!isIdentitySuffix(xferWriteOp.permutation_map()))
+ if (ignoreFilter && ignoreFilter(writeOp))
+ return failure();
+
+ // TODO: Support splitting TransferWriteOp with non-identity permutation
+ // maps. Repurpose code from MaterializeVectors transformation.
+ if (!isIdentitySuffix(writeOp.permutation_map()))
return failure();
- // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
- auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
- auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
+
+ // Fail to match unless this is writing a vector resulting from an
+ // InsertSlicesOp.
+ auto insertSlicesOp =
+ writeOp.vector().getDefiningOp<vector::InsertSlicesOp>();
if (!insertSlicesOp)
return failure();
- // Get TupleOp operand of 'insertSlicesOp'.
- auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
- insertSlicesOp.vectors().getDefiningOp());
+ // Get the TupleOp operand of the InsertSlicesOp.
+ auto tupleOp = insertSlicesOp.vectors().getDefiningOp<vector::TupleOp>();
if (!tupleOp)
return failure();
- // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
+ // Get 'sizes' and 'strides' parameters from the InsertSlicesOp user.
auto sourceTupleType = insertSlicesOp.getSourceTupleType();
auto resultVectorType = insertSlicesOp.getResultVectorType();
SmallVector<int64_t, 4> sizes;
@@ -768,21 +790,20 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
SmallVector<int64_t, 4> strides;
insertSlicesOp.getStrides(strides);
- Location loc = xferWriteOp.getLoc();
+ Location loc = writeOp.getLoc();
auto shapedElementType =
- xferWriteOp.source().getType().cast<ShapedType>().getElementType();
- SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
- xferWriteOp.indices().end());
+ writeOp.source().getType().cast<ShapedType>().getElementType();
+ auto indices = llvm::to_vector<4>(writeOp.indices());
Value resultTensor;
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
- // `masked` attribute propagates conservatively: if the coarse op didn't
+ // 'masked' attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
Operation *write = rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index),
- resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
- xferWriteOp.permutation_map(),
- xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
+ resultTensor ? resultTensor : writeOp.source(), sliceIndices,
+ writeOp.permutation_map(),
+ writeOp.masked() ? *writeOp.masked() : ArrayAttr());
if (!write->getResults().empty())
resultTensor = write->getResult(0);
};
@@ -790,13 +811,15 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
sourceTupleType, sizes, strides, indices, rewriter,
createSlice);
- // Erase old 'xferWriteOp'.
if (resultTensor)
- rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
+ rewriter.replaceOp(writeOp, ArrayRef<Value>(resultTensor));
else
- rewriter.eraseOp(xferWriteOp);
+ rewriter.eraseOp(writeOp);
return success();
}
+
+private:
+ std::function<bool(Operation *)> ignoreFilter;
};
/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
@@ -3105,15 +3128,16 @@ struct BubbleUpBitCastForStridedSliceInsert
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- // clang-format off
- patterns.insert<ShapeCastOpDecomposer,
- ShapeCastOpFolder,
- SplitTransferReadOp,
- SplitTransferWriteOp,
- TupleGetFolderOp,
- TransferReadExtractPattern,
- TransferWriteInsertPattern>(context);
- // clang-format on
+ patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
+ TransferReadExtractPattern, TransferWriteInsertPattern>(
+ context);
+}
+
+void mlir::vector::populateSplitVectorTransferPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context,
+ std::function<bool(Operation *)> ignoreFilter) {
+ patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(context,
+ ignoreFilter);
}
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 8ec970f68b23..17da2d42fa32 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -47,6 +47,7 @@ struct TestVectorToVectorConversion
populateVectorToVectorTransformationPatterns(patterns, ctx);
populateBubbleVectorBitCastOpPatterns(patterns, ctx);
populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
+ populateSplitVectorTransferPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
More information about the Mlir-commits
mailing list