[Mlir-commits] [mlir] 92ea624 - [mlir][Linalg] Rewrite CodegenStrategy to populate a pass pipeline.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Sep 29 06:39:04 PDT 2021


Author: Nicolas Vasilache
Date: 2021-09-29T13:35:45Z
New Revision: 92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21

URL: https://github.com/llvm/llvm-project/commit/92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21
DIFF: https://github.com/llvm/llvm-project/commit/92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21.diff

LOG: [mlir][Linalg] Rewrite CodegenStrategy to populate a pass pipeline.

This revision retires a good portion of the complexity of the codegen strategy and puts the logic behind pass logic.

Differential revision: https://reviews.llvm.org/D110678

Added: 
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Linalg/codegen-strategy.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
    mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 56c709b543517..867921c51a51e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_LINALG_PASSES_H_
 #define MLIR_DIALECT_LINALG_PASSES_H_
 
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Pass/Pass.h"
 
@@ -77,6 +78,43 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
 /// Create a pass to tile a LinalgOp and fuse its producers.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFuseTensorOpsPass();
 
+//===----------------------------------------------------------------------===//
+/// Linalg strategy passes.
+//===----------------------------------------------------------------------===//
+/// Create a LinalgStrategyTilePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTilePass(
+    StringRef opName = "",
+    linalg::LinalgTilingOptions opt = linalg::LinalgTilingOptions(),
+    linalg::LinalgTransformationFilter filter =
+        linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyPromotePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPromotePass(
+    StringRef opName = "",
+    linalg::LinalgPromotionOptions opt = linalg::LinalgPromotionOptions(),
+    linalg::LinalgTransformationFilter filter =
+        linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyVectorizePass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyVectorizePass(StringRef opName = "",
+                                  linalg::LinalgVectorizationOptions opt =
+                                      linalg::LinalgVectorizationOptions(),
+                                  linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyEnablePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyEnablePass(
+    linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(),
+    linalg::LinalgTransformationFilter filter =
+        linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyLowerVectorsPass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt =
+                                         linalg::LinalgVectorLoweringOptions(),
+                                     linalg::LinalgTransformationFilter filter =
+                                         linalg::LinalgTransformationFilter());
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3f331b1fff502..32327cd968096 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -229,4 +229,66 @@ def LinalgTileAndFuseTensorOps
   let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"];
 }
 
+def LinalgStrategyTilePass
+    : FunctionPass<"linalg-strategy-tile-pass"> {
+  let summary = "Configurable pass to apply pattern-based linalg tiling.";
+  let constructor = "mlir::createLinalgStrategyTilePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+    Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+      "Which linalg op within the func is the anchor to latch on.">,
+  ];
+}
+
+def LinalgStrategyPromotePass
+    : FunctionPass<"linalg-strategy-promote-pass"> {
+  let summary = "Configurable pass to apply pattern-based linalg promotion.";
+  let constructor = "mlir::createLinalgStrategyPromotePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+    Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+      "Which linalg op within the func is the anchor to latch on.">,
+  ];
+}
+
+def LinalgStrategyVectorizePass
+    : FunctionPass<"linalg-strategy-vectorize-pass"> {
+  let summary = "Configurable pass to apply pattern-based linalg vectorization.";
+  let constructor = "mlir::createLinalgStrategyVectorizePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+    Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+      "Which linalg op within the func is the anchor to latch on.">,
+  ];
+}
+
+def LinalgStrategyEnablePass
+    : FunctionPass<"linalg-strategy-enable-pass"> {
+  let summary = "Configurable pass to enable the application of other "
+    "pattern-based linalg passes.";
+  let constructor = "mlir::createLinalgStrategyEnablePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+  ];
+}
+
+def LinalgStrategyLowerVectorsPass
+    : FunctionPass<"linalg-strategy-lower-vectors-pass"> {
+  let summary = "Configurable pass to lower vector operations.";
+  let constructor = "mlir::createLinalgStrategyLowerVectorsPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+  ];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index d33dd81fa323a..ff372079f690d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -10,7 +10,8 @@
 #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
 
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Pass/PassManager.h"
 
 namespace mlir {
 
@@ -21,69 +22,23 @@ namespace linalg {
 /// Abstract Transformation class applied in a sequence that also handles state
 /// through markers.
 struct Transformation {
-  explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
+  explicit Transformation(LinalgTransformationFilter::FilterFunction f)
       : filter(f) {}
   virtual ~Transformation() = default;
-  virtual RewritePatternSet
-  buildRewritePatterns(MLIRContext *context,
-                       linalg::LinalgTransformationFilter m) = 0;
-  linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
+  virtual void addToPassPipeline(OpPassManager &pm,
+                                 LinalgTransformationFilter m) const = 0;
+  LinalgTransformationFilter::FilterFunction filter = nullptr;
 };
 
-/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`.
-template <template <typename> class PatternType, typename ConcreteOpType,
-          typename OptionsType,
-          typename = std::enable_if_t<std::is_member_function_pointer<
-              decltype(&ConcreteOpType::getOperationName)>::value>>
-void sfinae_enqueue(RewritePatternSet &patternList, OptionsType options,
-                    StringRef opName, linalg::LinalgTransformationFilter m) {
-  assert(opName == ConcreteOpType::getOperationName() &&
-         "explicit name must match ConcreteOpType::getOperationName");
-  patternList.add<PatternType<ConcreteOpType>>(patternList.getContext(),
-                                               options, m);
-}
-
-/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
-/// (e.g. LinalgOp, other interfaces, Operation*).
-template <template <typename> class PatternType, typename OpType,
-          typename OptionsType>
-void sfinae_enqueue(RewritePatternSet &patternList, OptionsType options,
-                    StringRef opName, linalg::LinalgTransformationFilter m) {
-  assert(!opName.empty() && "opName must not be empty");
-  patternList.add<PatternType<OpType>>(opName, patternList.getContext(),
-                                       options, m);
-}
-
-template <typename PatternType, typename OpType, typename OptionsType>
-void enqueue(RewritePatternSet &patternList, OptionsType options,
-             StringRef opName, linalg::LinalgTransformationFilter m) {
-  if (!opName.empty())
-    patternList.add<PatternType>(opName, patternList.getContext(), options, m);
-  else
-    patternList.add<PatternType>(patternList.getContext(),
-                                 m.addOpFilter<OpType>(), options);
-}
-
-/// Promotion transformation enqueues a particular stage-1 pattern for
-/// `Tile<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType>
+/// Represent one application of LinalgStrategyTilePass.
 struct Tile : public Transformation {
-  explicit Tile(linalg::LinalgTilingOptions options,
-                linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(LinalgOpType::getOperationName()),
-        options(options) {}
-
   Tile(StringRef name, linalg::LinalgTilingOptions options,
-       linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+       LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(f), opName(name), options(options) {}
 
-  RewritePatternSet
-  buildRewritePatterns(MLIRContext *context,
-                       linalg::LinalgTransformationFilter m) override {
-    RewritePatternSet tilingPatterns(context);
-    sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
-        tilingPatterns, options, opName, m);
-    return tilingPatterns;
+  void addToPassPipeline(OpPassManager &pm,
+                         LinalgTransformationFilter m) const override {
+    pm.addPass(createLinalgStrategyTilePass(opName, options, m));
   }
 
 private:
@@ -91,27 +46,15 @@ struct Tile : public Transformation {
   linalg::LinalgTilingOptions options;
 };
 
-/// Promotion transformation enqueues a particular stage-1 pattern for
-/// `Promote<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType>
+/// Represent one application of createLinalgStrategyPromotePass.
 struct Promote : public Transformation {
-  explicit Promote(
-      linalg::LinalgPromotionOptions options,
-      linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(LinalgOpType::getOperationName()),
-        options(options) {}
-
   Promote(StringRef name, linalg::LinalgPromotionOptions options,
-          linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+          LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(f), opName(name), options(options) {}
 
-  RewritePatternSet
-  buildRewritePatterns(MLIRContext *context,
-                       linalg::LinalgTransformationFilter m) override {
-    RewritePatternSet promotionPatterns(context);
-    sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
-        promotionPatterns, options, opName, m);
-    return promotionPatterns;
+  void addToPassPipeline(OpPassManager &pm,
+                         LinalgTransformationFilter m) const override {
+    pm.addPass(createLinalgStrategyPromotePass(opName, options, m));
   }
 
 private:
@@ -119,30 +62,19 @@ struct Promote : public Transformation {
   linalg::LinalgPromotionOptions options;
 };
 
-/// Vectorization transformation enqueues a particular stage-1 pattern for
-/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
-/// transfer rewrite forwarding patterns.
-template <typename LinalgOpType = LinalgOp>
+/// Represent one application of createLinalgStrategyVectorizePass.
 struct Vectorize : public Transformation {
-  explicit Vectorize(
-      linalg::LinalgVectorizationOptions options,
-      linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+  explicit Vectorize(linalg::LinalgVectorizationOptions options,
+                     LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(f), opName(), options(options) {}
 
   Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
-            linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+            LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(f), opName(name), options(options) {}
 
-  RewritePatternSet
-  buildRewritePatterns(MLIRContext *context,
-                       linalg::LinalgTransformationFilter m) override {
-    RewritePatternSet vectorizationPatterns(context);
-    enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
-        vectorizationPatterns, options, opName, m);
-    vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
-                              linalg::LinalgCopyVTWForwardingPattern>(
-        context, /*benefit=*/2);
-    return vectorizationPatterns;
+  void addToPassPipeline(OpPassManager &pm,
+                         LinalgTransformationFilter m) const override {
+    pm.addPass(createLinalgStrategyVectorizePass(opName, options, m));
   }
 
 private:
@@ -150,129 +82,47 @@ struct Vectorize : public Transformation {
   linalg::LinalgVectorizationOptions options;
 };
 
-/// Options to control the application of late transformations.
-struct LateCodegenStrategyOptions {
-  /// Hoisting transformations are always deemed beneficial and must disabled
-  /// explicitly.
-  bool enableLICM = true;
-  bool enableHoistRedundantVectorTransfers = true;
-  bool enableHoistRedundantVectorTransfersOnTensor = true;
-  /// Vector lowering operations may result in surprising behavior when
-  /// composing multiple codegen strategies and must be enabled explicitly.
-  bool enableVectorTransferPartialRewrite = false;
-  bool enableVectorContractLowering = false;
-  bool enableVectorToSCFConversion = false;
-};
-
 /// Codegen strategy controls how a Linalg op is progressively lowered.
-/// The application uses a 3-level staged patterns strategy which allows
-/// ordering transformations by using the Linalg `applyStagedPatterns`
-/// function, where:
-///   1. The first stage consists of the successive `tile`, `promote` and
-///   `vectorize` patterns, applied sequentially.
-///   2. The second stage consists of common local canonicalization patterns
-///   that are applied eagerly after each stage-1 pattern.
-///   3. the third stage consists of more global transformation, also applied
-///   eagerly, after all stage-2 patterns. Such more global transformations
 struct CodegenStrategy {
-  /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
+  /// Append a pattern to add a level of tiling for Op `opName` with tiling
   /// `options`.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  tile(linalg::LinalgTilingOptions options,
-       linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    transformationSequence.emplace_back(
-        std::make_unique<Tile<LinalgOpType>>(options, f));
-    return *this;
-  }
-  /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
-  /// `options`.
-  template <typename LinalgOpType>
   CodegenStrategy &
   tile(StringRef opName, linalg::LinalgTilingOptions options,
-       linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+       LinalgTransformationFilter::FilterFunction f = nullptr) {
     transformationSequence.emplace_back(
-        std::make_unique<Tile<LinalgOpType>>(opName, options, f));
+        std::make_unique<Tile>(opName, options, f));
     return *this;
   }
   /// Conditionally append a pattern to add a level of tiling for
   /// `LinalgOpType` with tiling `options`.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  tileIf(bool b, linalg::LinalgTilingOptions options,
-         linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? tile<LinalgOpType>(options) : *this;
-  }
-  /// Conditionally append a pattern to add a level of tiling for
-  /// `LinalgOpType` with tiling `options`.
-  template <typename LinalgOpType>
   CodegenStrategy &
   tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
-         linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? tile<LinalgOpType>(opName, options) : *this;
+         LinalgTransformationFilter::FilterFunction f = nullptr) {
+    return b ? tile(opName, options) : *this;
   }
   /// Append a pattern to add a level of promotion for `LinalgOpType` with
   /// promotion `options`.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  promote(linalg::LinalgPromotionOptions options,
-          linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    transformationSequence.emplace_back(
-        std::make_unique<Promote<LinalgOpType>>(options, f));
-    return *this;
-  }
-  /// Append a pattern to add a level of promotion for `LinalgOpType` with
-  /// promotion `options`.
-  template <typename LinalgOpType>
   CodegenStrategy &
   promote(StringRef opName, linalg::LinalgPromotionOptions options,
-          linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+          LinalgTransformationFilter::FilterFunction f = nullptr) {
     transformationSequence.emplace_back(
-        std::make_unique<Promote<LinalgOpType>>(opName, options, f));
+        std::make_unique<Promote>(opName, options, f));
     return *this;
   }
   /// Conditionally append a pattern to add a level of promotion for
   /// `LinalgOpType` with promotion `options`.
-  template <typename LinalgOpType>
   CodegenStrategy &
   promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options,
-            linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? promote<LinalgOpType>(opName, options, f) : *this;
-    return *this;
-  }
-  /// Conditionally append a pattern to add a level of promotion for
-  /// `LinalgOpType` with promotion `options`.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  promoteIf(bool b, linalg::LinalgPromotionOptions options,
-            linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? promote<LinalgOpType>(options, f) : *this;
-    return *this;
-  }
-  /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  vectorize(linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    transformationSequence.emplace_back(
-        std::make_unique<Vectorize<LinalgOpType>>(
-            linalg::LinalgVectorizationOptions(), f));
-    return *this;
-  }
-  /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
-  /// operation.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  vectorizeIf(bool b,
-              linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? vectorize<LinalgOpType>(f) : *this;
+            LinalgTransformationFilter::FilterFunction f = nullptr) {
+    return b ? promote(opName, options, f) : *this;
     return *this;
   }
   /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
   CodegenStrategy &
   vectorize(StringRef opName,
-            linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+            LinalgTransformationFilter::FilterFunction f = nullptr) {
     assert(!opName.empty() && "expected an op name");
-    transformationSequence.emplace_back(std::make_unique<Vectorize<LinalgOp>>(
+    transformationSequence.emplace_back(std::make_unique<Vectorize>(
         opName, linalg::LinalgVectorizationOptions(), f));
     return *this;
   }
@@ -280,14 +130,14 @@ struct CodegenStrategy {
   /// operation.
   CodegenStrategy &
   vectorizeIf(bool b, StringRef opName,
-              linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+              LinalgTransformationFilter::FilterFunction f = nullptr) {
     return b ? vectorize(opName, f) : *this;
     return *this;
   }
   /// Configure the post staged-patterns late vector transformations.
   CodegenStrategy &
   setVectorTransformsOptions(vector::VectorTransformsOptions options) {
-    vectorTransformsOptions = options;
+    vectorTransformOptions = options;
     return *this;
   }
   /// Configure the post staged-patterns late vector.transfer to scf
@@ -328,12 +178,13 @@ struct CodegenStrategy {
 
   /// Apply the transformation patterns in sequence with cleanup
   /// transformations interleaved.
-  void transform(FuncOp func) const;
+  LogicalResult transform(FuncOp func) const;
+  void configurePassPipeline(OpPassManager &pm, MLIRContext *context) const;
 
 private:
   LogicalResult postPatternTransforms(Operation *func) const;
 
-  vector::VectorTransformsOptions vectorTransformsOptions;
+  vector::VectorTransformsOptions vectorTransformOptions;
   VectorTransferToSCFOptions vectorToSCFOptions;
   SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
   LateCodegenStrategyOptions lateCodegenStrategyOptions;

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 03843bdfaa012..4a76b92573209 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -9,6 +9,7 @@
 #ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
 #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
 
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Utils.h"
@@ -593,6 +594,35 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
   }
 };
 
+struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
+  /// Entry point to match any LinalgOp OpInterface.
+  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
+  LinalgGenericTilingPattern(
+      MLIRContext *context, LinalgTransformationFilter filter,
+      LinalgTilingOptions options = LinalgTilingOptions(),
+      PatternBenefit benefit = 1)
+      : LinalgBaseTilingPattern(context, options, filter, benefit) {}
+  /// Entry point to match a specific Linalg op.
+  LinalgGenericTilingPattern(
+      StringRef opName, MLIRContext *context, LinalgTilingOptions options,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    TiledLinalgOp tiledLinalgOp;
+    if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
+                                                            tiledLinalgOp)))
+      return failure();
+    if (tiledLinalgOp.tensorResults.empty())
+      rewriter.eraseOp(op);
+    else
+      rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
+    return success();
+  }
+};
+
 struct LinalgFusionOptions {
   /// List of operands indices to use for fusion.
   llvm::SmallSet<unsigned, 1> indicesToFuse = {};
@@ -678,6 +708,13 @@ struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `promoteSubViews` for more details.
 struct LinalgBasePromotionPattern : public RewritePattern {
+  /// Entry point to match any LinalgOp OpInterface.
+  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
+  LinalgBasePromotionPattern(
+      MLIRContext *context, LinalgTransformationFilter filter,
+      LinalgPromotionOptions options = LinalgPromotionOptions(),
+      PatternBenefit benefit = 1);
+  /// Entry point to match a specific Linalg op.
   LinalgBasePromotionPattern(
       StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
       LinalgTransformationFilter filter = LinalgTransformationFilter(),
@@ -757,6 +794,39 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
       : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
 };
 
+/// Options to control the application of late transformations.
+struct LateCodegenStrategyOptions {
+  /// Hoisting transformations are always deemed beneficial and must disabled
+  /// explicitly.
+  bool enableLICM = true;
+  bool enableHoistRedundantVectorTransfers = true;
+  bool enableHoistRedundantVectorTransfersOnTensor = true;
+  /// Vector lowering operations may result in surprising behavior when
+  /// composing multiple codegen strategies and must be enabled explicitly.
+  bool enableVectorTransferPartialRewrite = false;
+  bool enableVectorContractLowering = false;
+  bool enableVectorToSCFConversion = false;
+};
+
+/// Options to control the application of enabling transformations.
+/// Hoisting transformations are always deemed beneficial and must be disabled
+/// explicitly.
+struct LinalgEnablingOptions {
+  bool enableLICM = true;
+  bool enableHoistRedundantVectorTransfers = true;
+  bool enableHoistRedundantVectorTransfersOnTensor = true;
+};
+
+/// Vector lowering options control how ops are lowered down to 1-D and scf.for
+/// form.
+struct LinalgVectorLoweringOptions {
+  bool enableVectorTransferPartialRewrite = false;
+  bool enableVectorContractLowering = false;
+  bool enableVectorToSCFConversion = false;
+  vector::VectorTransformsOptions vectorTransformOptions;
+  VectorTransferToSCFOptions vectorTransferToSCFOptions;
+};
+
 /// Trait to check if T provides a `getOperationName` method.
 template <typename T, typename... Args>
 using has_get_operation_name = decltype(T::getOperationName());
@@ -929,8 +999,8 @@ struct GeneralizePadTensorOpPattern : public OpRewritePattern<PadTensorOp> {
 /// scattering magic constants throughout the code base, the patterns must be
 /// added with this function. `baseBenefit` can be used to offset the benefit
 /// of all PadTensorOp vectorization patterns by a certain value.
-void populatePadTensorOpVectorizationPatterns(
-    RewritePatternSet &patterns, PatternBenefit baseBenefit = 1);
+void populatePadTensorOpVectorizationPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit baseBenefit = 1);
 
 /// Match and rewrite for the pattern:
 /// ```

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index a2a9108e83900..41d59c435ec25 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -218,17 +218,17 @@ class ContractionOpToMatmulOpLowering
   }
 
   ContractionOpToMatmulOpLowering(
-      vector::VectorTransformsOptions vectorTransformsOptions,
+      vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
+        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformsOptions;
+  vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
 };
 
@@ -259,17 +259,17 @@ class ContractionOpToOuterProductOpLowering
   }
 
   ContractionOpToOuterProductOpLowering(
-      vector::VectorTransformsOptions vectorTransformsOptions,
+      vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
+        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformsOptions;
+  vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
 };
 
@@ -303,18 +303,17 @@ class ContractionOpToDotLowering
   }
 
   ContractionOpToDotLowering(
-      vector::VectorTransformsOptions vectorTransformsOptions,
+      vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions),
-        filter(defaultFilter) {}
+        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformsOptions;
+  vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
 };
 
@@ -342,18 +341,18 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     return success();
   }
 
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
                         MLIRContext *context,
                         FilterConstraintType constraint = defaultFilter)
       : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
+        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformsOptions;
+  vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
   // Lower one parallel dimension.
   Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a05cd5858f42b..5dfb419ffd406 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  LinalgStrategyPasses.cpp
   Promotion.cpp
   Tiling.cpp
   Transforms.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 8192551a5c392..cd0e75cc8a17b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
 
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
@@ -27,85 +28,43 @@ using namespace mlir::linalg;
 
 #define DEBUG_TYPE "linalg-codegen-strategy"
 
-void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
-  MLIRContext *context = func.getContext();
-  // Emplace patterns one at a time while also maintaining a simple chained
-  // state transition.
-  unsigned stepCount = 0;
-  SmallVector<FrozenRewritePatternSet, 4> stage1Patterns;
-  auto zeroState = Identifier::get(std::to_string(stepCount), context);
-  auto currentState = zeroState;
-  for (const std::unique_ptr<Transformation> &t : transformationSequence) {
-    auto nextState = Identifier::get(std::to_string(++stepCount), context);
-    auto marker = (currentState == zeroState)
+void mlir::linalg::CodegenStrategy::configurePassPipeline(
+    OpPassManager &pm, MLIRContext *context) const {
+  for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e;
+       ++stepCount) {
+    const std::unique_ptr<Transformation> &t =
+        transformationSequence[stepCount];
+    std::string currentStr = std::to_string(stepCount);
+    auto currentState = Identifier::get(currentStr, context);
+    std::string nextStr = std::to_string(stepCount + 1);
+    auto nextState = Identifier::get(nextStr, context);
+    auto filter = (currentState.str() == std::to_string(0))
                       ? linalg::LinalgTransformationFilter(
                             t->filter, ArrayRef<Identifier>{}, nextState)
                       : linalg::LinalgTransformationFilter(
                             t->filter, currentState, nextState);
-    stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
-    currentState = nextState;
-  }
-
-  RewritePatternSet stage2Patterns =
-      linalg::getLinalgTilingCanonicalizationPatterns(context);
-  scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns);
-
-  auto stage3Transforms = [&](Operation *op) {
-    // Some of these may be too aggressive as a stage 3 that is applied on each
-    // stage 1 application and may have to be split out to post staged patterns
-    // application (in which case they could just be passes, TBD).
-    if (lateCodegenStrategyOptions.enableLICM) {
-      op->walk([&](LoopLikeOpInterface loopLike) {
-        LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
-        if (failed(moveLoopInvariantCode(loopLike)))
-          llvm_unreachable("unexpected LICM failure");
-      });
-    }
-    promoteSingleIterationLoops(cast<FuncOp>(op));
-    if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers)
-      hoistRedundantVectorTransfers(cast<FuncOp>(op));
-    if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfersOnTensor)
-      hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
-    return success();
-  };
-  (void)linalg::applyStagedPatterns(
-      func, stage1Patterns, std::move(stage2Patterns), stage3Transforms);
-
-  //===--------------------------------------------------------------------===//
-  // Post staged patterns transforms
-  //===--------------------------------------------------------------------===//
-
-  // Programmatic splitting of slow/fast path vector transfers.
-  if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
-    RewritePatternSet patterns(context);
-    patterns.add<vector::VectorTransferFullPartialRewriter>(
-        context, vectorTransformsOptions);
-    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
-  }
-
-  // Programmatic controlled lowering of vector.contract only.
-  if (lateCodegenStrategyOptions.enableVectorContractLowering) {
-    RewritePatternSet vectorContractLoweringPatterns(context);
-    vectorContractLoweringPatterns
-        .add<ContractionOpToOuterProductOpLowering,
-             ContractionOpToMatmulOpLowering, ContractionOpLowering>(
-            vectorTransformsOptions, context);
-    vector::populateVectorTransferPermutationMapLoweringPatterns(
-        vectorContractLoweringPatterns);
-    (void)applyPatternsAndFoldGreedily(
-        func, std::move(vectorContractLoweringPatterns));
-  }
-
-  // Programmatic controlled lowering of vector.transfer only.
-  if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
-    RewritePatternSet vectorToLoopsPatterns(context);
-    populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
-                                          vectorToSCFOptions);
-    (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
+    t->addToPassPipeline(pm, filter);
+    pm.addPass(createLinalgStrategyEnablePass());
   }
+  LinalgVectorLoweringOptions vectorLoweringOptions;
+  vectorLoweringOptions.enableVectorTransferPartialRewrite =
+      lateCodegenStrategyOptions.enableVectorTransferPartialRewrite;
+  vectorLoweringOptions.enableVectorContractLowering =
+      lateCodegenStrategyOptions.enableVectorContractLowering;
+  vectorLoweringOptions.enableVectorToSCFConversion =
+      lateCodegenStrategyOptions.enableVectorToSCFConversion;
+  vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions;
+  vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions;
+  pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions));
+}
 
+LogicalResult mlir::linalg::CodegenStrategy::transform(FuncOp funcOp) const {
+  PassManager pm(funcOp.getContext(), funcOp.getOperationName());
+  configurePassPipeline(pm, funcOp.getContext());
+  LogicalResult res = pm.run(funcOp);
   // Ensure we drop the marker in the end.
-  func.walk([](LinalgOp op) {
+  funcOp.walk([](LinalgOp op) {
     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
   });
+  return res;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
new file mode 100644
index 0000000000000..db0a24d4f4b3f
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -0,0 +1,256 @@
+//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a configurable pass that can apply patterns liberally
+// and be plugged in a pass pipeline.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+using namespace linalg;
+
+namespace {
+
+/// Configurable pass to apply pattern-based linalg tiling.
+struct LinalgStrategyTilePass
+    : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
+
+  LinalgStrategyTilePass() = default;
+
+  LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
+                         LinalgTransformationFilter filt)
+      : options(opt), filter(filt) {
+    this->anchorOpName.setValue(opName.str());
+  }
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    RewritePatternSet tilingPattern(funcOp.getContext());
+    if (!anchorOpName.empty()) {
+      tilingPattern.add<LinalgGenericTilingPattern>(
+          anchorOpName, funcOp.getContext(), options, filter);
+    } else {
+      tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
+                                                    options);
+    }
+    (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
+  }
+
+  LinalgTilingOptions options;
+  LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg promotion.
+struct LinalgStrategyPromotePass
+    : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
+
+  LinalgStrategyPromotePass() = default;
+
+  LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
+                            LinalgTransformationFilter filt)
+      : options(opt), filter(filt) {
+    this->anchorOpName.setValue(opName.str());
+  }
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    RewritePatternSet promotionPattern(funcOp.getContext());
+    if (!anchorOpName.empty()) {
+      promotionPattern.add<LinalgBasePromotionPattern>(
+          anchorOpName, funcOp.getContext(), options, filter);
+    } else {
+      promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
+                                                       filter, options);
+    }
+    (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
+  }
+
+  LinalgPromotionOptions options;
+  LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg vectorization.
+struct LinalgStrategyVectorizePass
+    : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
+
+  LinalgStrategyVectorizePass() = default;
+
+  LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
+                              LinalgTransformationFilter filt)
+      : options(opt), filter(filt) {
+    this->anchorOpName.setValue(opName.str());
+  }
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    RewritePatternSet vectorizationPatterns(funcOp.getContext());
+    if (!anchorOpName.empty()) {
+      vectorizationPatterns.add<LinalgVectorizationPattern>(
+          anchorOpName, funcOp.getContext(), options, filter);
+    } else {
+      vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
+                                                            filter, options);
+    }
+    vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
+                              linalg::LinalgCopyVTWForwardingPattern>(
+        funcOp.getContext(), /*benefit=*/2);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(vectorizationPatterns));
+  }
+
+  LinalgVectorizationOptions options;
+  LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to enable the application of other pattern-based linalg
+/// passes.
+struct LinalgStrategyEnablePass
+    : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
+
+  LinalgStrategyEnablePass(LinalgEnablingOptions opt,
+                           LinalgTransformationFilter filt)
+      : options(opt), filter(filt) {}
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    MLIRContext *context = funcOp.getContext();
+    RewritePatternSet patterns =
+        linalg::getLinalgTilingCanonicalizationPatterns(context);
+    scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
+      return signalPassFailure();
+
+    if (options.enableLICM) {
+      if (funcOp
+              ->walk([&](LoopLikeOpInterface loopLike) {
+                if (failed(moveLoopInvariantCode(loopLike)))
+                  return WalkResult::interrupt();
+                return WalkResult::advance();
+              })
+              .wasInterrupted())
+        return signalPassFailure();
+    }
+
+    promoteSingleIterationLoops(funcOp);
+    if (options.enableHoistRedundantVectorTransfers)
+      hoistRedundantVectorTransfers(funcOp);
+
+    if (options.enableHoistRedundantVectorTransfersOnTensor)
+      hoistRedundantVectorTransfersOnTensor(funcOp);
+  }
+
+  LinalgEnablingOptions options;
+  LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to lower vector operations.
+struct LinalgStrategyLowerVectorsPass
+    : public LinalgStrategyLowerVectorsPassBase<
+          LinalgStrategyLowerVectorsPass> {
+
+  LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
+                                 LinalgTransformationFilter filt)
+      : options(opt), filter(filt) {}
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    MLIRContext *context = funcOp.getContext();
+    RewritePatternSet patterns(context);
+    if (options.enableVectorTransferPartialRewrite) {
+      patterns.add<vector::VectorTransferFullPartialRewriter>(
+          context, options.vectorTransformOptions);
+    }
+    if (options.enableVectorContractLowering) {
+      patterns.add<ContractionOpToOuterProductOpLowering,
+                   ContractionOpToMatmulOpLowering, ContractionOpLowering>(
+          options.vectorTransformOptions, context);
+      vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+    }
+    if (options.enableVectorToSCFConversion) {
+      populateVectorToSCFConversionPatterns(patterns,
+                                            options.vectorTransferToSCFOptions);
+    }
+    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+  }
+
+  LinalgVectorLoweringOptions options;
+  LinalgTransformationFilter filter;
+};
+} // namespace
+
+/// Create a LinalgStrategyTilePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
+                                   LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyPromotePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyPromotePass(StringRef opName,
+                                      LinalgPromotionOptions opt,
+                                      LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyVectorizePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyVectorizePass(StringRef opName,
+                                        LinalgVectorizationOptions opt,
+                                        LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyEnablePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
+                                     LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
+}
+
+/// Create a LinalgStrategyLowerVectorsPass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
+                                           LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 1d28451ae05ef..c471459da5c25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -488,6 +488,12 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
   return success();
 }
 
+mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
+    MLIRContext *context, LinalgTransformationFilter filter,
+    LinalgPromotionOptions options, PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
+      options(options) {}
+
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
     LinalgTransformationFilter filter, PatternBenefit benefit)

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 7d6e6d2ba53d5..ba97f99972a06 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -672,10 +672,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
-  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
                       MLIRContext *context)
       : OpRewritePattern<vector::TransposeOp>(context),
-        vectorTransformsOptions(vectorTransformsOptions) {}
+        vectorTransformOptions(vectorTransformOptions) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
@@ -689,7 +689,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
     // Handle a true 2-D matrix transpose 
diff erently when requested.
-    if (vectorTransformsOptions.vectorTransposeLowering ==
+    if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
       Type flattenedType =
@@ -739,7 +739,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   }
 
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformsOptions;
+  vector::VectorTransformsOptions vectorTransformOptions;
 };
 
 /// Progressive lowering of OuterProductOp.
@@ -1151,7 +1151,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
   // TODO: implement masks
   if (llvm::size(op.masks()) != 0)
     return failure();
-  if (vectorTransformsOptions.vectorContractLowering !=
+  if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::Matmul)
     return failure();
   if (failed(filter(op)))
@@ -1314,7 +1314,7 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (llvm::size(op.masks()) != 0)
     return failure();
 
-  if (vectorTransformsOptions.vectorContractLowering !=
+  if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
 
@@ -1419,7 +1419,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   if (failed(filter(op)))
     return failure();
 
-  if (vectorTransformsOptions.vectorContractLowering !=
+  if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::Dot)
     return failure();
 
@@ -1560,13 +1560,13 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
-  ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
+  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
     return success();
-  ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
+  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
     return success();
-  ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
+  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
     return success();
 
@@ -1835,8 +1835,9 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
 /// Operates under a scoped context to build the intersection between the
 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
 // TODO: view intersection/union/
diff erences should be a proper std op.
-static std::pair<Value, Value> createSubViewIntersection(
-    OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) {
+static std::pair<Value, Value>
+createSubViewIntersection(OpBuilder &b, VectorTransferOpInterface xferOp,
+                          Value alloc) {
   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
   int64_t memrefRank = xferOp.getShapedType().getRank();
   // TODO: relax this precondition, will require rank-reducing subviews.
@@ -2195,6 +2196,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   MemRefType compatibleMemRefType =
       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
                                   alloc.getType().cast<MemRefType>());
+  if (!compatibleMemRefType)
+    return failure();
+
   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
                                    b.getIndexType());
   returnTypes[0] = compatibleMemRefType;

diff  --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index e08d99eb03ef9..bbfdaf371a8c7 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -1,9 +1,9 @@
 // Test that both anchor-op name and MatmulOp-based codegen strategy produce the same result.
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 tile-interchange=1,2,0 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 tile-interchange=1,2,0 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
 
 // CHECK-LABEL: func @matmul(
 // OUTER-LABEL: func @matmul(

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
index 12f44c6dbc4eb..7ddfdaf538fed 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -1,8 +1,8 @@
 // RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
 // RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
 // RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \
-// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
-// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
 
 // RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
 // RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \

diff  --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index aed639419e009..bf5cf5a166f21 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -93,7 +93,7 @@ void TestConvVectorization::runOnOperation() {
   // Post staged patterns transforms
   //===--------------------------------------------------------------------===//
 
-  VectorTransformsOptions vectorTransformsOptions{
+  VectorTransformsOptions vectorTransformOptions{
       VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
 
   RewritePatternSet vectorTransferPatterns(context);
@@ -101,7 +101,7 @@ void TestConvVectorization::runOnOperation() {
   // supported as can be seen in splitFullAndPartialTransferPrecondition,
   // VectorTransforms.cpp
   vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
-      context, vectorTransformsOptions);
+      context, vectorTransformOptions);
   (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
 
   // Programmatic controlled lowering of linalg.copy and linalg.fill.
@@ -113,9 +113,9 @@ void TestConvVectorization::runOnOperation() {
   // Programmatic controlled lowering of vector.contract only.
   RewritePatternSet vectorContractLoweringPatterns(context);
   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
-                                         vectorTransformsOptions);
+                                         vectorTransformOptions);
   populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
-                                          vectorTransformsOptions);
+                                          vectorTransformOptions);
   (void)applyPatternsAndFoldGreedily(module,
                                      std::move(vectorContractLoweringPatterns));
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 8f2cd6c689a0d..679cc9375aea3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -52,7 +52,6 @@ struct TestLinalgCodegenStrategy
 
   void runOnFunction() override;
 
-  template <typename OpType>
   void runStrategy(LinalgTilingOptions tilingOptions,
                    LinalgTilingOptions registerTilingOptions,
                    vector::VectorContractLowering vectorContractLowering,
@@ -127,26 +126,23 @@ struct TestLinalgCodegenStrategy
       llvm::cl::init("")};
 };
 
-template <>
-void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
+void TestLinalgCodegenStrategy::runStrategy(
     LinalgTilingOptions tilingOptions,
     LinalgTilingOptions registerTilingOptions,
     vector::VectorContractLowering vectorContractLowering,
     vector::VectorTransferSplit vectorTransferSplit) {
   assert(!anchorOpName.empty());
   CodegenStrategy strategy;
-  strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
-      .promoteIf<LinalgOp>(promote, anchorOpName,
-                           LinalgPromotionOptions()
-                               .setAlignment(16)
-                               .setUseFullTileBuffersByDefault(promoteFullTile))
-      .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
-                        registerTilingOptions)
-      .promoteIf<LinalgOp>(
-          registerPromote, anchorOpName,
-          LinalgPromotionOptions()
-              .setAlignment(16)
-              .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+  strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
+      .promoteIf(promote, anchorOpName,
+                 LinalgPromotionOptions()
+                     .setAlignment(16)
+                     .setUseFullTileBuffersByDefault(promoteFullTile))
+      .tileIf(!registerTileSizes.empty(), anchorOpName, registerTilingOptions)
+      .promoteIf(registerPromote, anchorOpName,
+                 LinalgPromotionOptions()
+                     .setAlignment(16)
+                     .setUseFullTileBuffersByDefault(registerPromoteFullTile))
       .vectorizeIf(vectorize, anchorOpName)
       .setEnableVectorTransferPartialRewrite(true)
       .setEnableVectorContractLowering(true)
@@ -157,39 +153,7 @@ void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
               .setVectorTransferSplit(vectorTransferSplit))
       .setVectorTransferToSCFOptions(
           VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
-  strategy.transform(getFunction());
-}
-
-template <typename OpType>
-void TestLinalgCodegenStrategy::runStrategy(
-    LinalgTilingOptions tilingOptions,
-    LinalgTilingOptions registerTilingOptions,
-    vector::VectorContractLowering vectorContractLowering,
-    vector::VectorTransferSplit vectorTransferSplit) {
-  CodegenStrategy strategy;
-  strategy.tileIf<OpType>(!tileSizes.empty(), tilingOptions)
-      .template promoteIf<OpType>(
-          promote, LinalgPromotionOptions()
-                       .setAlignment(16)
-                       .setUseFullTileBuffersByDefault(promoteFullTile))
-      .template tileIf<OpType>(!registerTileSizes.empty(),
-                               registerTilingOptions)
-      .template promoteIf<OpType>(
-          registerPromote,
-          LinalgPromotionOptions()
-              .setAlignment(16)
-              .setUseFullTileBuffersByDefault(registerPromoteFullTile))
-      .template vectorizeIf<OpType>(vectorize)
-      .setEnableVectorTransferPartialRewrite(true)
-      .setEnableVectorContractLowering(true)
-      .setEnableVectorToSCFConversion(true)
-      .setVectorTransformsOptions(
-          vector::VectorTransformsOptions()
-              .setVectorTransformsOptions(vectorContractLowering)
-              .setVectorTransferSplit(vectorTransferSplit))
-      .setVectorTransferToSCFOptions(
-          VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
-  strategy.transform(getFunction());
+  (void)strategy.transform(getFunction());
 }
 } // end anonymous namespace
 
@@ -224,14 +188,8 @@ void TestLinalgCodegenStrategy::runOnFunction() {
           .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
           .Default(vector::VectorTransferSplit::None);
 
-  // If no anchorOpNameis specified, just test that strategy applies properly to
-  // linalg::MatmulOp.
-  if (anchorOpName.empty())
-    runStrategy<linalg::MatmulOp>(tilingOptions, registerTilingOptions,
-                                  vectorContractLowering, vectorTransferSplit);
-  else
-    runStrategy<LinalgOp>(tilingOptions, registerTilingOptions,
-                          vectorContractLowering, vectorTransferSplit);
+  runStrategy(tilingOptions, registerTilingOptions, vectorContractLowering,
+              vectorTransferSplit);
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list