[Mlir-commits] [mlir] 800694a - [mlir][Linalg] Make a LinalgStrategyDecomposePass available.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Nov 11 09:47:31 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-11T17:47:27Z
New Revision: 800694a6977c02cfd9770ef3bd5530e6fb4ff2f7

URL: https://github.com/llvm/llvm-project/commit/800694a6977c02cfd9770ef3bd5530e6fb4ff2f7
DIFF: https://github.com/llvm/llvm-project/commit/800694a6977c02cfd9770ef3bd5530e6fb4ff2f7.diff

LOG: [mlir][Linalg] Make a LinalgStrategyDecomposePass available.

Differential Revision: https://reviews.llvm.org/D113684

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 1c9b7252942f3..c0173ec2f443a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -67,7 +67,7 @@ std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
 /// buffers instead.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
 
-/// Create a pass to conver named Linalg operations to Linalg generic
+/// Create a pass to convert named Linalg operations to Linalg generic
 /// operations.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
 
@@ -108,6 +108,11 @@ createLinalgStrategyGeneralizePass(StringRef opName = "",
                                    linalg::LinalgTransformationFilter filter =
                                        linalg::LinalgTransformationFilter());
 
+/// Create a LinalgStrategyDecomposePass.
+// TODO: atm this is applied to all supported ops. If/when we need finer control
+// this should be exposed with an opName + filter and a proper pattern.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyDecomposePass();
+
 /// Create a LinalgStrategyInterchangePass.
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange = {},

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 5d05039b5ea59..c9bcfebecb022 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -287,6 +287,19 @@ def LinalgStrategyGeneralizePass
   ];
 }
 
+// TODO: atm this is applied to all supported ops. If/when we need finer control
+// this should be exposed with an opName + filter and a proper pattern.
+def LinalgStrategyDecomposePass
+    : FunctionPass<"linalg-strategy-decompose-pass"> {
+  let summary = "Configurable pass to apply pattern-based generalization.";
+  let constructor = "mlir::createLinalgStrategyDecomposePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+  ];
+}
+
 def LinalgStrategyInterchangePass
     : FunctionPass<"linalg-strategy-interchange-pass"> {
   let summary = "Configurable pass to apply pattern-based iterator interchange.";

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 4a97b926655fa..24cec12cec62d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -134,6 +134,25 @@ struct LinalgStrategyGeneralizePass
   LinalgTransformationFilter filter;
 };
 
+/// Configurable pass to apply lowering of coarser-grained named linalg ops into
+/// finer-grained named versions.
+struct LinalgStrategyDecomposePass
+    : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
+
+  LinalgStrategyDecomposePass() = default;
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+    RewritePatternSet decompositionPattern(funcOp.getContext());
+    populateDecomposeConvolutionPatterns(decompositionPattern);
+    if (failed(applyPatternsAndFoldGreedily(funcOp,
+                                            std::move(decompositionPattern))))
+      signalPassFailure();
+  }
+};
+
 /// Configurable pass to apply pattern-based linalg generalization.
 struct LinalgStrategyInterchangePass
     : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
@@ -389,6 +408,13 @@ mlir::createLinalgStrategyGeneralizePass(StringRef opName,
                                          LinalgTransformationFilter filter) {
   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
 }
+/// Create a LinalgStrategyDecomposePass.
+// TODO: atm this is applied to all supported ops. If/when we need finer control
+// this should be exposed with an opName + filter and a proper pattern.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyDecomposePass() {
+  return std::make_unique<LinalgStrategyDecomposePass>();
+}
 
 /// Create a LinalgStrategyInterchangePass.
 std::unique_ptr<OperationPass<FuncOp>>

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index c3415c6e2a5d6..8f0b43afafcb1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -14,13 +14,14 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/SetVector.h"
@@ -554,12 +555,6 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
-  RewritePatternSet patterns(funcOp.getContext());
-  populateDecomposeConvolutionPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
 static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
@@ -726,8 +721,13 @@ void TestLinalgTransforms::runOnFunction() {
   if (testTileScalarizeDynamicDims)
     return applyTilePattern(getFunction(), loopType, tileSizes,
                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
-  if (testDecomposeConvolutionPattern)
-    return applyDecomposeConvolutionPatterns(getFunction());
+  if (testDecomposeConvolutionPattern) {
+    // TODO: thread all tests through LinalgStrategy passes.
+    OpPassManager dynamicPM("builtin.func");
+    dynamicPM.addPass(createLinalgStrategyDecomposePass());
+    if (failed(runPipeline(dynamicPM, getFunction())))
+      return signalPassFailure();
+  }
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list