[Mlir-commits] [mlir] e789efc - [mlir][linalg] Refactor PadTensorOpVectorizationPattern (NFC)
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 4 07:51:20 PDT 2021
Author: Matthias Springer
Date: 2021-06-04T23:45:08+09:00
New Revision: e789efc92a5aa6495a22bd3a93a03f640dc2f32a
URL: https://github.com/llvm/llvm-project/commit/e789efc92a5aa6495a22bd3a93a03f640dc2f32a
DIFF: https://github.com/llvm/llvm-project/commit/e789efc92a5aa6495a22bd3a93a03f640dc2f32a.diff
LOG: [mlir][linalg] Refactor PadTensorOpVectorizationPattern (NFC)
* Rename PadTensorOpVectorizationPattern to GenericPadTensorOpVectorizationPattern.
* Make GenericPadTensorOpVectorizationPattern a private pattern, to be instantiated via populatePadTensorOpVectorizationPatterns.
* Factor out parts of PadTensorOpVectorizationPattern into helper functions.
This commit prepares PadTensorOpVectorizationPattern for a series of subsequent commits that add more specialized PadTensorOp vectorization patterns.
Differential Revision: https://reviews.llvm.org/D103681
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 25f4b3627d6ef..9df8fbb2e4693 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -880,14 +880,14 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
PatternRewriter &rewriter) const override;
};
-/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
-/// it needs a specific pattern to vectorize.
-struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
- using OpRewritePattern<PadTensorOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(PadTensorOp padOp,
- PatternRewriter &rewriter) const override;
-};
+/// Populates `patterns` with patterns that vectorize linalg.pad_tensor.
+/// These patterns are meant to apply in a complementary fashion. Benefits
+/// are used to encode a certain ordering of pattern application. To avoid
+/// scattering magic constants throughout the code base, the patterns must be
+/// added with this function. `baseBenefit` can be used to offset the benefit
+/// of all PadTensorOp vectorization patterns by a certain value.
+void populatePadTensorOpVectorizationPatterns(
+ RewritePatternSet &patterns, PatternBenefit baseBenefit = 1);
/// Match and rewrite for the pattern:
/// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 12a8d80c72fcc..b52059d535cf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -650,66 +650,81 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
// Misc. vectorization patterns.
//----------------------------------------------------------------------------//
-/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
-/// TransferWriteOp. For now, this only applies when all low and high paddings
-/// are determined to be zero.
-LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
- linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
- // Helper function to determine whether an OpFoldResult is not a zero Index.
- auto isNotZeroIndex = [](OpFoldResult ofr) {
- if (Attribute attr = ofr.dyn_cast<Attribute>())
- return attr.cast<IntegerAttr>().getInt() != 0;
- Value v = ofr.get<Value>();
- if (auto constOp = v.getDefiningOp<ConstantOp>())
- if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
- return intAttr.getValue().getSExtValue() != 0;
- return true;
- };
-
- auto resultShapedType = padOp.result().getType().cast<ShapedType>();
- // Bail on non-static shapes.
- if (!resultShapedType.hasStaticShape())
- return failure();
-
- // If any pad_low is not a static 0, needs a mask. Bail for now.
- if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
- return failure();
- VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
- if (!vectorType)
- return failure();
-
- // Only support padding with a constant for now, i.e. either:
- // 1. A BBarg from a
diff erent block.
- // 2. A value defined outside of the current block.
- Block &block = padOp.region().front();
+/// Given a block, return the Value that the block yields if that Value is
+/// constant. In this context, "constant" means "defined outside of the block".
+/// Should not be called on blocks that yield more than one value.
+///
+/// Values are considered constant in two cases:
+/// - A basic block argument from a
diff erent block.
+/// - A value defined outside of the block.
+///
+/// If the yielded value is not constant, an empty Value is returned.
+static Value getConstantYieldValueFromBlock(Block &block) {
auto yieldOp = cast<YieldOp>(block.getTerminator());
assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
- Value padValue = yieldOp.values().front();
- Operation *definingOp = padValue.getDefiningOp();
+ Value result = yieldOp.values().front();
+ Operation *definingOp = result.getDefiningOp();
+
+ // Check if yield value is defined inside the block.
if (definingOp && definingOp->getBlock() == &block)
- return failure();
- if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
- return failure();
+ return Value();
+ // Check if the yield value is a BB arg of the block.
+ if (!definingOp && result.cast<BlockArgument>().getOwner() == &block)
+ return Value();
- // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
- if (llvm::any_of(padOp.getMixedHighPad(),
- [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
- return failure();
+ return result;
+}
+
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
+/// TransferWriteOp. For now, this only applies when all low and high paddings
+/// are determined to be zero.
+struct GenericPadTensorOpVectorizationPattern
+ : public OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadTensorOp padOp,
+ PatternRewriter &rewriter) const override {
+ /// Given an OpFoldResult, return true if its value is guaranteed to be a
+ /// zero integer.
+ auto isZeroInt = [&](OpFoldResult ofr) {
+ return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); };
+ // Low padding must be static 0.
+ if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure();
+ // High padding must be static 0.
+ if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
+ // Pad value must be a constant.
+ auto padValue = getConstantYieldValueFromBlock(padOp.region().front());
+ if (!padValue) return failure();
+
+ // Bail on non-static shapes.
+ auto resultShapedType = padOp.result().getType().cast<ShapedType>();
+ if (!resultShapedType.hasStaticShape())
+ return failure();
+ VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
+ if (!vectorType)
+ return failure();
- // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
- // TransferWriteOp@[0..0].
- SmallVector<Value> indices(
- resultShapedType.getRank(),
- rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
- Value read = rewriter.create<vector::TransferReadOp>(
- padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
- Value init =
- rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
- resultShapedType.getElementType());
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
- indices);
+ // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
+ // TransferWriteOp@[0..0].
+ SmallVector<Value> indices(
+ resultShapedType.getRank(),
+ rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
+ Value read = rewriter.create<vector::TransferReadOp>(
+ padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
+ Value init = rewriter.create<InitTensorOp>(
+ padOp.getLoc(), resultShapedType.getShape(),
+ resultShapedType.getElementType());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
+ indices);
- return success();
+ return success();
+ }
+};
+
+void mlir::linalg::populatePadTensorOpVectorizationPatterns(
+ RewritePatternSet &patterns, PatternBenefit baseBenefit) {
+ patterns.add<GenericPadTensorOpVectorizationPattern>(
+ patterns.getContext(), baseBenefit);
}
// TODO: cleanup all the convolution vectorization patterns.
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 181e1d93b250a..16eb79ffe4fb6 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -508,7 +508,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
funcOp.getContext(),
LinalgTransformationFilter()
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
- patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());
+ populatePadTensorOpVectorizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
More information about the Mlir-commits
mailing list