[Mlir-commits] [mlir] 23800b0 - [mlir][linalg] Add loop interchange to CodegenStrategy.

Tobias Gysi llvmlistbot at llvm.org
Thu Oct 7 23:49:00 PDT 2021


Author: Tobias Gysi
Date: 2021-10-08T06:39:22Z
New Revision: 23800b05be2be19451f6930ab7bd56aad4dab5e1

URL: https://github.com/llvm/llvm-project/commit/23800b05be2be19451f6930ab7bd56aad4dab5e1
DIFF: https://github.com/llvm/llvm-project/commit/23800b05be2be19451f6930ab7bd56aad4dab5e1.diff

LOG: [mlir][linalg] Add loop interchange to CodegenStrategy.

Add a loop interchange pass and integrate it with CodegenStrategy.

This patch depends on https://reviews.llvm.org/D110728 and https://reviews.llvm.org/D110746.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110748

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/test/Dialect/Linalg/codegen-strategy.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 1b318e2ed1e3..3633fba85393 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -101,6 +101,12 @@ createLinalgStrategyGeneralizePass(StringRef opName = "",
                                    linalg::LinalgTransformationFilter filter =
                                        linalg::LinalgTransformationFilter());
 
+/// Create a LinalgStrategyInterchangePass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange = {},
+                                    linalg::LinalgTransformationFilter filter =
+                                        linalg::LinalgTransformationFilter());
+
 /// Create a LinalgStrategyVectorizePass.
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgStrategyVectorizePass(StringRef opName = "",

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2494248826f0..1df4458753db 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -268,6 +268,17 @@ def LinalgStrategyGeneralizePass
   ];
 }
 
+def LinalgStrategyInterchangePass
+    : FunctionPass<"linalg-strategy-interchange-pass"> {
+  let summary = "Configurable pass to apply pattern-based iterator interchange.";
+  let constructor = "mlir::createLinalgStrategyInterchangePass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+  ];
+}
+
 def LinalgStrategyVectorizePass
     : FunctionPass<"linalg-strategy-vectorize-pass"> {
   let summary = "Configurable pass to apply pattern-based linalg vectorization.";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index d8e8674452f2..1745b797e547 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -77,6 +77,22 @@ struct Generalize : public Transformation {
   std::string opName;
 };
 
+/// Represent one application of createLinalgStrategyInterchangePass.
+struct Interchange : public Transformation {
+  explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
+                       LinalgTransformationFilter::FilterFunction f = nullptr)
+      : Transformation(f), iteratorInterchange(iteratorInterchange.begin(),
+                                               iteratorInterchange.end()) {}
+
+  void addToPassPipeline(OpPassManager &pm,
+                         LinalgTransformationFilter m) const override {
+    pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
+  }
+
+private:
+  SmallVector<int64_t> iteratorInterchange;
+};
+
 /// Represent one application of createLinalgStrategyVectorizePass.
 struct Vectorize : public Transformation {
   explicit Vectorize(linalg::LinalgVectorizationOptions options,
@@ -147,6 +163,21 @@ struct CodegenStrategy {
     return b ? generalize(opName, f) : *this;
     return *this;
   }
+  /// Append a pattern to interchange iterators.
+  CodegenStrategy &
+  interchange(ArrayRef<int64_t> iteratorInterchange,
+              LinalgTransformationFilter::FilterFunction f = nullptr) {
+    transformationSequence.emplace_back(
+        std::make_unique<Interchange>(iteratorInterchange, f));
+    return *this;
+  }
+  /// Conditionally append a pattern to interchange iterators.
+  CodegenStrategy &
+  interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
+                LinalgTransformationFilter::FilterFunction f = nullptr) {
+    return b ? interchange(iteratorInterchange, f) : *this;
+    return *this;
+  }
   /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
   CodegenStrategy &
   vectorize(StringRef opName,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 7dd112e69229..e21785d5cbb4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -101,6 +101,37 @@ struct LinalgStrategyGeneralizePass
   LinalgTransformationFilter filter;
 };
 
+/// Configurable pass to apply pattern-based linalg generalization.
+struct LinalgStrategyInterchangePass
+    : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
+
+  LinalgStrategyInterchangePass() = default;
+
+  LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
+                                LinalgTransformationFilter filter)
+      : iteratorInterchange(iteratorInterchange.begin(),
+                            iteratorInterchange.end()),
+        filter(filter) {}
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
+                                            iteratorInterchange.end());
+    RewritePatternSet interchangePattern(funcOp.getContext());
+    interchangePattern.add<GenericOpInterchangePattern>(
+        funcOp.getContext(), interchangeVector, filter);
+    if (failed(applyPatternsAndFoldGreedily(funcOp,
+                                            std::move(interchangePattern))))
+      signalPassFailure();
+  }
+
+  SmallVector<int64_t> iteratorInterchange;
+  LinalgTransformationFilter filter;
+};
+
 /// Configurable pass to apply pattern-based linalg promotion.
 struct LinalgStrategyPromotePass
     : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
@@ -273,6 +304,14 @@ mlir::createLinalgStrategyGeneralizePass(StringRef opName,
   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
 }
 
+/// Create a LinalgStrategyInterchangePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
+                                          LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
+                                                         filter);
+}
+
 /// Create a LinalgStrategyVectorizePass.
 std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLinalgStrategyVectorizePass(StringRef opName,

diff  --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index 7328e4550006..33ec6279d619 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -4,7 +4,7 @@
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize iterator-interchange=0,2,1" | FileCheck %s --check-prefix=GENER
 
 
 // CHECK-LABEL: func @matmul(
@@ -19,8 +19,10 @@ func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<15
   // CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32}
   // CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
 
-  // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
-  // GENER: linalg.generic
+  //      OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
+
+  //      GENER: linalg.generic
+  // GENER-SAME: iterator_types = ["parallel", "reduction", "parallel"]
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 7906c1c9afba..ff213d32ddfc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -89,6 +89,9 @@ struct TestLinalgCodegenStrategy
   Option<bool> generalize{*this, "generalize",
                           llvm::cl::desc("Generalize named operations."),
                           llvm::cl::init(false)};
+  ListOption<int64_t> iteratorInterchange{
+      *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc("Specifies the iterator interchange.")};
   Option<bool> vectorize{
       *this, "vectorize",
       llvm::cl::desc("Rewrite the linalg op as a vector operation."),
@@ -148,6 +151,7 @@ void TestLinalgCodegenStrategy::runStrategy(
                      .setAlignment(16)
                      .setUseFullTileBuffersByDefault(registerPromoteFullTile))
       .generalizeIf(generalize, anchorOpName)
+      .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
       .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
       .setEnableVectorTransferPartialRewrite(true)
       .setEnableVectorContractLowering(true)


        


More information about the Mlir-commits mailing list