[Mlir-commits] [mlir] 23800b0 - [mlir][linalg] Add loop interchange to CodegenStrategy.
Tobias Gysi
llvmlistbot at llvm.org
Thu Oct 7 23:49:00 PDT 2021
Author: Tobias Gysi
Date: 2021-10-08T06:39:22Z
New Revision: 23800b05be2be19451f6930ab7bd56aad4dab5e1
URL: https://github.com/llvm/llvm-project/commit/23800b05be2be19451f6930ab7bd56aad4dab5e1
DIFF: https://github.com/llvm/llvm-project/commit/23800b05be2be19451f6930ab7bd56aad4dab5e1.diff
LOG: [mlir][linalg] Add loop interchange to CodegenStrategy.
Add a loop interchange pass and integrate it with CodegenStrategy.
This patch depends on https://reviews.llvm.org/D110728 and https://reviews.llvm.org/D110746.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D110748
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/test/Dialect/Linalg/codegen-strategy.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 1b318e2ed1e3..3633fba85393 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -101,6 +101,12 @@ createLinalgStrategyGeneralizePass(StringRef opName = "",
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
+/// Create a LinalgStrategyInterchangePass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange = {},
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
+
/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyVectorizePass(StringRef opName = "",
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2494248826f0..1df4458753db 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -268,6 +268,17 @@ def LinalgStrategyGeneralizePass
];
}
+def LinalgStrategyInterchangePass
+ : FunctionPass<"linalg-strategy-interchange-pass"> {
+ let summary = "Configurable pass to apply pattern-based iterator interchange.";
+ let constructor = "mlir::createLinalgStrategyInterchangePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
def LinalgStrategyVectorizePass
: FunctionPass<"linalg-strategy-vectorize-pass"> {
let summary = "Configurable pass to apply pattern-based linalg vectorization.";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index d8e8674452f2..1745b797e547 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -77,6 +77,22 @@ struct Generalize : public Transformation {
std::string opName;
};
+/// Represent one application of createLinalgStrategyInterchangePass.
+struct Interchange : public Transformation {
+ explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
+ LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), iteratorInterchange(iteratorInterchange.begin(),
+ iteratorInterchange.end()) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
+ }
+
+private:
+ SmallVector<int64_t> iteratorInterchange;
+};
+
/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
explicit Vectorize(linalg::LinalgVectorizationOptions options,
@@ -147,6 +163,21 @@ struct CodegenStrategy {
return b ? generalize(opName, f) : *this;
return *this;
}
+ /// Append a pattern to interchange iterators.
+ CodegenStrategy &
+ interchange(ArrayRef<int64_t> iteratorInterchange,
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Interchange>(iteratorInterchange, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to interchange iterators.
+ CodegenStrategy &
+ interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? interchange(iteratorInterchange, f) : *this;
+ return *this;
+ }
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 7dd112e69229..e21785d5cbb4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -101,6 +101,37 @@ struct LinalgStrategyGeneralizePass
LinalgTransformationFilter filter;
};
+/// Configurable pass to apply pattern-based linalg generalization.
+struct LinalgStrategyInterchangePass
+ : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
+
+ LinalgStrategyInterchangePass() = default;
+
+ LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
+ LinalgTransformationFilter filter)
+ : iteratorInterchange(iteratorInterchange.begin(),
+ iteratorInterchange.end()),
+ filter(filter) {}
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
+ iteratorInterchange.end());
+ RewritePatternSet interchangePattern(funcOp.getContext());
+ interchangePattern.add<GenericOpInterchangePattern>(
+ funcOp.getContext(), interchangeVector, filter);
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(interchangePattern))))
+ signalPassFailure();
+ }
+
+ SmallVector<int64_t> iteratorInterchange;
+ LinalgTransformationFilter filter;
+};
+
/// Configurable pass to apply pattern-based linalg promotion.
struct LinalgStrategyPromotePass
: public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
@@ -273,6 +304,14 @@ mlir::createLinalgStrategyGeneralizePass(StringRef opName,
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}
+/// Create a LinalgStrategyInterchangePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
+ filter);
+}
+
/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyVectorizePass(StringRef opName,
diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index 7328e4550006..33ec6279d619 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -4,7 +4,7 @@
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize iterator-interchange=0,2,1" | FileCheck %s --check-prefix=GENER
// CHECK-LABEL: func @matmul(
@@ -19,8 +19,10 @@ func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<15
// CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32}
// CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
- // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
- // GENER: linalg.generic
+ // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
+
+ // GENER: linalg.generic
+ // GENER-SAME: iterator_types = ["parallel", "reduction", "parallel"]
return
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 7906c1c9afba..ff213d32ddfc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -89,6 +89,9 @@ struct TestLinalgCodegenStrategy
Option<bool> generalize{*this, "generalize",
llvm::cl::desc("Generalize named operations."),
llvm::cl::init(false)};
+ ListOption<int64_t> iteratorInterchange{
+ *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc("Specifies the iterator interchange.")};
Option<bool> vectorize{
*this, "vectorize",
llvm::cl::desc("Rewrite the linalg op as a vector operation."),
@@ -148,6 +151,7 @@ void TestLinalgCodegenStrategy::runStrategy(
.setAlignment(16)
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
.generalizeIf(generalize, anchorOpName)
+ .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
.setEnableVectorTransferPartialRewrite(true)
.setEnableVectorContractLowering(true)
More information about the Mlir-commits
mailing list