[Mlir-commits] [mlir] e826db6 - [mlir][linalg] Move generalization pattern to Transforms (NFC).

Tobias Gysi llvmlistbot at llvm.org
Tue Oct 5 05:50:47 PDT 2021


Author: Tobias Gysi
Date: 2021-10-05T12:49:42Z
New Revision: e826db624040919a11e9dd3a9f3714105ee031ce

URL: https://github.com/llvm/llvm-project/commit/e826db624040919a11e9dd3a9f3714105ee031ce
DIFF: https://github.com/llvm/llvm-project/commit/e826db624040919a11e9dd3a9f3714105ee031ce.diff

LOG: [mlir][linalg] Move generalization pattern to Transforms (NFC).

Move the generalization pattern to the other Linalg transforms to make it available to the codegen strategy.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110728

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4a76b9257320..b2bd87086ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -234,6 +234,10 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
 void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
                           ArrayRef<unsigned> interchangeVector);
 
+/// Creates a GenericOp from the given named operation `namedOp`. Assumes
+/// `namedOp` is not a GenericOp and has a region builder.
+GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
+
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
 /// smallest constant value for the size of the buffer needed for each
@@ -380,6 +384,9 @@ LogicalResult
 interchangeGenericOpPrecondition(GenericOp genericOp,
                                  ArrayRef<unsigned> interchangeVector);
 
+/// Generalize named operations to generic operations.
+LogicalResult generalizeNamedOpPrecondition(Operation *op);
+
 /// Promote std.subviews feeding linalg operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
@@ -701,6 +708,31 @@ struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
   SmallVector<unsigned, 8> interchangeVector;
 };
 
+///
+/// Linalg generalization pattern.
+///
+/// Apply the `generalization` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `generalization` for more details.
+struct LinalgGeneralizationPattern : public RewritePattern {
+  // Entry point to match any LinalgOp OpInterface.
+  LinalgGeneralizationPattern(
+      MLIRContext *context,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  // Entry point to match a specific Linalg op.
+  LinalgGeneralizationPattern(
+      StringRef opName, MLIRContext *context,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+};
+
 ///
 /// Linalg promotion patterns.
 ///

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index d0d14f86c54d..afa78b8bc384 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -29,10 +29,19 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
-// the given `namedOp` does not have a region builder.
-static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
-                                            PatternRewriter &rewriter) {
+LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
+  LinalgOp namedOp = dyn_cast<LinalgOp>(op);
+  // Check if the operation is a LinalgOp but not a GenericOp.
+  if (!namedOp || isa<GenericOp>(op))
+    return failure();
+  // Check if the operation has a region builder.
+  if (!namedOp.getRegionBuilder())
+    return failure();
+  return success();
+}
+
+GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
+                                          LinalgOp namedOp) {
   SmallVector<Value> inputOperands = namedOp.getInputOperands();
   SmallVector<Value> outputOperands = namedOp.getOutputOperands();
   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@@ -54,10 +63,7 @@ static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
   // Otherwise use the region builder to generate a new region.
   // TODO: Remove this path once all linag operations have a region attached.
   auto regionBuilder = namedOp.getRegionBuilder();
-  if (!regionBuilder) {
-    LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
-    return nullptr;
-  }
+  assert(regionBuilder && "expect the operation to have region builder");
   return rewriter.create<GenericOp>(
       namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
       iterators,
@@ -112,41 +118,6 @@ struct GeneralizeConvOp
   GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
 };
 
-/// Catch-all pattern for converting all named ops with a region builder into
-/// linalg.generic.
-struct LinalgNamedOpGeneralizationPattern : RewritePattern {
-  LinalgNamedOpGeneralizationPattern(MLIRContext *context,
-                                     LinalgTransformationFilter marker,
-                                     PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-        marker(std::move(marker)) {}
-
-  LogicalResult matchAndRewrite(Operation *rootOp,
-                                PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<LinalgOp>(rootOp);
-    if (!linalgOp)
-      return failure();
-    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
-      return failure();
-
-    // No nothing to do for linalg.generic.
-    if (isa<GenericOp>(rootOp))
-      return failure();
-
-    GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
-    if (!genericOp)
-      return failure();
-
-    rewriter.replaceOp(rootOp, genericOp.getResults());
-    marker.replaceLinalgTransformationFilter(rewriter,
-                                             genericOp.getOperation());
-    return success();
-  }
-
-private:
-  LinalgTransformationFilter marker;
-};
-
 struct LinalgGeneralizationPass
     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
   void runOnFunction() override;
@@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgConvGeneralizationPatterns(
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
-  patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
-                                                   marker);
+  patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
 }
 
 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index aacb20ca9726..34ff4017a686 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -488,6 +488,30 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
   return success();
 }
 
+/// Linalg generalization pattern.
+mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
+    MLIRContext *context, LinalgTransformationFilter filter,
+    PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
+
+mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
+    StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
+    PatternBenefit benefit)
+    : RewritePattern(opName, benefit, context, {}), filter(filter) {}
+
+LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+  if (failed(generalizeNamedOpPrecondition(op)))
+    return failure();
+
+  GenericOp genericOp = generalizeNamedOp(rewriter, op);
+  rewriter.replaceOp(op, genericOp.getResults());
+  filter.replaceLinalgTransformationFilter(rewriter, genericOp);
+  return success();
+}
+
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
     MLIRContext *context, LinalgTransformationFilter filter,
     LinalgPromotionOptions options, PatternBenefit benefit)


        


More information about the Mlir-commits mailing list