[Mlir-commits] [mlir] 92ea624 - [mlir][Linalg] Rewrite CodegenStrategy to populate a pass pipeline.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Sep 29 06:39:04 PDT 2021
Author: Nicolas Vasilache
Date: 2021-09-29T13:35:45Z
New Revision: 92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21
URL: https://github.com/llvm/llvm-project/commit/92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21
DIFF: https://github.com/llvm/llvm-project/commit/92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21.diff
LOG: [mlir][Linalg] Rewrite CodegenStrategy to populate a pass pipeline.
This revision retires a good portion of the complexity of the codegen strategy and puts the logic behind pass logic.
Differential revision: https://reviews.llvm.org/D110678
Added:
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Linalg/codegen-strategy.mlir
mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 56c709b543517..867921c51a51e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_LINALG_PASSES_H_
#define MLIR_DIALECT_LINALG_PASSES_H_
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
@@ -77,6 +78,43 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Create a pass to tile a LinalgOp and fuse its producers.
std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFuseTensorOpsPass();
+//===----------------------------------------------------------------------===//
+/// Linalg strategy passes.
+//===----------------------------------------------------------------------===//
+/// Create a LinalgStrategyTilePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTilePass(
+ StringRef opName = "",
+ linalg::LinalgTilingOptions opt = linalg::LinalgTilingOptions(),
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyPromotePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPromotePass(
+ StringRef opName = "",
+ linalg::LinalgPromotionOptions opt = linalg::LinalgPromotionOptions(),
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyVectorizePass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyVectorizePass(StringRef opName = "",
+ linalg::LinalgVectorizationOptions opt =
+ linalg::LinalgVectorizationOptions(),
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyEnablePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyEnablePass(
+ linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(),
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyLowerVectorsPass.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt =
+ linalg::LinalgVectorLoweringOptions(),
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3f331b1fff502..32327cd968096 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -229,4 +229,66 @@ def LinalgTileAndFuseTensorOps
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"];
}
+def LinalgStrategyTilePass
+ : FunctionPass<"linalg-strategy-tile-pass"> {
+ let summary = "Configurable pass to apply pattern-based linalg tiling.";
+ let constructor = "mlir::createLinalgStrategyTilePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyPromotePass
+ : FunctionPass<"linalg-strategy-promote-pass"> {
+ let summary = "Configurable pass to apply pattern-based linalg promotion.";
+ let constructor = "mlir::createLinalgStrategyPromotePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyVectorizePass
+ : FunctionPass<"linalg-strategy-vectorize-pass"> {
+ let summary = "Configurable pass to apply pattern-based linalg vectorization.";
+ let constructor = "mlir::createLinalgStrategyVectorizePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyEnablePass
+ : FunctionPass<"linalg-strategy-enable-pass"> {
+ let summary = "Configurable pass to enable the application of other "
+ "pattern-based linalg passes.";
+ let constructor = "mlir::createLinalgStrategyEnablePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyLowerVectorsPass
+ : FunctionPass<"linalg-strategy-lower-vectors-pass"> {
+ let summary = "Configurable pass to lower vector operations.";
+ let constructor = "mlir::createLinalgStrategyLowerVectorsPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index d33dd81fa323a..ff372079f690d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -10,7 +10,8 @@
#define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Pass/PassManager.h"
namespace mlir {
@@ -21,69 +22,23 @@ namespace linalg {
/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
- explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
+ explicit Transformation(LinalgTransformationFilter::FilterFunction f)
: filter(f) {}
virtual ~Transformation() = default;
- virtual RewritePatternSet
- buildRewritePatterns(MLIRContext *context,
- linalg::LinalgTransformationFilter m) = 0;
- linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
+ virtual void addToPassPipeline(OpPassManager &pm,
+ LinalgTransformationFilter m) const = 0;
+ LinalgTransformationFilter::FilterFunction filter = nullptr;
};
-/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`.
-template <template <typename> class PatternType, typename ConcreteOpType,
- typename OptionsType,
- typename = std::enable_if_t<std::is_member_function_pointer<
- decltype(&ConcreteOpType::getOperationName)>::value>>
-void sfinae_enqueue(RewritePatternSet &patternList, OptionsType options,
- StringRef opName, linalg::LinalgTransformationFilter m) {
- assert(opName == ConcreteOpType::getOperationName() &&
- "explicit name must match ConcreteOpType::getOperationName");
- patternList.add<PatternType<ConcreteOpType>>(patternList.getContext(),
- options, m);
-}
-
-/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
-/// (e.g. LinalgOp, other interfaces, Operation*).
-template <template <typename> class PatternType, typename OpType,
- typename OptionsType>
-void sfinae_enqueue(RewritePatternSet &patternList, OptionsType options,
- StringRef opName, linalg::LinalgTransformationFilter m) {
- assert(!opName.empty() && "opName must not be empty");
- patternList.add<PatternType<OpType>>(opName, patternList.getContext(),
- options, m);
-}
-
-template <typename PatternType, typename OpType, typename OptionsType>
-void enqueue(RewritePatternSet &patternList, OptionsType options,
- StringRef opName, linalg::LinalgTransformationFilter m) {
- if (!opName.empty())
- patternList.add<PatternType>(opName, patternList.getContext(), options, m);
- else
- patternList.add<PatternType>(patternList.getContext(),
- m.addOpFilter<OpType>(), options);
-}
-
-/// Promotion transformation enqueues a particular stage-1 pattern for
-/// `Tile<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType>
+/// Represent one application of LinalgStrategyTilePass.
struct Tile : public Transformation {
- explicit Tile(linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(LinalgOpType::getOperationName()),
- options(options) {}
-
Tile(StringRef name, linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
- RewritePatternSet
- buildRewritePatterns(MLIRContext *context,
- linalg::LinalgTransformationFilter m) override {
- RewritePatternSet tilingPatterns(context);
- sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
- tilingPatterns, options, opName, m);
- return tilingPatterns;
+ void addToPassPipeline(OpPassManager &pm,
+ LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyTilePass(opName, options, m));
}
private:
@@ -91,27 +46,15 @@ struct Tile : public Transformation {
linalg::LinalgTilingOptions options;
};
-/// Promotion transformation enqueues a particular stage-1 pattern for
-/// `Promote<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType>
+/// Represent one application of createLinalgStrategyPromotePass.
struct Promote : public Transformation {
- explicit Promote(
- linalg::LinalgPromotionOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(f), opName(LinalgOpType::getOperationName()),
- options(options) {}
-
Promote(StringRef name, linalg::LinalgPromotionOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
- RewritePatternSet
- buildRewritePatterns(MLIRContext *context,
- linalg::LinalgTransformationFilter m) override {
- RewritePatternSet promotionPatterns(context);
- sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
- promotionPatterns, options, opName, m);
- return promotionPatterns;
+ void addToPassPipeline(OpPassManager &pm,
+ LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyPromotePass(opName, options, m));
}
private:
@@ -119,30 +62,19 @@ struct Promote : public Transformation {
linalg::LinalgPromotionOptions options;
};
-/// Vectorization transformation enqueues a particular stage-1 pattern for
-/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
-/// transfer rewrite forwarding patterns.
-template <typename LinalgOpType = LinalgOp>
+/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
- explicit Vectorize(
- linalg::LinalgVectorizationOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ explicit Vectorize(linalg::LinalgVectorizationOptions options,
+ LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(), options(options) {}
Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
- RewritePatternSet
- buildRewritePatterns(MLIRContext *context,
- linalg::LinalgTransformationFilter m) override {
- RewritePatternSet vectorizationPatterns(context);
- enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
- vectorizationPatterns, options, opName, m);
- vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
- linalg::LinalgCopyVTWForwardingPattern>(
- context, /*benefit=*/2);
- return vectorizationPatterns;
+ void addToPassPipeline(OpPassManager &pm,
+ LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyVectorizePass(opName, options, m));
}
private:
@@ -150,129 +82,47 @@ struct Vectorize : public Transformation {
linalg::LinalgVectorizationOptions options;
};
-/// Options to control the application of late transformations.
-struct LateCodegenStrategyOptions {
- /// Hoisting transformations are always deemed beneficial and must disabled
- /// explicitly.
- bool enableLICM = true;
- bool enableHoistRedundantVectorTransfers = true;
- bool enableHoistRedundantVectorTransfersOnTensor = true;
- /// Vector lowering operations may result in surprising behavior when
- /// composing multiple codegen strategies and must be enabled explicitly.
- bool enableVectorTransferPartialRewrite = false;
- bool enableVectorContractLowering = false;
- bool enableVectorToSCFConversion = false;
-};
-
/// Codegen strategy controls how a Linalg op is progressively lowered.
-/// The application uses a 3-level staged patterns strategy which allows
-/// ordering transformations by using the Linalg `applyStagedPatterns`
-/// function, where:
-/// 1. The first stage consists of the successive `tile`, `promote` and
-/// `vectorize` patterns, applied sequentially.
-/// 2. The second stage consists of common local canonicalization patterns
-/// that are applied eagerly after each stage-1 pattern.
-/// 3. the third stage consists of more global transformation, also applied
-/// eagerly, after all stage-2 patterns. Such more global transformations
struct CodegenStrategy {
- /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
+ /// Append a pattern to add a level of tiling for Op `opName` with tiling
/// `options`.
- template <typename LinalgOpType>
- CodegenStrategy &
- tile(linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- transformationSequence.emplace_back(
- std::make_unique<Tile<LinalgOpType>>(options, f));
- return *this;
- }
- /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
- /// `options`.
- template <typename LinalgOpType>
CodegenStrategy &
tile(StringRef opName, linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
- std::make_unique<Tile<LinalgOpType>>(opName, options, f));
+ std::make_unique<Tile>(opName, options, f));
return *this;
}
/// Conditionally append a pattern to add a level of tiling for
/// `LinalgOpType` with tiling `options`.
- template <typename LinalgOpType>
- CodegenStrategy &
- tileIf(bool b, linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? tile<LinalgOpType>(options) : *this;
- }
- /// Conditionally append a pattern to add a level of tiling for
- /// `LinalgOpType` with tiling `options`.
- template <typename LinalgOpType>
CodegenStrategy &
tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? tile<LinalgOpType>(opName, options) : *this;
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? tile(opName, options) : *this;
}
/// Append a pattern to add a level of promotion for `LinalgOpType` with
/// promotion `options`.
- template <typename LinalgOpType>
- CodegenStrategy &
- promote(linalg::LinalgPromotionOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- transformationSequence.emplace_back(
- std::make_unique<Promote<LinalgOpType>>(options, f));
- return *this;
- }
- /// Append a pattern to add a level of promotion for `LinalgOpType` with
- /// promotion `options`.
- template <typename LinalgOpType>
CodegenStrategy &
promote(StringRef opName, linalg::LinalgPromotionOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
- std::make_unique<Promote<LinalgOpType>>(opName, options, f));
+ std::make_unique<Promote>(opName, options, f));
return *this;
}
/// Conditionally append a pattern to add a level of promotion for
/// `LinalgOpType` with promotion `options`.
- template <typename LinalgOpType>
CodegenStrategy &
promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? promote<LinalgOpType>(opName, options, f) : *this;
- return *this;
- }
- /// Conditionally append a pattern to add a level of promotion for
- /// `LinalgOpType` with promotion `options`.
- template <typename LinalgOpType>
- CodegenStrategy &
- promoteIf(bool b, linalg::LinalgPromotionOptions options,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? promote<LinalgOpType>(options, f) : *this;
- return *this;
- }
- /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
- template <typename LinalgOpType>
- CodegenStrategy &
- vectorize(linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- transformationSequence.emplace_back(
- std::make_unique<Vectorize<LinalgOpType>>(
- linalg::LinalgVectorizationOptions(), f));
- return *this;
- }
- /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
- /// operation.
- template <typename LinalgOpType>
- CodegenStrategy &
- vectorizeIf(bool b,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? vectorize<LinalgOpType>(f) : *this;
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? promote(opName, options, f) : *this;
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
assert(!opName.empty() && "expected an op name");
- transformationSequence.emplace_back(std::make_unique<Vectorize<LinalgOp>>(
+ transformationSequence.emplace_back(std::make_unique<Vectorize>(
opName, linalg::LinalgVectorizationOptions(), f));
return *this;
}
@@ -280,14 +130,14 @@ struct CodegenStrategy {
/// operation.
CodegenStrategy &
vectorizeIf(bool b, StringRef opName,
- linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? vectorize(opName, f) : *this;
return *this;
}
/// Configure the post staged-patterns late vector transformations.
CodegenStrategy &
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
- vectorTransformsOptions = options;
+ vectorTransformOptions = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
@@ -328,12 +178,13 @@ struct CodegenStrategy {
/// Apply the transformation patterns in sequence with cleanup
/// transformations interleaved.
- void transform(FuncOp func) const;
+ LogicalResult transform(FuncOp func) const;
+ void configurePassPipeline(OpPassManager &pm, MLIRContext *context) const;
private:
LogicalResult postPatternTransforms(Operation *func) const;
- vector::VectorTransformsOptions vectorTransformsOptions;
+ vector::VectorTransformsOptions vectorTransformOptions;
VectorTransferToSCFOptions vectorToSCFOptions;
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
LateCodegenStrategyOptions lateCodegenStrategyOptions;
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 03843bdfaa012..4a76b92573209 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -9,6 +9,7 @@
#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils.h"
@@ -593,6 +594,35 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
}
};
+struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
+ /// Entry point to match any LinalgOp OpInterface.
+ /// MatchAnyOpTag-based constructor with a mandatory `filter`.
+ LinalgGenericTilingPattern(
+ MLIRContext *context, LinalgTransformationFilter filter,
+ LinalgTilingOptions options = LinalgTilingOptions(),
+ PatternBenefit benefit = 1)
+ : LinalgBaseTilingPattern(context, options, filter, benefit) {}
+ /// Entry point to match a specific Linalg op.
+ LinalgGenericTilingPattern(
+ StringRef opName, MLIRContext *context, LinalgTilingOptions options,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ TiledLinalgOp tiledLinalgOp;
+ if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
+ tiledLinalgOp)))
+ return failure();
+ if (tiledLinalgOp.tensorResults.empty())
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
+ return success();
+ }
+};
+
struct LinalgFusionOptions {
/// List of operands indices to use for fusion.
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
@@ -678,6 +708,13 @@ struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViews` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
+ /// Entry point to match any LinalgOp OpInterface.
+ /// MatchAnyOpTag-based constructor with a mandatory `filter`.
+ LinalgBasePromotionPattern(
+ MLIRContext *context, LinalgTransformationFilter filter,
+ LinalgPromotionOptions options = LinalgPromotionOptions(),
+ PatternBenefit benefit = 1);
+ /// Entry point to match a specific Linalg op.
LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
@@ -757,6 +794,39 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
};
+/// Options to control the application of late transformations.
+struct LateCodegenStrategyOptions {
+ /// Hoisting transformations are always deemed beneficial and must disabled
+ /// explicitly.
+ bool enableLICM = true;
+ bool enableHoistRedundantVectorTransfers = true;
+ bool enableHoistRedundantVectorTransfersOnTensor = true;
+ /// Vector lowering operations may result in surprising behavior when
+ /// composing multiple codegen strategies and must be enabled explicitly.
+ bool enableVectorTransferPartialRewrite = false;
+ bool enableVectorContractLowering = false;
+ bool enableVectorToSCFConversion = false;
+};
+
+/// Options to control the application of enabling transformations.
+/// Hoisting transformations are always deemed beneficial and must be disabled
+/// explicitly.
+struct LinalgEnablingOptions {
+ bool enableLICM = true;
+ bool enableHoistRedundantVectorTransfers = true;
+ bool enableHoistRedundantVectorTransfersOnTensor = true;
+};
+
+/// Vector lowering options control how ops are lowered down to 1-D and scf.for
+/// form.
+struct LinalgVectorLoweringOptions {
+ bool enableVectorTransferPartialRewrite = false;
+ bool enableVectorContractLowering = false;
+ bool enableVectorToSCFConversion = false;
+ vector::VectorTransformsOptions vectorTransformOptions;
+ VectorTransferToSCFOptions vectorTransferToSCFOptions;
+};
+
/// Trait to check if T provides a `getOperationName` method.
template <typename T, typename... Args>
using has_get_operation_name = decltype(T::getOperationName());
@@ -929,8 +999,8 @@ struct GeneralizePadTensorOpPattern : public OpRewritePattern<PadTensorOp> {
/// scattering magic constants throughout the code base, the patterns must be
/// added with this function. `baseBenefit` can be used to offset the benefit
/// of all PadTensorOp vectorization patterns by a certain value.
-void populatePadTensorOpVectorizationPatterns(
- RewritePatternSet &patterns, PatternBenefit baseBenefit = 1);
+void populatePadTensorOpVectorizationPatterns(RewritePatternSet &patterns,
+ PatternBenefit baseBenefit = 1);
/// Match and rewrite for the pattern:
/// ```
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index a2a9108e83900..41d59c435ec25 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -218,17 +218,17 @@ class ContractionOpToMatmulOpLowering
}
ContractionOpToMatmulOpLowering(
- vector::VectorTransformsOptions vectorTransformsOptions,
+ vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
+ vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformsOptions;
+ vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
@@ -259,17 +259,17 @@ class ContractionOpToOuterProductOpLowering
}
ContractionOpToOuterProductOpLowering(
- vector::VectorTransformsOptions vectorTransformsOptions,
+ vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
+ vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformsOptions;
+ vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
@@ -303,18 +303,17 @@ class ContractionOpToDotLowering
}
ContractionOpToDotLowering(
- vector::VectorTransformsOptions vectorTransformsOptions,
+ vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions),
- filter(defaultFilter) {}
+ vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformsOptions;
+ vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
@@ -342,18 +341,18 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
return success();
}
- ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+ ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
+ vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformsOptions;
+ vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a05cd5858f42b..5dfb419ffd406 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ LinalgStrategyPasses.cpp
Promotion.cpp
Tiling.cpp
Transforms.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 8192551a5c392..cd0e75cc8a17b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@@ -27,85 +28,43 @@ using namespace mlir::linalg;
#define DEBUG_TYPE "linalg-codegen-strategy"
-void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
- MLIRContext *context = func.getContext();
- // Emplace patterns one at a time while also maintaining a simple chained
- // state transition.
- unsigned stepCount = 0;
- SmallVector<FrozenRewritePatternSet, 4> stage1Patterns;
- auto zeroState = Identifier::get(std::to_string(stepCount), context);
- auto currentState = zeroState;
- for (const std::unique_ptr<Transformation> &t : transformationSequence) {
- auto nextState = Identifier::get(std::to_string(++stepCount), context);
- auto marker = (currentState == zeroState)
+void mlir::linalg::CodegenStrategy::configurePassPipeline(
+ OpPassManager &pm, MLIRContext *context) const {
+ for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e;
+ ++stepCount) {
+ const std::unique_ptr<Transformation> &t =
+ transformationSequence[stepCount];
+ std::string currentStr = std::to_string(stepCount);
+ auto currentState = Identifier::get(currentStr, context);
+ std::string nextStr = std::to_string(stepCount + 1);
+ auto nextState = Identifier::get(nextStr, context);
+ auto filter = (currentState.str() == std::to_string(0))
? linalg::LinalgTransformationFilter(
t->filter, ArrayRef<Identifier>{}, nextState)
: linalg::LinalgTransformationFilter(
t->filter, currentState, nextState);
- stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
- currentState = nextState;
- }
-
- RewritePatternSet stage2Patterns =
- linalg::getLinalgTilingCanonicalizationPatterns(context);
- scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns);
-
- auto stage3Transforms = [&](Operation *op) {
- // Some of these may be too aggressive as a stage 3 that is applied on each
- // stage 1 application and may have to be split out to post staged patterns
- // application (in which case they could just be passes, TBD).
- if (lateCodegenStrategyOptions.enableLICM) {
- op->walk([&](LoopLikeOpInterface loopLike) {
- LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
- if (failed(moveLoopInvariantCode(loopLike)))
- llvm_unreachable("unexpected LICM failure");
- });
- }
- promoteSingleIterationLoops(cast<FuncOp>(op));
- if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers)
- hoistRedundantVectorTransfers(cast<FuncOp>(op));
- if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfersOnTensor)
- hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
- return success();
- };
- (void)linalg::applyStagedPatterns(
- func, stage1Patterns, std::move(stage2Patterns), stage3Transforms);
-
- //===--------------------------------------------------------------------===//
- // Post staged patterns transforms
- //===--------------------------------------------------------------------===//
-
- // Programmatic splitting of slow/fast path vector transfers.
- if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
- RewritePatternSet patterns(context);
- patterns.add<vector::VectorTransferFullPartialRewriter>(
- context, vectorTransformsOptions);
- (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
- }
-
- // Programmatic controlled lowering of vector.contract only.
- if (lateCodegenStrategyOptions.enableVectorContractLowering) {
- RewritePatternSet vectorContractLoweringPatterns(context);
- vectorContractLoweringPatterns
- .add<ContractionOpToOuterProductOpLowering,
- ContractionOpToMatmulOpLowering, ContractionOpLowering>(
- vectorTransformsOptions, context);
- vector::populateVectorTransferPermutationMapLoweringPatterns(
- vectorContractLoweringPatterns);
- (void)applyPatternsAndFoldGreedily(
- func, std::move(vectorContractLoweringPatterns));
- }
-
- // Programmatic controlled lowering of vector.transfer only.
- if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
- RewritePatternSet vectorToLoopsPatterns(context);
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
- vectorToSCFOptions);
- (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
+ t->addToPassPipeline(pm, filter);
+ pm.addPass(createLinalgStrategyEnablePass());
}
+ LinalgVectorLoweringOptions vectorLoweringOptions;
+ vectorLoweringOptions.enableVectorTransferPartialRewrite =
+ lateCodegenStrategyOptions.enableVectorTransferPartialRewrite;
+ vectorLoweringOptions.enableVectorContractLowering =
+ lateCodegenStrategyOptions.enableVectorContractLowering;
+ vectorLoweringOptions.enableVectorToSCFConversion =
+ lateCodegenStrategyOptions.enableVectorToSCFConversion;
+ vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions;
+ vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions;
+ pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions));
+}
+LogicalResult mlir::linalg::CodegenStrategy::transform(FuncOp funcOp) const {
+ PassManager pm(funcOp.getContext(), funcOp.getOperationName());
+ configurePassPipeline(pm, funcOp.getContext());
+ LogicalResult res = pm.run(funcOp);
// Ensure we drop the marker in the end.
- func.walk([](LinalgOp op) {
+ funcOp.walk([](LinalgOp op) {
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
+ return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
new file mode 100644
index 0000000000000..db0a24d4f4b3f
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -0,0 +1,256 @@
+//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a configurable pass that can apply patterns liberally
+// and be plugged in a pass pipeline.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+using namespace linalg;
+
+namespace {
+
+/// Configurable pass to apply pattern-based linalg tiling.
+struct LinalgStrategyTilePass
+ : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
+
+ LinalgStrategyTilePass() = default;
+
+ LinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
+ LinalgTransformationFilter filt)
+ : options(opt), filter(filt) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet tilingPattern(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ tilingPattern.add<LinalgGenericTilingPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
+ options);
+ }
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
+ }
+
+ LinalgTilingOptions options;
+ LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg promotion.
+struct LinalgStrategyPromotePass
+ : public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
+
+ LinalgStrategyPromotePass() = default;
+
+ LinalgStrategyPromotePass(StringRef opName, LinalgPromotionOptions opt,
+ LinalgTransformationFilter filt)
+ : options(opt), filter(filt) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet promotionPattern(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ promotionPattern.add<LinalgBasePromotionPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ promotionPattern.add<LinalgBasePromotionPattern>(funcOp.getContext(),
+ filter, options);
+ }
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPattern));
+ }
+
+ LinalgPromotionOptions options;
+ LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg vectorization.
+struct LinalgStrategyVectorizePass
+ : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
+
+ LinalgStrategyVectorizePass() = default;
+
+ LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt,
+ LinalgTransformationFilter filt)
+ : options(opt), filter(filt) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet vectorizationPatterns(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ vectorizationPatterns.add<LinalgVectorizationPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
+ filter, options);
+ }
+ vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
+ linalg::LinalgCopyVTWForwardingPattern>(
+ funcOp.getContext(), /*benefit=*/2);
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorizationPatterns));
+ }
+
+ LinalgVectorizationOptions options;
+ LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to enable the application of other pattern-based linalg
+/// passes.
+struct LinalgStrategyEnablePass
+ : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
+
+ LinalgStrategyEnablePass(LinalgEnablingOptions opt,
+ LinalgTransformationFilter filt)
+ : options(opt), filter(filt) {}
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
+ return signalPassFailure();
+
+ if (options.enableLICM) {
+ if (funcOp
+ ->walk([&](LoopLikeOpInterface loopLike) {
+ if (failed(moveLoopInvariantCode(loopLike)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return signalPassFailure();
+ }
+
+ promoteSingleIterationLoops(funcOp);
+ if (options.enableHoistRedundantVectorTransfers)
+ hoistRedundantVectorTransfers(funcOp);
+
+ if (options.enableHoistRedundantVectorTransfersOnTensor)
+ hoistRedundantVectorTransfersOnTensor(funcOp);
+ }
+
+ LinalgEnablingOptions options;
+ LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to lower vector operations.
+struct LinalgStrategyLowerVectorsPass
+ : public LinalgStrategyLowerVectorsPassBase<
+ LinalgStrategyLowerVectorsPass> {
+
+ LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
+ LinalgTransformationFilter filt)
+ : options(opt), filter(filt) {}
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+ if (options.enableVectorTransferPartialRewrite) {
+ patterns.add<vector::VectorTransferFullPartialRewriter>(
+ context, options.vectorTransformOptions);
+ }
+ if (options.enableVectorContractLowering) {
+ patterns.add<ContractionOpToOuterProductOpLowering,
+ ContractionOpToMatmulOpLowering, ContractionOpLowering>(
+ options.vectorTransformOptions, context);
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ }
+ if (options.enableVectorToSCFConversion) {
+ populateVectorToSCFConversionPatterns(patterns,
+ options.vectorTransferToSCFOptions);
+ }
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ }
+
+ LinalgVectorLoweringOptions options;
+ LinalgTransformationFilter filter;
+};
+} // namespace
+
+/// Create a LinalgStrategyTilePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyPromotePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyPromotePass(StringRef opName,
+ LinalgPromotionOptions opt,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyVectorizePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyVectorizePass(StringRef opName,
+ LinalgVectorizationOptions opt,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyEnablePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyEnablePass(LinalgEnablingOptions opt,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
+}
+
+/// Create a LinalgStrategyLowerVectorsPass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 1d28451ae05ef..c471459da5c25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -488,6 +488,12 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
return success();
}
+mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
+ MLIRContext *context, LinalgTransformationFilter filter,
+ LinalgPromotionOptions options, PatternBenefit benefit)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
+ options(options) {}
+
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 7d6e6d2ba53d5..ba97f99972a06 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -672,10 +672,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
- TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+ TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context)
: OpRewritePattern<vector::TransposeOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
@@ -689,7 +689,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
transp.push_back(attr.cast<IntegerAttr>().getInt());
// Handle a true 2-D matrix transpose
diff erently when requested.
- if (vectorTransformsOptions.vectorTransposeLowering ==
+ if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
Type flattenedType =
@@ -739,7 +739,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
}
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformsOptions;
+ vector::VectorTransformsOptions vectorTransformOptions;
};
/// Progressive lowering of OuterProductOp.
@@ -1151,7 +1151,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement masks
if (llvm::size(op.masks()) != 0)
return failure();
- if (vectorTransformsOptions.vectorContractLowering !=
+ if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
@@ -1314,7 +1314,7 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
if (llvm::size(op.masks()) != 0)
return failure();
- if (vectorTransformsOptions.vectorContractLowering !=
+ if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
@@ -1419,7 +1419,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
if (failed(filter(op)))
return failure();
- if (vectorTransformsOptions.vectorContractLowering !=
+ if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Dot)
return failure();
@@ -1560,13 +1560,13 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
- ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
+ ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
if (succeeded(pat1.matchAndRewrite(op, rewriter)))
return success();
- ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
+ ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
if (succeeded(pat2.matchAndRewrite(op, rewriter)))
return success();
- ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
+ ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
if (succeeded(pat3.matchAndRewrite(op, rewriter)))
return success();
@@ -1835,8 +1835,9 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// Operates under a scoped context to build the intersection between the
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
// TODO: view intersection/union/
diff erences should be a proper std op.
-static std::pair<Value, Value> createSubViewIntersection(
- OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) {
+static std::pair<Value, Value>
+createSubViewIntersection(OpBuilder &b, VectorTransferOpInterface xferOp,
+ Value alloc) {
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
@@ -2195,6 +2196,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
MemRefType compatibleMemRefType =
getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
alloc.getType().cast<MemRefType>());
+ if (!compatibleMemRefType)
+ return failure();
+
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
b.getIndexType());
returnTypes[0] = compatibleMemRefType;
diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index e08d99eb03ef9..bbfdaf371a8c7 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -1,9 +1,9 @@
// Test that both anchor-op name and MatmulOp-based codegen strategy produce the same result.
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 tile-interchange=1,2,0 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 tile-interchange=1,2,0 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// CHECK-LABEL: func @matmul(
// OUTER-LABEL: func @matmul(
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
index 12f44c6dbc4eb..7ddfdaf538fed 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -1,8 +1,8 @@
// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \
-// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
-// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index aed639419e009..bf5cf5a166f21 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -93,7 +93,7 @@ void TestConvVectorization::runOnOperation() {
// Post staged patterns transforms
//===--------------------------------------------------------------------===//
- VectorTransformsOptions vectorTransformsOptions{
+ VectorTransformsOptions vectorTransformOptions{
VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
RewritePatternSet vectorTransferPatterns(context);
@@ -101,7 +101,7 @@ void TestConvVectorization::runOnOperation() {
// supported as can be seen in splitFullAndPartialTransferPrecondition,
// VectorTransforms.cpp
vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
- context, vectorTransformsOptions);
+ context, vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
// Programmatic controlled lowering of linalg.copy and linalg.fill.
@@ -113,9 +113,9 @@ void TestConvVectorization::runOnOperation() {
// Programmatic controlled lowering of vector.contract only.
RewritePatternSet vectorContractLoweringPatterns(context);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
- vectorTransformsOptions);
+ vectorTransformOptions);
populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
- vectorTransformsOptions);
+ vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 8f2cd6c689a0d..679cc9375aea3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -52,7 +52,6 @@ struct TestLinalgCodegenStrategy
void runOnFunction() override;
- template <typename OpType>
void runStrategy(LinalgTilingOptions tilingOptions,
LinalgTilingOptions registerTilingOptions,
vector::VectorContractLowering vectorContractLowering,
@@ -127,26 +126,23 @@ struct TestLinalgCodegenStrategy
llvm::cl::init("")};
};
-template <>
-void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
+void TestLinalgCodegenStrategy::runStrategy(
LinalgTilingOptions tilingOptions,
LinalgTilingOptions registerTilingOptions,
vector::VectorContractLowering vectorContractLowering,
vector::VectorTransferSplit vectorTransferSplit) {
assert(!anchorOpName.empty());
CodegenStrategy strategy;
- strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
- .promoteIf<LinalgOp>(promote, anchorOpName,
- LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(promoteFullTile))
- .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
- registerTilingOptions)
- .promoteIf<LinalgOp>(
- registerPromote, anchorOpName,
- LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+ strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
+ .promoteIf(promote, anchorOpName,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(promoteFullTile))
+ .tileIf(!registerTileSizes.empty(), anchorOpName, registerTilingOptions)
+ .promoteIf(registerPromote, anchorOpName,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(registerPromoteFullTile))
.vectorizeIf(vectorize, anchorOpName)
.setEnableVectorTransferPartialRewrite(true)
.setEnableVectorContractLowering(true)
@@ -157,39 +153,7 @@ void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
.setVectorTransferSplit(vectorTransferSplit))
.setVectorTransferToSCFOptions(
VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
- strategy.transform(getFunction());
-}
-
-template <typename OpType>
-void TestLinalgCodegenStrategy::runStrategy(
- LinalgTilingOptions tilingOptions,
- LinalgTilingOptions registerTilingOptions,
- vector::VectorContractLowering vectorContractLowering,
- vector::VectorTransferSplit vectorTransferSplit) {
- CodegenStrategy strategy;
- strategy.tileIf<OpType>(!tileSizes.empty(), tilingOptions)
- .template promoteIf<OpType>(
- promote, LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(promoteFullTile))
- .template tileIf<OpType>(!registerTileSizes.empty(),
- registerTilingOptions)
- .template promoteIf<OpType>(
- registerPromote,
- LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(registerPromoteFullTile))
- .template vectorizeIf<OpType>(vectorize)
- .setEnableVectorTransferPartialRewrite(true)
- .setEnableVectorContractLowering(true)
- .setEnableVectorToSCFConversion(true)
- .setVectorTransformsOptions(
- vector::VectorTransformsOptions()
- .setVectorTransformsOptions(vectorContractLowering)
- .setVectorTransferSplit(vectorTransferSplit))
- .setVectorTransferToSCFOptions(
- VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
- strategy.transform(getFunction());
+ (void)strategy.transform(getFunction());
}
} // end anonymous namespace
@@ -224,14 +188,8 @@ void TestLinalgCodegenStrategy::runOnFunction() {
.Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
.Default(vector::VectorTransferSplit::None);
- // If no anchorOpNameis specified, just test that strategy applies properly to
- // linalg::MatmulOp.
- if (anchorOpName.empty())
- runStrategy<linalg::MatmulOp>(tilingOptions, registerTilingOptions,
- vectorContractLowering, vectorTransferSplit);
- else
- runStrategy<LinalgOp>(tilingOptions, registerTilingOptions,
- vectorContractLowering, vectorTransferSplit);
+ runStrategy(tilingOptions, registerTilingOptions, vectorContractLowering,
+ vectorTransferSplit);
}
namespace mlir {
More information about the Mlir-commits
mailing list