[Mlir-commits] [mlir] 6c3d0fa - [mlir][Linalg] Relax restriction of Linalg passes on FuncOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jul 5 12:55:06 PDT 2023


Author: Nicolas Vasilache
Date: 2023-07-05T19:55:01Z
New Revision: 6c3d0fa0cda8c7a144ba3318376d916116d47efd

URL: https://github.com/llvm/llvm-project/commit/6c3d0fa0cda8c7a144ba3318376d916116d47efd
DIFF: https://github.com/llvm/llvm-project/commit/6c3d0fa0cda8c7a144ba3318376d916116d47efd.diff

LOG: [mlir][Linalg] Relax restriction of Linalg passes on FuncOp

Existing Linalg passes are still anchoring on FuncOp.
Relax this unnecessary limitation.

Differential Revision: https://reviews.llvm.org/D154497

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index f2b5788af20bb0..5f46affe592a2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -38,31 +38,28 @@ std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
 
 std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgInlineScalarOperandsPass();
+std::unique_ptr<Pass> createLinalgInlineScalarOperandsPass();
 
 /// Create a pass to convert Linalg operations to scf.for loops and
 /// memref.load/memref.store accesses.
-std::unique_ptr<OperationPass<func::FuncOp>> createConvertLinalgToLoopsPass();
+std::unique_ptr<Pass> createConvertLinalgToLoopsPass();
 
 /// Create a pass to convert Linalg operations to scf.parallel loops and
 /// memref.load/memref.store accesses.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createConvertLinalgToParallelLoopsPass();
+std::unique_ptr<Pass> createConvertLinalgToParallelLoopsPass();
 
 /// Create a pass to convert Linalg operations to affine.for loops and
 /// affine_load/affine_store accesses.
 /// Placeholder for now, this is NYI.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createConvertLinalgToAffineLoopsPass();
+std::unique_ptr<Pass> createConvertLinalgToAffineLoopsPass();
 
 /// Create a pass to convert Linalg operations which work on tensors to use
 /// buffers instead.
-std::unique_ptr<OperationPass<func::FuncOp>> createLinalgBufferizePass();
+std::unique_ptr<Pass> createLinalgBufferizePass();
 
 /// Create a pass to convert named Linalg operations to Linalg generic
 /// operations.
-std::unique_ptr<OperationPass<func::FuncOp>> createLinalgGeneralizationPass();
+std::unique_ptr<Pass> createLinalgGeneralizationPass();
 
 /// Create a pass to convert Linalg operations to equivalent operations that
 /// work on primitive types, if possible.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index be4c870d4129c3..1ed867b3a0df75 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -55,7 +55,7 @@ def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
 }
 
-def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands", "func::FuncOp"> {
+def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands"> {
   let summary = "Inline scalar operands into linalg generic ops";
   let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
   let dependentDialects = [
@@ -63,7 +63,7 @@ def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands", "func::Fu
   ];
 }
 
-def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops", "func::FuncOp"> {
+def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";
   let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
@@ -71,7 +71,7 @@ def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops", "func::Fun
     "affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
 }
 
-def LinalgLowerToLoops : Pass<"convert-linalg-to-loops", "func::FuncOp"> {
+def LinalgLowerToLoops : Pass<"convert-linalg-to-loops"> {
   let summary = "Lower the operations from the linalg dialect into loops";
   let constructor = "mlir::createConvertLinalgToLoopsPass()";
   let dependentDialects = [
@@ -82,7 +82,7 @@ def LinalgLowerToLoops : Pass<"convert-linalg-to-loops", "func::FuncOp"> {
 }
 
 def LinalgLowerToParallelLoops
-    : Pass<"convert-linalg-to-parallel-loops", "func::FuncOp"> {
+    : Pass<"convert-linalg-to-parallel-loops"> {
   let summary = "Lower the operations from the linalg dialect into parallel "
                 "loops";
   let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
@@ -94,7 +94,7 @@ def LinalgLowerToParallelLoops
   ];
 }
 
-def LinalgBufferize : Pass<"linalg-bufferize", "func::FuncOp"> {
+def LinalgBufferize : Pass<"linalg-bufferize"> {
   let summary = "Bufferize the linalg dialect";
   let constructor = "mlir::createLinalgBufferizePass()";
   let dependentDialects = [
@@ -105,7 +105,7 @@ def LinalgBufferize : Pass<"linalg-bufferize", "func::FuncOp"> {
   ];
 }
 
-def LinalgGeneralization : Pass<"linalg-generalize-named-ops", "func::FuncOp"> {
+def LinalgGeneralization : Pass<"linalg-generalize-named-ops"> {
   let summary = "Convert named ops into generic ops";
   let constructor = "mlir::createLinalgGeneralizationPass()";
   let dependentDialects = ["linalg::LinalgDialect"];

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index ca74a17f6cd0f8..73c4d4779750d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -49,6 +49,6 @@ struct LinalgBufferizePass
 };
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLinalgBufferizePass() {
+std::unique_ptr<Pass> mlir::createLinalgBufferizePass() {
   return std::make_unique<LinalgBufferizePass>();
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 146217d55b124d..2382903bf3785b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -84,10 +84,9 @@ struct LinalgGeneralizationPass
 } // namespace
 
 void LinalgGeneralizationPass::runOnOperation() {
-  func::FuncOp func = getOperation();
   RewritePatternSet patterns(&getContext());
   populateLinalgNamedOpsGeneralizationPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
@@ -95,7 +94,6 @@ void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
   patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgGeneralizationPass() {
+std::unique_ptr<Pass> mlir::createLinalgGeneralizationPass() {
   return std::make_unique<LinalgGeneralizationPass>();
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index c9b118ba0825e9..9cf8f7c4e1ca0a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -103,17 +103,15 @@ struct LinalgInlineScalarOperandsPass
     : public impl::LinalgInlineScalarOperandsBase<
           LinalgInlineScalarOperandsPass> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
-    MLIRContext *context = funcOp.getContext();
-    RewritePatternSet patterns(context);
-
+    Operation *op = getOperation();
+    MLIRContext &ctx = getContext();
+    RewritePatternSet patterns(&ctx);
     populateInlineConstantOperandsPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgInlineScalarOperandsPass() {
+std::unique_ptr<Pass> mlir::createLinalgInlineScalarOperandsPass() {
   return std::make_unique<LinalgInlineScalarOperandsPass>();
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index d91d8c4bf61003..5a44a85c95c75a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -312,8 +312,8 @@ struct FoldAffineOp : public RewritePattern {
 };
 
 template <typename LoopType>
-static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) {
-  MLIRContext *context = funcOp.getContext();
+static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
+  MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
   memref::DimOp::getCanonicalizationPatterns(patterns, context);
@@ -321,7 +321,7 @@ static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) {
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
   // Just apply the patterns greedily.
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
 }
 
 struct LowerToAffineLoops
@@ -352,18 +352,15 @@ struct LowerToParallelLoops
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createConvertLinalgToLoopsPass() {
+std::unique_ptr<Pass> mlir::createConvertLinalgToLoopsPass() {
   return std::make_unique<LowerToLoops>();
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createConvertLinalgToParallelLoopsPass() {
+std::unique_ptr<Pass> mlir::createConvertLinalgToParallelLoopsPass() {
   return std::make_unique<LowerToParallelLoops>();
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createConvertLinalgToAffineLoopsPass() {
+std::unique_ptr<Pass> mlir::createConvertLinalgToAffineLoopsPass() {
   return std::make_unique<LowerToAffineLoops>();
 }
 


        


More information about the Mlir-commits mailing list