[Mlir-commits] [mlir] 2ca887d - [mlir] VectorToSCF target rank is a pass option
Matthias Springer
llvmlistbot at llvm.org
Thu May 13 18:30:52 PDT 2021
Author: Matthias Springer
Date: 2021-05-14T10:30:43+09:00
New Revision: 2ca887de6e3c83f77e3ccde172ff55042cece6ab
URL: https://github.com/llvm/llvm-project/commit/2ca887de6e3c83f77e3ccde172ff55042cece6ab
DIFF: https://github.com/llvm/llvm-project/commit/2ca887de6e3c83f77e3ccde172ff55042cece6ab.diff
LOG: [mlir] VectorToSCF target rank is a pass option
Make "target rank" a pass option of VectorToSCF.
Depends On D102101
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D102123
Added:
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b26b708c06fb6..b440578754dad 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -519,6 +519,8 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
let options = [
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
"Perform full unrolling when converting vector transfers to SCF">,
+ Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
+ "Target vector rank to which transfer ops should be lowered">,
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 5a42b9a070f84..03765cb5532c8 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -49,10 +49,17 @@ class RewritePatternSet;
struct VectorTransferToSCFOptions {
bool unroll = false;
+ unsigned targetRank = 1;
+
VectorTransferToSCFOptions &setUnroll(bool u) {
unroll = u;
return *this;
}
+
+ VectorTransferToSCFOptions &setTargetRank(unsigned r) {
+ targetRank = r;
+ return *this;
+ }
};
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 5b5769c9ad066..a209bc4672e36 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -38,8 +38,16 @@ namespace {
/// Attribute name used for labeling transfer ops during progressive lowering.
static const char kPassLabel[] = "__vector_to_scf_lowering__";
-/// Lower to 1D transfer ops. Target-specific lowering will lower those.
-static const int64_t kTargetRank = 1;
+/// Patterns that inherit from this struct have access to
+/// VectorTransferToSCFOptions.
+template <typename OpTy>
+struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
+ explicit VectorToSCFPattern(MLIRContext *context,
+ VectorTransferToSCFOptions opt)
+ : OpRewritePattern<OpTy>(context), options(opt) {}
+
+ VectorTransferToSCFOptions options;
+};
/// Given a MemRefType with VectorType element type, unpack one dimension from
/// the VectorType into the MemRefType.
@@ -270,8 +278,9 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
/// Add the pass label to a vector transfer op if its rank is not the target
/// rank.
template <typename OpTy>
-static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) {
- if (newXferOp.getVectorType().getRank() > kTargetRank)
+static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
+ unsigned targetRank) {
+ if (newXferOp.getVectorType().getRank() > targetRank)
newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
}
@@ -347,8 +356,10 @@ struct Strategy<TransferReadOp> {
/// Note: The loop and type cast are generated in TransferOpConversion.
/// The original TransferReadOp and store op are deleted in `cleanup`.
/// Note: The `mask` operand is set in TransferOpConversion.
- static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
- Value buffer, Value iv) {
+ static TransferReadOp rewriteOp(OpBuilder &builder,
+ VectorTransferToSCFOptions options,
+ TransferReadOp xferOp, Value buffer,
+ Value iv) {
SmallVector<Value, 8> storeIndices;
getBufferIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
@@ -367,7 +378,8 @@ struct Strategy<TransferReadOp> {
.value;
maybeApplyPassLabel(builder,
- dyn_cast<TransferReadOp>(newXfer.getDefiningOp()));
+ dyn_cast<TransferReadOp>(newXfer.getDefiningOp()),
+ options.targetRank);
memref_store(newXfer, buffer, storeIndices);
return newXfer.getDefiningOp<TransferReadOp>();
@@ -428,8 +440,10 @@ struct Strategy<TransferWriteOp> {
/// to memory.
///
/// Note: For more details, see comments on Strategy<TransferReadOp>.
- static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
- Value buffer, Value iv) {
+ static TransferWriteOp rewriteOp(OpBuilder &builder,
+ VectorTransferToSCFOptions options,
+ TransferWriteOp xferOp, Value buffer,
+ Value iv) {
SmallVector<Value, 8> loadIndices;
getBufferIndices(xferOp, loadIndices);
loadIndices.push_back(iv);
@@ -444,7 +458,7 @@ struct Strategy<TransferWriteOp> {
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
inBoundsAttr);
- maybeApplyPassLabel(builder, newXfer.op);
+ maybeApplyPassLabel(builder, newXfer.op, options.targetRank);
return newXfer;
}
@@ -460,10 +474,10 @@ struct Strategy<TransferWriteOp> {
};
template <typename OpTy>
-LogicalResult checkPrepareXferOp(OpTy xferOp) {
+LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
if (xferOp->hasAttr(kPassLabel))
return failure();
- if (xferOp.getVectorType().getRank() <= kTargetRank)
+ if (xferOp.getVectorType().getRank() <= targetRank)
return failure();
return success();
}
@@ -491,12 +505,13 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
/// ```
///
/// Note: A second temporary buffer may be allocated for the `mask` operand.
-struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+struct PrepareTransferReadConversion
+ : public VectorToSCFPattern<TransferReadOp> {
+ using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp).failed())
+ if (checkPrepareXferOp(xferOp, options.targetRank).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
@@ -539,12 +554,12 @@ struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> {
///
/// Note: A second temporary buffer may be allocated for the `mask` operand.
struct PrepareTransferWriteConversion
- : public OpRewritePattern<TransferWriteOp> {
- using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+ : public VectorToSCFPattern<TransferWriteOp> {
+ using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp).failed())
+ if (checkPrepareXferOp(xferOp, options.targetRank).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
@@ -583,8 +598,8 @@ struct PrepareTransferWriteConversion
/// out-of-bounds, generate an if-check and handle both cases separately.
/// 3. Clean up according to the corresponding Strategy<OpTy>.
template <typename OpTy>
-struct TransferOpConversion : public OpRewritePattern<OpTy> {
- using OpRewritePattern<OpTy>::OpRewritePattern;
+struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
+ using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
@@ -635,8 +650,8 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
/*inBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
// Create new transfer op.
- OpTy newXfer =
- Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
+ OpTy newXfer = Strategy<OpTy>::rewriteOp(
+ b, this->options, xferOp, castedDataBuffer, iv);
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and
@@ -731,8 +746,9 @@ static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp,
/// Note: As an optimization, if the result of the original TransferReadOp
/// was directly inserted into another vector, no new %v_init vector is created.
/// Instead, the new TransferReadOp results are inserted into that vector.
-struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+struct UnrollTransferReadConversion
+ : public VectorToSCFPattern<TransferReadOp> {
+ using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
/// Return the vector into which the newly created TransferReadOp results
/// are inserted.
@@ -770,7 +786,7 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
- if (xferOp.getVectorType().getRank() <= kTargetRank)
+ if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
@@ -861,8 +877,8 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
/// doing so, `a` may become dead, and the number of ExtractOps generated during
/// recursive application of this pattern will be minimal.
struct UnrollTransferWriteConversion
- : public OpRewritePattern<TransferWriteOp> {
- using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+ : public VectorToSCFPattern<TransferWriteOp> {
+ using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
@@ -893,7 +909,7 @@ struct UnrollTransferWriteConversion
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (xferOp.getVectorType().getRank() <= kTargetRank)
+ if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
@@ -1062,8 +1078,8 @@ static bool isLastMemrefDimUnitStride(MemRefType type) {
/// }
/// ```
template <typename OpTy>
-struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
- using OpRewritePattern<OpTy>::OpRewritePattern;
+struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
+ using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
@@ -1106,17 +1122,18 @@ void populateVectorToSCFConversionPatterns(
RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
if (options.unroll) {
patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
- patterns.getContext());
+ patterns.getContext(), options);
} else {
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
TransferOpConversion<TransferReadOp>,
- TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+ TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
+ options);
}
- if (kTargetRank == 1) {
+ if (options.targetRank == 1) {
patterns.add<TransferOp1dConversion<TransferReadOp>,
- TransferOp1dConversion<TransferWriteOp>>(
- patterns.getContext());
+ TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
+ options);
}
}
@@ -1129,12 +1146,16 @@ struct ConvertVectorToSCFPass
ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
this->fullUnroll = options.unroll;
+ this->targetRank = options.targetRank;
}
void runOnFunction() override {
+ VectorTransferToSCFOptions options;
+ options.setUnroll(fullUnroll);
+ options.setTargetRank(targetRank);
+
RewritePatternSet patterns(getFunction().getContext());
- populateVectorToSCFConversionPatterns(
- patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+ populateVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list