[Mlir-commits] [mlir] 1ebd197 - [mlir][linalg] Add generalization to CodegenStrategy.

Tobias Gysi llvmlistbot at llvm.org
Thu Oct 7 23:31:53 PDT 2021


Author: Tobias Gysi
Date: 2021-10-08T06:31:19Z
New Revision: 1ebd197bc53b25aff3d5e8cc0a7c00c5fe8c1223

URL: https://github.com/llvm/llvm-project/commit/1ebd197bc53b25aff3d5e8cc0a7c00c5fe8c1223
DIFF: https://github.com/llvm/llvm-project/commit/1ebd197bc53b25aff3d5e8cc0a7c00c5fe8c1223.diff

LOG: [mlir][linalg] Add generalization to CodegenStrategy.

Add a generalization pass and integrate it with CodegenStrategy.

This patch depends on https://reviews.llvm.org/D110728.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110746

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/test/Dialect/Linalg/codegen-strategy.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 867921c51a51e..1b318e2ed1e3e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -95,6 +95,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPromotePass(
     linalg::LinalgTransformationFilter filter =
         linalg::LinalgTransformationFilter());
 
+/// Create a LinalgStrategyGeneralizePass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyGeneralizePass(StringRef opName = "",
+                                   linalg::LinalgTransformationFilter filter =
+                                       linalg::LinalgTransformationFilter());
+
 /// Create a LinalgStrategyVectorizePass.
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgStrategyVectorizePass(StringRef opName = "",

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 32327cd968096..2494248826f09 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -255,6 +255,19 @@ def LinalgStrategyPromotePass
   ];
 }
 
+def LinalgStrategyGeneralizePass
+    : FunctionPass<"linalg-strategy-generalize-pass"> {
+  let summary = "Configurable pass to apply pattern-based generalization.";
+  let constructor = "mlir::createLinalgStrategyGeneralizePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+    Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+      "Which linalg op within the func is the anchor to latch on.">,
+  ];
+}
+
 def LinalgStrategyVectorizePass
     : FunctionPass<"linalg-strategy-vectorize-pass"> {
   let summary = "Configurable pass to apply pattern-based linalg vectorization.";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index ff372079f690d..d8e8674452f24 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -62,6 +62,21 @@ struct Promote : public Transformation {
   linalg::LinalgPromotionOptions options;
 };
 
+/// Represent one application of createLinalgStrategyGeneralizePass.
+struct Generalize : public Transformation {
+  explicit Generalize(StringRef name,
+                      LinalgTransformationFilter::FilterFunction f = nullptr)
+      : Transformation(f), opName(name) {}
+
+  void addToPassPipeline(OpPassManager &pm,
+                         LinalgTransformationFilter m) const override {
+    pm.addPass(createLinalgStrategyGeneralizePass(opName, m));
+  }
+
+private:
+  std::string opName;
+};
+
 /// Represent one application of createLinalgStrategyVectorizePass.
 struct Vectorize : public Transformation {
   explicit Vectorize(linalg::LinalgVectorizationOptions options,
@@ -117,6 +132,21 @@ struct CodegenStrategy {
     return b ? promote(opName, options, f) : *this;
     return *this;
   }
+  /// Append a pattern to generalize named operations.
+  CodegenStrategy &
+  generalize(StringRef opName,
+             LinalgTransformationFilter::FilterFunction f = nullptr) {
+    transformationSequence.emplace_back(
+        std::make_unique<Generalize>(opName, f));
+    return *this;
+  }
+  /// Conditionally append a pattern to generalize named operations.
+  CodegenStrategy &
+  generalizeIf(bool b, StringRef opName,
+               LinalgTransformationFilter::FilterFunction f = nullptr) {
+    return b ? generalize(opName, f) : *this;
+    return *this;
+  }
   /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
   CodegenStrategy &
   vectorize(StringRef opName,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index db0a24d4f4b3f..7dd112e69229e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -68,6 +68,39 @@ struct LinalgStrategyTilePass
   LinalgTransformationFilter filter;
 };
 
+/// Configurable pass to apply pattern-based linalg generalization.
+struct LinalgStrategyGeneralizePass
+    : public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
+
+  LinalgStrategyGeneralizePass() = default;
+
+  LinalgStrategyGeneralizePass(StringRef opName,
+                               LinalgTransformationFilter filter)
+      : filter(filter) {
+    this->anchorOpName.setValue(opName.str());
+  }
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    RewritePatternSet generalizationPattern(funcOp.getContext());
+    if (!anchorOpName.empty()) {
+      generalizationPattern.add<LinalgGeneralizationPattern>(
+          anchorOpName, funcOp.getContext(), filter);
+    } else {
+      generalizationPattern.add<LinalgGeneralizationPattern>(
+          funcOp.getContext(), filter);
+    }
+    if (failed(applyPatternsAndFoldGreedily(funcOp,
+                                            std::move(generalizationPattern))))
+      signalPassFailure();
+  }
+
+  LinalgTransformationFilter filter;
+};
+
 /// Configurable pass to apply pattern-based linalg promotion.
 struct LinalgStrategyPromotePass
     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
@@ -233,6 +266,13 @@ mlir::createLinalgStrategyPromotePass(StringRef opName,
   return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
 }
 
+/// Create a LinalgStrategyGeneralizePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyGeneralizePass(StringRef opName,
+                                         LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
+}
+
 /// Create a LinalgStrategyVectorizePass.
 std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLinalgStrategyVectorizePass(StringRef opName,

diff  --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index bbfdaf371a8c7..7328e45500068 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -4,9 +4,12 @@
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER
+
 
 // CHECK-LABEL: func @matmul(
 // OUTER-LABEL: func @matmul(
+// GENER-LABEL: func @matmul(
 func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
   linalg.matmul
    ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
@@ -17,6 +20,7 @@ func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<15
   // CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
 
   // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
+  // GENER: linalg.generic
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 679cc9375aea3..7906c1c9afbae 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -86,6 +86,9 @@ struct TestLinalgCodegenStrategy
       *this, "register-promote-full-tile-pad",
       llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
       llvm::cl::init(false)};
+  Option<bool> generalize{*this, "generalize",
+                          llvm::cl::desc("Generalize named operations."),
+                          llvm::cl::init(false)};
   Option<bool> vectorize{
       *this, "vectorize",
       llvm::cl::desc("Rewrite the linalg op as a vector operation."),
@@ -133,6 +136,7 @@ void TestLinalgCodegenStrategy::runStrategy(
     vector::VectorTransferSplit vectorTransferSplit) {
   assert(!anchorOpName.empty());
   CodegenStrategy strategy;
+  StringRef genericOpName = GenericOp::getOperationName();
   strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
       .promoteIf(promote, anchorOpName,
                  LinalgPromotionOptions()
@@ -143,7 +147,8 @@ void TestLinalgCodegenStrategy::runStrategy(
                  LinalgPromotionOptions()
                      .setAlignment(16)
                      .setUseFullTileBuffersByDefault(registerPromoteFullTile))
-      .vectorizeIf(vectorize, anchorOpName)
+      .generalizeIf(generalize, anchorOpName)
+      .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
       .setEnableVectorTransferPartialRewrite(true)
       .setEnableVectorContractLowering(true)
       .setEnableVectorToSCFConversion(true)


        


More information about the Mlir-commits mailing list