[Mlir-commits] [mlir] 5499898 - [mlir] Generalize SCF passes to not have to run on FuncOp.

Stella Laurenzo llvmlistbot at llvm.org
Sun Jun 26 11:05:43 PDT 2022


Author: Stella Laurenzo
Date: 2022-06-26T11:05:35-07:00
New Revision: 54998986c3d9e847615598ebcd294ca5dce6c772

URL: https://github.com/llvm/llvm-project/commit/54998986c3d9e847615598ebcd294ca5dce6c772
DIFF: https://github.com/llvm/llvm-project/commit/54998986c3d9e847615598ebcd294ca5dce6c772.diff

LOG: [mlir] Generalize SCF passes to not have to run on FuncOp.

Seems to have been an accident of history and none of these had any reason to be restricted to FuncOp.

Differential Revision: https://reviews.llvm.org/D128614

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
    mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
    mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
    mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index d910f2e02d090..896859d5ee375 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -11,7 +11,7 @@
 
 include "mlir/Pass/PassBase.td"
 
-def SCFBufferize : Pass<"scf-bufferize", "func::FuncOp"> {
+def SCFBufferize : Pass<"scf-bufferize"> {
   let summary = "Bufferize the scf dialect.";
   let constructor = "mlir::createSCFBufferizePass()";
   let dependentDialects = ["bufferization::BufferizationDialect",
@@ -21,14 +21,14 @@ def SCFBufferize : Pass<"scf-bufferize", "func::FuncOp"> {
 // Note: Making these canonicalization patterns would require a dependency
 // of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa.
 def SCFForLoopCanonicalization
-    : Pass<"scf-for-loop-canonicalization", "func::FuncOp"> {
+    : Pass<"scf-for-loop-canonicalization"> {
   let summary = "Canonicalize operations within scf.for loop bodies";
   let constructor = "mlir::createSCFForLoopCanonicalizationPass()";
   let dependentDialects = ["AffineDialect", "tensor::TensorDialect",
                            "memref::MemRefDialect"];
 }
 
-def SCFForLoopPeeling : Pass<"scf-for-loop-peeling", "func::FuncOp"> {
+def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> {
   let summary = "Peel `for` loops at their upper bounds.";
   let constructor = "mlir::createForLoopPeelingPass()";
   let options = [
@@ -40,7 +40,7 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling", "func::FuncOp"> {
   let dependentDialects = ["AffineDialect"];
 }
 
-def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization", "func::FuncOp"> {
+def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let summary = "Specialize `for` loops for vectorization";
   let constructor = "mlir::createForLoopSpecializationPass()";
 }
@@ -64,12 +64,12 @@ def SCFParallelLoopCollapsing : Pass<"scf-parallel-loop-collapsing"> {
 }
 
 def SCFParallelLoopSpecialization
-    : Pass<"scf-parallel-loop-specialization", "func::FuncOp"> {
+    : Pass<"scf-parallel-loop-specialization"> {
   let summary = "Specialize parallel loops for vectorization";
   let constructor = "mlir::createParallelLoopSpecializationPass()";
 }
 
-def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling", "func::FuncOp"> {
+def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling"> {
   let summary = "Tile parallel loops";
   let constructor = "mlir::createParallelLoopTilingPass()";
   let options = [
@@ -88,7 +88,7 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
   let constructor = "mlir::createForLoopRangeFoldingPass()";
 }
 
-def SCFForToWhileLoop : Pass<"scf-for-to-while", "func::FuncOp"> {
+def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
   let description = [{

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 17cc6f3773390..14eb075d8c897 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -100,11 +100,11 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
 
 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
-    MLIRContext *ctx = funcOp.getContext();
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
     RewritePatternSet patterns(ctx);
     patterns.add<ForLoopLoweringPattern>(ctx);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
   }
 };
 } // namespace

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index eda6bc6e1cf8b..18d43d72e210b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -195,11 +195,11 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
 struct SCFForLoopCanonicalization
     : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
-    MLIRContext *ctx = funcOp.getContext();
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
     RewritePatternSet patterns(ctx);
     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
-    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
       signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 1195f7f8a5672..aa0056b683907 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -237,7 +237,7 @@ namespace {
 struct ParallelLoopSpecialization
     : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> {
   void runOnOperation() override {
-    getOperation().walk(
+    getOperation()->walk(
         [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
   }
 };
@@ -245,20 +245,20 @@ struct ParallelLoopSpecialization
 struct ForLoopSpecialization
     : public SCFForLoopSpecializationBase<ForLoopSpecialization> {
   void runOnOperation() override {
-    getOperation().walk([](ForOp op) { specializeForLoopForUnrolling(op); });
+    getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); });
   }
 };
 
 struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
-    MLIRContext *ctx = funcOp.getContext();
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
     RewritePatternSet patterns(ctx);
     patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
 
     // Drop the markers.
-    funcOp.walk([](Operation *op) {
+    parentOp->walk([](Operation *op) {
       op->removeAttr(kPeeledLoopLabel);
       op->removeAttr(kPartialIterationLabel);
     });

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index f20764647d4b4..c39d3afae25ce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -195,8 +195,9 @@ struct ParallelLoopTiling
   }
 
   void runOnOperation() override {
+    auto *parentOp = getOperation();
     SmallVector<ParallelOp, 2> innermostPloops;
-    getInnermostParallelLoops(getOperation().getOperation(), innermostPloops);
+    getInnermostParallelLoops(parentOp, innermostPloops);
     for (ParallelOp ploop : innermostPloops) {
       // FIXME: Add reduction support.
       if (ploop.getNumReductions() == 0)


        


More information about the Mlir-commits mailing list