[Mlir-commits] [mlir] e826db6 - [mlir][linalg] Move generalization pattern to Transforms (NFC).
Tobias Gysi
llvmlistbot at llvm.org
Tue Oct 5 05:50:47 PDT 2021
Author: Tobias Gysi
Date: 2021-10-05T12:49:42Z
New Revision: e826db624040919a11e9dd3a9f3714105ee031ce
URL: https://github.com/llvm/llvm-project/commit/e826db624040919a11e9dd3a9f3714105ee031ce
DIFF: https://github.com/llvm/llvm-project/commit/e826db624040919a11e9dd3a9f3714105ee031ce.diff
LOG: [mlir][linalg] Move generalization pattern to Transforms (NFC).
Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D110728
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4a76b9257320..b2bd87086ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
+/// Creates a GenericOp from the given named operation `namedOp`. Assumes
+/// `namedOp` is not a GenericOp and has a region builder.
+GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
+
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
/// smallest constant value for the size of the buffer needed for each
@@ -380,6 +384,9 @@ LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
+/// Generalize named operations to generic operations.
+LogicalResult generalizeNamedOpPrecondition(Operation *op);
+
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
@@ -701,6 +708,31 @@ struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
SmallVector<unsigned, 8> interchangeVector;
};
+///
+/// Linalg generalization pattern.
+///
+/// Apply the `generalization` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `generalization` for more details.
+struct LinalgGeneralizationPattern : public RewritePattern {
+ // Entry point to match any LinalgOp OpInterface.
+ LinalgGeneralizationPattern(
+ MLIRContext *context,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+ // Entry point to match a specific Linalg op.
+ LinalgGeneralizationPattern(
+ StringRef opName, MLIRContext *context,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+};
+
///
/// Linalg promotion patterns.
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index d0d14f86c54d..afa78b8bc384 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -29,10 +29,19 @@
using namespace mlir;
using namespace mlir::linalg;
-// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
-// the given `namedOp` does not have a region builder.
-static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
- PatternRewriter &rewriter) {
+LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
+ LinalgOp namedOp = dyn_cast<LinalgOp>(op);
+ // Check if the operation is a LinalgOp but not a GenericOp.
+ if (!namedOp || isa<GenericOp>(op))
+ return failure();
+ // Check if the operation has a region builder.
+ if (!namedOp.getRegionBuilder())
+ return failure();
+ return success();
+}
+
+GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
+ LinalgOp namedOp) {
SmallVector<Value> inputOperands = namedOp.getInputOperands();
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
// Otherwise use the region builder to generate a new region.
// TODO: Remove this path once all linag operations have a region attached.
auto regionBuilder = namedOp.getRegionBuilder();
- if (!regionBuilder) {
- LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
- return nullptr;
- }
+ assert(regionBuilder && "expect the operation to have region builder");
return rewriter.create<GenericOp>(
namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
iterators,
@@ -112,41 +118,6 @@ struct GeneralizeConvOp
GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
};
-/// Catch-all pattern for converting all named ops with a region builder into
-/// linalg.generic.
-struct LinalgNamedOpGeneralizationPattern : RewritePattern {
- LinalgNamedOpGeneralizationPattern(MLIRContext *context,
- LinalgTransformationFilter marker,
- PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- marker(std::move(marker)) {}
-
- LogicalResult matchAndRewrite(Operation *rootOp,
- PatternRewriter &rewriter) const override {
- auto linalgOp = dyn_cast<LinalgOp>(rootOp);
- if (!linalgOp)
- return failure();
- if (failed(marker.checkAndNotify(rewriter, linalgOp)))
- return failure();
-
- // No nothing to do for linalg.generic.
- if (isa<GenericOp>(rootOp))
- return failure();
-
- GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
- if (!genericOp)
- return failure();
-
- rewriter.replaceOp(rootOp, genericOp.getResults());
- marker.replaceLinalgTransformationFilter(rewriter,
- genericOp.getOperation());
- return success();
- }
-
-private:
- LinalgTransformationFilter marker;
-};
-
struct LinalgGeneralizationPass
: public LinalgGeneralizationBase<LinalgGeneralizationPass> {
void runOnFunction() override;
@@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
RewritePatternSet &patterns, LinalgTransformationFilter marker) {
- patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
- marker);
+ patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index aacb20ca9726..34ff4017a686 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
return success();
}
+/// Linalg generalization pattern.
+mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
+ MLIRContext *context, LinalgTransformationFilter filter,
+ PatternBenefit benefit)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
+
+mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
+ StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
+ PatternBenefit benefit)
+ : RewritePattern(opName, benefit, context, {}), filter(filter) {}
+
+LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+ if (failed(generalizeNamedOpPrecondition(op)))
+ return failure();
+
+ GenericOp genericOp = generalizeNamedOp(rewriter, op);
+ rewriter.replaceOp(op, genericOp.getResults());
+ filter.replaceLinalgTransformationFilter(rewriter, genericOp);
+ return success();
+}
+
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
MLIRContext *context, LinalgTransformationFilter filter,
LinalgPromotionOptions options, PatternBenefit benefit)
More information about the Mlir-commits
mailing list