[Mlir-commits] [mlir] 299cc5d - [mlir][Linalg] Further improve codegen strategy and add a linalg.matmul_i8_i8_i32
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jan 28 05:06:37 PST 2021
Author: Nicolas Vasilache
Date: 2021-01-28T13:02:42Z
New Revision: 299cc5da6df6be9f3c81c54e2e952c6d3519f63b
URL: https://github.com/llvm/llvm-project/commit/299cc5da6df6be9f3c81c54e2e952c6d3519f63b
DIFF: https://github.com/llvm/llvm-project/commit/299cc5da6df6be9f3c81c54e2e952c6d3519f63b.diff
LOG: [mlir][Linalg] Further improve codegen strategy and add a linalg.matmul_i8_i8_i32
This revision adds a layer of SFINAE to the composable codegen strategy so it does
not have to require statically defined ops but instead can also be used with OpInterfaces, Operation* and an op name string.
A linalg.matmul_i8_i8_i32 is added to the .tc spec to demonstrate how all this works end to end.
Differential Revision: https://reviews.llvm.org/D95600
Added:
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index e6d1e1935367..fc09243b46fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -8,6 +8,13 @@ def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
}
+ods_def<MatmulI8I8I32Op>:
+def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
+ // TODO: ideally something closer to
+ // C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
+ C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+}
+
ods_def<MatvecOp>:
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 25a98e3187e1..21bad4acec7c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -21,27 +21,63 @@ namespace linalg {
/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
+ explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
+ : filter(f) {}
virtual ~Transformation() = default;
virtual OwningRewritePatternList
- buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0;
- linalg::LinalgMarker marker;
+ buildRewritePatterns(MLIRContext *context,
+ linalg::LinalgTransformationFilter m) = 0;
+ linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
};
+/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`.
+template <template <typename> class PatternType, typename ConcreteOpType,
+ typename OptionsType,
+ typename std::enable_if<std::is_member_function_pointer<
+ decltype(&ConcreteOpType::getOperationName)>::value>>
+void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
+ MLIRContext *context, StringRef opName,
+ linalg::LinalgTransformationFilter m) {
+ assert(opName.empty() ||
+ opName == ConcreteOpType::getOperationName() &&
+ "explicit name must match ConcreteOpType::getOperationName");
+ patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
+}
+
+/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
+/// (e.g. LinalgOp, other interfaces, Operation*).
+template <template <typename> class PatternType, typename OpType,
+ typename OptionsType>
+void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
+ MLIRContext *context, StringRef opName,
+ linalg::LinalgTransformationFilter m) {
+ assert(!opName.empty() && "opName must not be empty");
+ patterList.insert<PatternType<OpType>>(opName, context, options, m);
+}
+
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Tile<LinalgOpType>`with the appropriate `options`.
template <typename LinalgOpType>
struct Tile : public Transformation {
- explicit Tile(linalg::LinalgTilingOptions options) : options(options) {}
+ explicit Tile(linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(""), options(options) {}
+
+ Tile(StringRef name, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(name), options(options) {}
OwningRewritePatternList
- buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
+ buildRewritePatterns(MLIRContext *context,
+ linalg::LinalgTransformationFilter m) override {
OwningRewritePatternList tilingPatterns;
- tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>(
- context, options, m);
+ sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
+ tilingPatterns, options, context, opName, m);
return tilingPatterns;
}
private:
+ std::string opName;
linalg::LinalgTilingOptions options;
};
@@ -49,17 +85,26 @@ struct Tile : public Transformation {
/// `Promote<LinalgOpType>`with the appropriate `options`.
template <typename LinalgOpType>
struct Promote : public Transformation {
- explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {}
+ explicit Promote(
+ linalg::LinalgPromotionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(""), options(options) {}
+
+ Promote(StringRef name, linalg::LinalgPromotionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(name), options(options) {}
OwningRewritePatternList
- buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
+ buildRewritePatterns(MLIRContext *context,
+ linalg::LinalgTransformationFilter m) override {
OwningRewritePatternList promotionPatterns;
- promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>(
- context, options, m);
+ sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
+ promotionPatterns, options, context, opName, m);
return promotionPatterns;
}
private:
+ std::string opName;
linalg::LinalgPromotionOptions options;
};
@@ -68,25 +113,36 @@ struct Promote : public Transformation {
/// transfer rewrite forwarding patterns.
template <typename LinalgOpType>
struct Vectorize : public Transformation {
+ explicit Vectorize(
+ linalg::LinalgVectorizationOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(""), options(options) {}
+
+ Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(name), options(options) {}
+
OwningRewritePatternList
- buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
+ buildRewritePatterns(MLIRContext *context,
+ linalg::LinalgTransformationFilter m) override {
OwningRewritePatternList vectorizationPatterns;
- // FillOp may interfere with forwarding patterns atm, so we bump up the
- // priority of LinalgCopyVTRForwardingPattern /
- // LinalgCopyVTWForwardingPattern.
- vectorizationPatterns
- .insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m);
+ sfinae_enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
+ vectorizationPatterns, options, context, opName, m);
vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
context, /*benefit=*/2);
return vectorizationPatterns;
}
+
+private:
+ std::string opName;
+ linalg::LinalgVectorizationOptions options;
};
/// Codegen strategy controls how a Linalg op is progressively lowered.
/// The application uses a 3-level staged patterns strategy which allows
-/// ordering transformations by using the Linalg `applyStagedPatterns` function,
-/// where:
+/// ordering transformations by using the Linalg `applyStagedPatterns`
+/// function, where:
/// 1. The first stage consists of the successive `tile`, `promote` and
/// `vectorize` patterns, applied sequentially.
/// 2. The second stage consists of common local canonicalization patterns
@@ -97,41 +153,112 @@ struct CodegenStrategy {
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
/// `options`.
template <typename LinalgOpType>
- CodegenStrategy &tile(linalg::LinalgTilingOptions options) {
- transformationSequence.emplace_back(new Tile<LinalgOpType>(options));
+ CodegenStrategy &
+ tile(linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Tile<LinalgOpType>>(options, f));
+ return *this;
+ }
+ /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
+ /// `options`.
+ template <typename LinalgOpType>
+ CodegenStrategy &
+ tile(StringRef opName, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Tile<LinalgOpType>>(opName, options, f));
return *this;
}
- /// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
- /// with tiling `options`.
+ /// Conditionally append a pattern to add a level of tiling for
+ /// `LinalgOpType` with tiling `options`.
template <typename LinalgOpType>
- CodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) {
+ CodegenStrategy &
+ tileIf(bool b, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? tile<LinalgOpType>(options) : *this;
}
+ /// Conditionally append a pattern to add a level of tiling for
+ /// `LinalgOpType` with tiling `options`.
+ template <typename LinalgOpType>
+ CodegenStrategy &
+ tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? tile<LinalgOpType>(opName, options) : *this;
+ }
+ /// Append a pattern to add a level of promotion for `LinalgOpType` with
+ /// promotion `options`.
+ template <typename LinalgOpType>
+ CodegenStrategy &
+ promote(linalg::LinalgPromotionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Promote<LinalgOpType>>(options, f));
+ return *this;
+ }
/// Append a pattern to add a level of promotion for `LinalgOpType` with
/// promotion `options`.
template <typename LinalgOpType>
- CodegenStrategy &promote(linalg::LinalgPromotionOptions options) {
- transformationSequence.emplace_back(new Promote<LinalgOpType>(options));
+ CodegenStrategy &
+ promote(StringRef opName, linalg::LinalgPromotionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Promote<LinalgOpType>>(opName, options, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to add a level of promotion for
+ /// `LinalgOpType` with promotion `options`.
+ template <typename LinalgOpType>
+ CodegenStrategy &
+ promoteIf(bool b, StringRef opName, linalg::LinalgPromotionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? promote<LinalgOpType>(opName, options, f) : *this;
return *this;
}
/// Conditionally append a pattern to add a level of promotion for
/// `LinalgOpType` with promotion `options`.
template <typename LinalgOpType>
- CodegenStrategy &promoteIf(bool b, linalg::LinalgPromotionOptions options) {
- return b ? promote<LinalgOpType>(options) : *this;
+ CodegenStrategy &
+ promoteIf(bool b, linalg::LinalgPromotionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? promote<LinalgOpType>(options, f) : *this;
+ return *this;
+ }
+ /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
+ template <typename LinalgOpType>
+ CodegenStrategy &
+ vectorize(linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Vectorize<LinalgOpType>>(
+ linalg::LinalgVectorizationOptions(), f));
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
template <typename LinalgOpType>
- CodegenStrategy &vectorize() {
- transformationSequence.emplace_back(new Vectorize<LinalgOpType>());
+ CodegenStrategy &
+ vectorize(StringRef opName,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Vectorize<LinalgOpType>>(
+ opName, linalg::LinalgVectorizationOptions(), f));
+ return *this;
+ }
+ /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
+ /// operation.
+ template <typename LinalgOpType>
+ CodegenStrategy &
+ vectorizeIf(bool b,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? vectorize<LinalgOpType>(f) : *this;
return *this;
}
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
/// operation.
template <typename LinalgOpType>
- CodegenStrategy &vectorizeIf(bool b) {
- return b ? vectorize<LinalgOpType>() : *this;
+ CodegenStrategy &
+ vectorizeIf(bool b, StringRef opName,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? vectorize<LinalgOpType>(opName, f) : *this;
return *this;
}
/// Configure the post staged-patterns late vector transformations.
@@ -140,15 +267,22 @@ struct CodegenStrategy {
vectorTransformsOptions = options;
return *this;
}
- /// Configure the post staged-patterns late vector.transfer to scf conversion.
+ /// Configure the post staged-patterns late vector.transfer to scf
+ /// conversion.
CodegenStrategy &
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
vectorToSCFOptions = options;
return *this;
}
+ /// Configure the post staged-patterns late vector.transfer to scf
+ /// conversion.
+ CodegenStrategy &setHoistInvariantCode(bool enableLICM) {
+ this->enableLICM = enableLICM;
+ return *this;
+ }
- /// Apply the transformation patterns in sequence with cleanup transformations
- /// interleaved.
+ /// Apply the transformation patterns in sequence with cleanup
+ /// transformations interleaved.
void transform(FuncOp func) const;
private:
@@ -157,6 +291,7 @@ struct CodegenStrategy {
vector::VectorTransformsOptions vectorTransformsOptions;
VectorTransferToSCFOptions vectorToSCFOptions;
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
+ bool enableLICM = true;
};
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f359992e5ff1..18cb91e3200a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -316,16 +316,32 @@ struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
-/// Helper class to control common attribute matching and setting behavior.
-struct LinalgMarker {
- explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
- Optional<Identifier> replacement = None);
- LinalgMarker(LinalgMarker &&) = default;
- LinalgMarker(const LinalgMarker &) = default;
+/// Helper class to control application of linalg transformation patterns.
+/// Control comes in 2 forms:
+/// 1. attribute matching and setting behavior using the attribute named
+/// `kLinalgTransformMarker`. This can be used to build a state machine
+/// using attributes and incrementally applying patterns to advance states.
+/// 2. filter function, which is a simple lambda on the Operation* that
+/// returns a LogicalResult.
+struct LinalgTransformationFilter {
+ using FilterFunction = std::function<LogicalResult(Operation *)>;
+
+ explicit LinalgTransformationFilter(
+ ArrayRef<Identifier> matchDisjunction = {},
+ Optional<Identifier> replacement = None);
+
+ explicit LinalgTransformationFilter(
+ FilterFunction f, ArrayRef<Identifier> matchDisjunction = {},
+ Optional<Identifier> replacement = None);
+
+ LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
+ LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
- void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
+ void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
+ Operation *op) const;
private:
+ FilterFunction filter;
SmallVector<Identifier, 4> matchDisjunction;
Optional<Identifier> replacement;
};
@@ -425,31 +441,44 @@ void populateLinalgTilingCanonicalizationPatterns(
/// and some operand shape cannot be bounded statically.
struct LinalgBaseTilingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
- LinalgBaseTilingPattern(LinalgTilingOptions options,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1);
+ LinalgBaseTilingPattern(
+ LinalgTilingOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
- LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
- LinalgTilingOptions options,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1);
+ LinalgBaseTilingPattern(
+ StringRef opName, MLIRContext *context, LinalgTilingOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
TiledLinalgOp &result) const;
private:
/// LinalgTransformMarker handles special attribute manipulations.
- LinalgMarker marker;
+ LinalgTransformationFilter marker;
/// Options to control tiling;
LinalgTilingOptions options;
};
template <typename OpTy>
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
- LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1)
- : LinalgBaseTilingPattern(OpTy::getOperationName(), context, options,
- marker, benefit) {}
+ /// SFINAE: This constructor can only trigger for concrete ops that have a
+ /// static `getOperationName` method.
+ template <typename ConcreateOpTy = OpTy>
+ LinalgTilingPattern(
+ MLIRContext *context, LinalgTilingOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
+ options, marker, benefit) {}
+
+ /// This constructor is available to anyone.
+ LinalgTilingPattern(
+ StringRef opName, MLIRContext *context, LinalgTilingOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBaseTilingPattern(opName, context, options, marker, benefit) {}
+
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledLinalgOp tiledLinalgOp;
@@ -474,14 +503,15 @@ struct LinalgFusionOptions {
};
struct LinalgBaseTileAndFusePattern : public RewritePattern {
- LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions,
- LinalgFusionOptions fusionOptions,
- LinalgMarker marker = LinalgMarker(),
- LinalgMarker fusedOpMarker = LinalgMarker(),
- LinalgMarker originalOpMarker = LinalgMarker(),
- PatternBenefit benefit = 1);
+ LinalgBaseTileAndFusePattern(
+ StringRef opName, MLIRContext *context,
+ const LinalgDependenceGraph &dependenceGraph,
+ LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
+ LinalgTransformationFilter originalOpMarker =
+ LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
@@ -493,27 +523,27 @@ struct LinalgBaseTileAndFusePattern : public RewritePattern {
/// Options to control fusion.
LinalgFusionOptions fusionOptions;
/// Marker to control application of the pattern.
- LinalgMarker marker;
+ LinalgTransformationFilter marker;
/// Marker set on the fused op after tile and fuse.
- LinalgMarker fusedOpMarker;
+ LinalgTransformationFilter fusedOpMarker;
/// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
/// to build the dependence graph changes then the dependenceGraph needs to be
/// recomputed right now. To not invalidate the dependenceGraph as
/// transformation happens, the original producer can be tagged with a marker
/// that can be later used to delete the original operations.
- LinalgMarker originalOpMarker;
+ LinalgTransformationFilter originalOpMarker;
};
template <typename OpTy>
struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
- LinalgTileAndFusePattern(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions,
- LinalgFusionOptions fusionOptions,
- LinalgMarker marker = LinalgMarker(),
- LinalgMarker fusedOpMarker = LinalgMarker(),
- LinalgMarker originalOpMarker = LinalgMarker(),
- PatternBenefit benefit = 1)
+ LinalgTileAndFusePattern(
+ MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
+ LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
+ LinalgTransformationFilter originalOpMarker =
+ LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
: LinalgBaseTileAndFusePattern(
OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
@@ -526,26 +556,27 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct LinalgBaseInterchangePattern : public RewritePattern {
- LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context,
- ArrayRef<unsigned> interchangeVector,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1);
+ LinalgBaseInterchangePattern(
+ StringRef opName, MLIRContext *context,
+ ArrayRef<unsigned> interchangeVector,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
- LinalgMarker marker;
+ LinalgTransformationFilter marker;
/// The interchange vector to reorder the iterators and indexing_maps dims.
SmallVector<unsigned, 8> interchangeVector;
};
template <typename OpTy>
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
- LinalgInterchangePattern(MLIRContext *context,
- ArrayRef<unsigned> interchangeVector,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1)
+ LinalgInterchangePattern(
+ MLIRContext *context, ArrayRef<unsigned> interchangeVector,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
interchangeVector, marker, benefit) {}
};
@@ -557,27 +588,38 @@ struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViews` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
- LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
- LinalgPromotionOptions options,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1);
+ LinalgBasePromotionPattern(
+ StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
- LinalgMarker marker;
+ LinalgTransformationFilter marker;
/// Promotion options.
LinalgPromotionOptions options;
};
template <typename OpTy>
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
- LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1)
+ /// SFINAE: This constructor can only trigger for concrete ops that have a
+ /// static `getOperationName` method.
+ template <typename ConcreateOpTy = OpTy>
+ LinalgPromotionPattern(
+ MLIRContext *context, LinalgPromotionOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
marker, benefit) {}
+ /// This constructor is available to anyone.
+ LinalgPromotionPattern(
+ StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBasePromotionPattern(opName, context, options, marker, benefit) {}
};
///
@@ -586,25 +628,42 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
/// `marker` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
+
+/// Empty for now, used for SFINAE purposes only.
+struct LinalgVectorizationOptions {};
+
struct LinalgBaseVectorizationPattern : public RewritePattern {
- LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1);
+ LinalgBaseVectorizationPattern(
+ StringRef opName, MLIRContext *context,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
- LinalgMarker marker;
+ LinalgTransformationFilter marker;
};
template <typename OpTy>
struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
- LinalgVectorizationPattern(MLIRContext *context,
- LinalgMarker marker = LinalgMarker(),
- PatternBenefit benefit = 1)
+ /// SFINAE: This constructor can only trigger for concrete ops that have a
+ /// static `getOperationName` method.
+ template <typename ConcreateOpTy = OpTy>
+ LinalgVectorizationPattern(
+ MLIRContext *context,
+ LinalgVectorizationOptions options = LinalgVectorizationOptions(),
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
marker, benefit) {}
+ /// This constructor is available to anyone.
+ LinalgVectorizationPattern(
+ StringRef opName, MLIRContext *context,
+ LinalgVectorizationOptions options = LinalgVectorizationOptions(),
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBaseVectorizationPattern(opName, context, marker, benefit) {}
};
///
@@ -622,10 +681,10 @@ enum class LinalgLoweringType {
template <typename OpTy>
struct LinalgLoweringPattern : public RewritePattern {
- LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
- LinalgMarker marker = LinalgMarker(),
- ArrayRef<unsigned> interchangeVector = {},
- PatternBenefit benefit = 1)
+ LinalgLoweringPattern(
+ MLIRContext *context, LinalgLoweringType loweringType,
+ LinalgTransformationFilter marker = LinalgTransformationFilter(),
+ ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
marker(marker), loweringType(loweringType),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
@@ -663,7 +722,7 @@ struct LinalgLoweringPattern : public RewritePattern {
private:
/// LinalgTransformMarker handles special attribute manipulations.
- LinalgMarker marker;
+ LinalgTransformationFilter marker;
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel.
LinalgLoweringType loweringType;
@@ -677,13 +736,13 @@ struct LinalgLoweringPattern : public RewritePattern {
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- LinalgMarker marker = LinalgMarker());
+ LinalgTransformationFilter marker = LinalgTransformationFilter());
/// Populates `patterns` with patterns to convert linalg.conv ops to
/// linalg.generic ops.
void populateLinalgConvGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- LinalgMarker marker = LinalgMarker());
+ LinalgTransformationFilter marker = LinalgTransformationFilter());
//===----------------------------------------------------------------------===//
// Op-specific patterns.
diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
index 5d4668d7b5fc..ae5d5ad357f5 100644
--- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
@@ -38,6 +38,7 @@ using std_ret = OperationBuilder<ReturnOp>;
using std_rsqrt = ValueBuilder<RsqrtOp>;
using std_select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
+using std_sign_extendi = ValueBuilder<SignExtendIOp>;
using std_splat = ValueBuilder<SplatOp>;
using std_store = OperationBuilder<StoreOp>;
using std_subf = ValueBuilder<SubFOp>;
@@ -48,9 +49,19 @@ using std_tensor_load = ValueBuilder<TensorLoadOp>;
using std_tensor_store = OperationBuilder<TensorStoreOp>;
using std_view = ValueBuilder<ViewOp>;
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
-using std_sign_extendi = ValueBuilder<SignExtendIOp>;
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
+template <int N>
+struct SExtiValueBuilder : public ValueBuilder<SignExtendIOp> {
+ using ValueBuilder<SignExtendIOp>::ValueBuilder;
+ template <typename... Args>
+ SExtiValueBuilder(Args... args)
+ : ValueBuilder<SignExtendIOp>(ScopedContext::getBuilderRef().getI32Type(),
+ args...) {}
+};
+
+using std_sexti32 = SExtiValueBuilder<32>;
+
/// Branches into `block` with `operands`.
BranchOp std_br(Block *block, ValueRange operands);
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
index 3c589d163857..3e7560ef0867 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -14,9 +14,12 @@
// RUN: tee -a /dev/stderr | FileCheck %s
-!row_major_A = type memref<${M}x${K}xf32>
-!row_major_B = type memref<${K}x${N}xf32>
-!row_major_C = type memref<${M}x${N}xf32>
+!elem_type_a = type f32
+!elem_type_b = type f32
+!elem_type_c = type f32
+!row_major_A = type memref<${M}x${K}x!elem_type_a>
+!row_major_B = type memref<${K}x${N}x!elem_type_b>
+!row_major_C = type memref<${M}x${N}x!elem_type_c>
func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
// TODO: activate manually for now.
@@ -48,16 +51,16 @@ func @print_perf(%iters: index, %total_time: f64) {
}
func @main() {
- %f0 = constant 0.0 : f32
- %f1 = constant 1.0 : f32
+ %v0 = constant 0.0 : !elem_type_a
+ %v1 = constant 1.0 : !elem_type_a
%A = alloc() : !row_major_A
%B = alloc() : !row_major_B
%C = alloc() : !row_major_C
- linalg.fill(%A, %f1) : !row_major_A, f32
- linalg.fill(%B, %f1) : !row_major_B, f32
- linalg.fill(%C, %f0) : !row_major_C, f32
+ linalg.fill(%A, %v1) : !row_major_A, !elem_type_a
+ linalg.fill(%B, %v1) : !row_major_B, !elem_type_b
+ linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
@@ -66,7 +69,8 @@ func @main() {
/// Run and dump performance for matmul.
/// Preheating run:
scf.for %arg0 = %c0 to %iters step %c1 {
- linalg.fill(%C, %f0) : !row_major_C, f32
+ %z = constant 0.0 : !elem_type_c
+ linalg.fill(%C, %z) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
}
%t_start_matmul = call @rtclock() : () -> f64
@@ -75,7 +79,8 @@ func @main() {
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
- linalg.fill(%C, %f0) : !row_major_C, f32
+ %z = constant 0.0 : !elem_type_c
+ linalg.fill(%C, %z) : !row_major_C, !elem_type_c
call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
}
%t_end_matmul = call @rtclock() : () -> f64
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
index a71643fde480..03e51b4b1a91 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
@@ -15,12 +15,15 @@
// Use tee to both print to stderr and FileCheck
// RUN: tee -a /dev/stderr | FileCheck %s
-!row_major_A = type memref<${M}x${K}xf32>
-!row_major_B = type memref<${K}x${N}xf32>
-!row_major_C = type memref<${M}x${N}xf32>
-!column_major_A = type memref<${K}x${M}xf32>
-!column_major_B = type memref<${N}x${K}xf32>
-!column_major_C = type memref<${N}x${M}xf32>
+!elem_type_a = type f32
+!elem_type_b = type f32
+!elem_type_c = type f32
+!row_major_A = type memref<${M}x${K}x!elem_type_a>
+!row_major_B = type memref<${K}x${N}x!elem_type_b>
+!row_major_C = type memref<${M}x${N}x!elem_type_c>
+!column_major_A = type memref<${K}x${M}x!elem_type_a>
+!column_major_B = type memref<${N}x${K}x!elem_type_b>
+!column_major_C = type memref<${N}x${M}x!elem_type_c>
func @matmul_column_major(%a: !column_major_A, %b: !column_major_B, %c: !column_major_C)
// TODO: activate manually for now.
@@ -52,16 +55,16 @@ func @print_perf(%iters: index, %total_time: f64) {
}
func @main() {
- %f0 = constant 0.0 : f32
- %f1 = constant 1.0 : f32
+ %f0 = constant 0.0 : !elem_type_c
+ %f1 = constant 1.0 : !elem_type_a
%cA = alloc() : !column_major_A
%cB = alloc() : !column_major_B
%cC = alloc() : !column_major_C
- linalg.fill(%cA, %f1) : !column_major_A, f32
- linalg.fill(%cB, %f1) : !column_major_B, f32
- linalg.fill(%cC, %f0) : !column_major_C, f32
+ linalg.fill(%cA, %f1) : !column_major_A, !elem_type_a
+ linalg.fill(%cB, %f1) : !column_major_B, !elem_type_b
+ linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
@@ -74,7 +77,7 @@ func @main() {
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
- linalg.fill(%cC, %f0) : !column_major_C, f32
+ linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()
}
%t_end_matmul_column_major = call @rtclock() : () -> f64
@@ -83,7 +86,7 @@ func @main() {
%res = load %cC[%c0, %c0]: !column_major_C
// CHECK: 64
- vector.print %res: f32
+ vector.print %res: !elem_type_c
dealloc %cA : !column_major_A
dealloc %cB : !column_major_B
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
index c8f3fe4b95d4..f672829ac432 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
@@ -16,12 +16,15 @@
// Use tee to both print to stderr and FileCheck
// RUN: tee -a /dev/stderr | FileCheck %s
-!row_major_A = type memref<${M}x${K}xf32>
-!row_major_B = type memref<${K}x${N}xf32>
-!row_major_C = type memref<${M}x${N}xf32>
-!column_major_A = type memref<${K}x${M}xf32>
-!column_major_B = type memref<${N}x${K}xf32>
-!column_major_C = type memref<${N}x${M}xf32>
+!elem_type_a = type f32
+!elem_type_b = type f32
+!elem_type_c = type f32
+!row_major_A = type memref<${M}x${K}x!elem_type_a>
+!row_major_B = type memref<${K}x${N}x!elem_type_b>
+!row_major_C = type memref<${M}x${N}x!elem_type_c>
+!column_major_A = type memref<${K}x${M}x!elem_type_a>
+!column_major_B = type memref<${N}x${K}x!elem_type_b>
+!column_major_C = type memref<${N}x${M}x!elem_type_c>
func @matmul_column_major_as_row_major(
%ca: !column_major_A, %cb: !column_major_B, %cc: !column_major_C,
@@ -58,16 +61,16 @@ func @print_perf(%iters: index, %total_time: f64) {
}
func @main() {
- %f0 = constant 0.0 : f32
- %f1 = constant 1.0 : f32
+ %f0 = constant 0.0 : !elem_type_c
+ %f1 = constant 1.0 : !elem_type_a
%cA = alloc() : !column_major_A
%cB = alloc() : !column_major_B
%cC = alloc() : !column_major_C
- linalg.fill(%cA, %f1) : !column_major_A, f32
- linalg.fill(%cB, %f1) : !column_major_B, f32
- linalg.fill(%cC, %f0) : !column_major_C, f32
+ linalg.fill(%cA, %f1) : !column_major_A, !elem_type_a
+ linalg.fill(%cB, %f1) : !column_major_B, !elem_type_b
+ linalg.fill(%cC, %f0) : !column_major_C, !elem_type_c
%c0 = constant 0: index
%c1 = constant 1: index
@@ -83,7 +86,7 @@ func @main() {
// This is accounts for about 10-15% perf hit on small sizes.
// Once linalg on tensors is ready, fusing fill at teh register level will
// be easy.
- linalg.fill(%C, %f0) : !row_major_C, f32
+ linalg.fill(%C, %f0) : !row_major_C, !elem_type_c
call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :
(!column_major_A, !column_major_B, !column_major_C,
!row_major_A, !row_major_B, !row_major_C) -> ()
@@ -94,10 +97,10 @@ func @main() {
%res = load %cC[%c0, %c0]: !column_major_C
// CHECK: 64
- vector.print %res: f32
+ vector.print %res: !elem_type_c
%res2 = load %C[%c0, %c0]: !row_major_C
// CHECK: 64
- vector.print %res2: f32
+ vector.print %res2: !elem_type_c
dealloc %A : !row_major_A
dealloc %B : !row_major_B
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
new file mode 100644
index 000000000000..9243ebbae4eb
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_i8_i8_i32.mlir
@@ -0,0 +1,103 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// TODO: extend vectorization with interfaces so that it works with sexti
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_i8_i8_i32 register-tile-sizes=12,32,16" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN: -dump-object-file -object-filename=/tmp/a.o \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+
+!elem_type_a = type i8
+!elem_type_b = type i8
+!elem_type_c = type i32
+!row_major_A = type memref<${M}x${K}x!elem_type_a>
+!row_major_B = type memref<${K}x${N}x!elem_type_b>
+!row_major_C = type memref<${M}x${N}x!elem_type_c>
+
+func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+ linalg.matmul_i8_i8_i32 ins(%a, %b : !row_major_A, !row_major_B)
+ outs(%c: !row_major_C)
+ return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+ %c2 = constant 2 : index
+ %cM = constant ${M} : index
+ %cN = constant ${N} : index
+ %cK = constant ${K} : index
+
+ %mn = muli %cM, %cN : index
+ %mnk = muli %mn, %cK : index
+
+ // 2*M*N*K.
+ %flops_per_iter = muli %c2, %mnk : index
+ %flops = muli %iters, %flops_per_iter : index
+ %flops_i64 = index_cast %flops : index to i64
+ %flops_f = sitofp %flops_i64 : i64 to f64
+ %flops_per_s = divf %flops_f, %total_time : f64
+ vector.print %flops_per_s : f64
+
+ return
+}
+
+func @main() {
+ %v0 = constant 0 : !elem_type_c
+ %v1 = constant 1 : !elem_type_a
+
+ %A = alloc() : !row_major_A
+ %B = alloc() : !row_major_B
+ %C = alloc() : !row_major_C
+
+ linalg.fill(%A, %v1) : !row_major_A, !elem_type_a
+ linalg.fill(%B, %v1) : !row_major_B, !elem_type_b
+ linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
+
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %iters = constant ${ITERS}: index
+
+ /// Run and dump performance for matmul.
+ /// Preheating run:
+ scf.for %arg0 = %c0 to %iters step %c1 {
+ linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
+ call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
+ }
+ %t_start_matmul = call @rtclock() : () -> f64
+ scf.for %arg0 = %c0 to %iters step %c1 {
+ // linalg.matmul writes %C in place, need to reset it to zero every time.
+ // This is accounts for about 10-15% perf hit on small sizes.
+ // Once linalg on tensors is ready, fusing fill at teh register level will
+ // be easy.
+ linalg.fill(%C, %v0) : !row_major_C, !elem_type_c
+ call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
+ }
+ %t_end_matmul = call @rtclock() : () -> f64
+ %tmatmul = subf %t_end_matmul, %t_start_matmul: f64
+ call @print_perf(%iters, %tmatmul) : (index, f64) -> ()
+
+ %res = load %C[%c0, %c0]: !row_major_C
+ // CHECK: 64
+ vector.print %res: !elem_type_c
+
+ dealloc %A : !row_major_A
+ dealloc %B : !row_major_B
+ dealloc %C : !row_major_C
+
+ return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 02058f886451..5c9d1df0c056 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -37,8 +37,10 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
for (const std::unique_ptr<Transformation> &t : transformationSequence) {
auto nextState = Identifier::get(std::to_string(++stepCount), context);
auto marker = (currentState == zeroState)
- ? linalg::LinalgMarker({}, nextState)
- : linalg::LinalgMarker(currentState, nextState);
+ ? linalg::LinalgTransformationFilter(
+ t->filter, ArrayRef<Identifier>{}, nextState)
+ : linalg::LinalgTransformationFilter(
+ t->filter, currentState, nextState);
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
currentState = nextState;
}
@@ -47,15 +49,17 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
linalg::getLinalgTilingCanonicalizationPatterns(context);
stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
- auto stage3Transforms = [](Operation *op) {
+ auto stage3Transforms = [&](Operation *op) {
// Some of these may be too aggressive as a stage 3 that is applied on each
// stage 1 application and may have to be split out to post staged patterns
// application (in which case they could just be passes, TBD).
- op->walk([&](LoopLikeOpInterface loopLike) {
- LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
- if (failed(moveLoopInvariantCode(loopLike)))
- llvm_unreachable("unexpected LICM failure");
- });
+ if (enableLICM) {
+ op->walk([&](LoopLikeOpInterface loopLike) {
+ LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
+ if (failed(moveLoopInvariantCode(loopLike)))
+ llvm_unreachable("unexpected LICM failure");
+ });
+ }
promoteSingleIterationLoops(cast<FuncOp>(op));
hoistViewAllocOps(cast<FuncOp>(op));
hoistRedundantVectorTransfers(cast<FuncOp>(op));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 454bbbe3578a..997fa692c2b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -63,7 +63,8 @@ namespace {
// into auto-generated ones.
template <typename ConcretePattern, typename RootOp>
struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
- LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
+ LinalgGeneralizationPattern(MLIRContext *context,
+ linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
@@ -81,12 +82,13 @@ struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
return failure();
rewriter.replaceOp(rootOp, genericOp.getResults());
- marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+ marker.replaceLinalgTransformationFilter(rewriter,
+ genericOp.getOperation());
return success();
}
private:
- linalg::LinalgMarker marker;
+ linalg::LinalgTransformationFilter marker;
};
struct GeneralizeConvOp
@@ -100,7 +102,7 @@ struct GeneralizeConvOp
/// linalg.generic.
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
- linalg::LinalgMarker marker,
+ linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()),
marker(std::move(marker)) {}
@@ -123,12 +125,13 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
return failure();
rewriter.replaceOp(rootOp, genericOp.getResults());
- marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+ marker.replaceLinalgTransformationFilter(rewriter,
+ genericOp.getOperation());
return success();
}
private:
- linalg::LinalgMarker marker;
+ linalg::LinalgTransformationFilter marker;
};
struct LinalgGeneralizationPass
@@ -165,13 +168,13 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- linalg::LinalgMarker marker) {
+ linalg::LinalgTransformationFilter marker) {
patterns.insert<GeneralizeConvOp>(context, marker);
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- linalg::LinalgMarker marker) {
+ linalg::LinalgTransformationFilter marker) {
patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index d723dc47ac57..ce41560f7557 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -536,7 +536,9 @@ class RewritePatternList<OpTy, OpTypes...> {
static void insert(OwningRewritePatternList &patterns,
const LinalgTilingOptions &options, MLIRContext *ctx) {
patterns.insert<LinalgTilingPattern<OpTy>>(
- ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
+ ctx, options,
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get("tiled", ctx)));
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index b0cb51516e25..b4c94ae53937 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -46,14 +46,23 @@ using namespace mlir::linalg;
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
-mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
- Optional<Identifier> replacement)
- : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
+ ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
+ : LinalgTransformationFilter([](Operation *) { return success(); },
+ matchDisjunction, replacement) {}
+
+mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
+ FilterFunction f, ArrayRef<Identifier> matchDisjunction,
+ Optional<Identifier> replacement)
+ : filter(f),
+ matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement) {}
-LogicalResult
-mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
- Operation *op) const {
+LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
+ PatternRewriter &rewriter, Operation *op) const {
+ if (filter && failed(filter(op)))
+ return failure();
+
auto attr = op->template getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker);
@@ -81,8 +90,9 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
});
}
-void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
- Operation *op) const {
+void mlir::linalg::LinalgTransformationFilter::
+ replaceLinalgTransformationFilter(PatternRewriter &rewriter,
+ Operation *op) const {
if (replacement.hasValue())
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(replacement.getValue()));
@@ -219,12 +229,13 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
/// Linalg base tiling pattern.
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
- LinalgMarker marker, PatternBenefit benefit)
+ LinalgTransformationFilter marker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
- LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit)
+ LinalgTilingOptions options, LinalgTransformationFilter marker,
+ PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker),
options(options) {}
@@ -250,9 +261,9 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
// Return relevant information to derived pattern.
result = *res;
// Replace marker on both tiledOp and tiledAndPaddedOp, if necessary.
- marker.replaceLinalgMarker(rewriter, tiledOp);
+ marker.replaceLinalgTransformationFilter(rewriter, tiledOp);
if (tiledOp != res->op)
- marker.replaceLinalgMarker(rewriter, res->op);
+ marker.replaceLinalgTransformationFilter(rewriter, res->op);
});
// Consider padding on the fly only if the op has tensor semantics.
@@ -276,8 +287,8 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgMarker marker, LinalgMarker fusedOpMarker,
- LinalgMarker originalOpMarker, PatternBenefit benefit)
+ LinalgTransformationFilter marker, LinalgTransformationFilter fusedOpMarker,
+ LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context),
dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
fusionOptions(fusionOptions), marker(marker),
@@ -352,23 +363,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
tiledAndFusedOps->op = unfusedTiledOp->op;
}
- marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
+ marker.replaceLinalgTransformationFilter(rewriter,
+ tiledAndFusedOps->op.getOperation());
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
- fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
+ fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
+ fusedOp.getOperation());
}
for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
- originalOpMarker.replaceLinalgMarker(rewriter,
- origProducerOp.getOperation());
+ originalOpMarker.replaceLinalgTransformationFilter(
+ rewriter, origProducerOp.getOperation());
}
- rewriter.updateRootInPlace(
- op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
+ rewriter.updateRootInPlace(op, [&]() {
+ originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
+ });
return success();
}
/// Linalg base interchange pattern.
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
- ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
+ ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter marker,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
@@ -388,14 +402,14 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
rewriter.updateRootInPlace(op, [&]() {
interchange(linalgOp, interchangeVector);
// New marker if specified.
- marker.replaceLinalgMarker(rewriter, op);
+ marker.replaceLinalgTransformationFilter(rewriter, op);
});
return success();
}
mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
- LinalgMarker marker, PatternBenefit benefit)
+ LinalgTransformationFilter marker, PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
@@ -417,12 +431,12 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
return op->emitError("subview promotion failed");
}
rewriter.finalizeRootUpdate(op);
- marker.replaceLinalgMarker(rewriter, op);
+ marker.replaceLinalgTransformationFilter(rewriter, op);
return success();
}
mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
- StringRef opName, MLIRContext *context, LinalgMarker marker,
+ StringRef opName, MLIRContext *context, LinalgTransformationFilter marker,
PatternBenefit benefit)
: RewritePattern(opName, {}, benefit, context), marker(marker) {}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 76a5bb56a4b6..fa1aba8fd157 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -607,12 +607,13 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
constexpr static StringRef kPromotedMarker = "PROMOTED";
tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
context, LinalgTilingOptions().setTileSizes(tileSizes),
- LinalgMarker({}, Identifier::get(kTiledMarker, context)));
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get(kTiledMarker, context)));
promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
- LinalgMarker(Identifier::get(kTiledMarker, context),
- Identifier::get(kPromotedMarker, context)));
+ LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
+ Identifier::get(kPromotedMarker, context)));
SmallVector<bool, 4> mask(N);
int offset = tileSizes.size() - N;
diff --git a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
index 8d80de793658..34ee46a91ac1 100644
--- a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
@@ -107,8 +107,12 @@ struct TestLinalgCodegenStrategy
};
} // end anonymous namespace
-template <typename LinalgNamedOp>
-void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
+/// Apply transformations specified as patterns.
+void TestLinalgCodegenStrategy::runOnFunction() {
+ linalg::LinalgTransformationFilter::FilterFunction filterOpName =
+ [&](Operation *op) -> LogicalResult {
+ return success(op->getName().getStringRef() == anchorOpName);
+ };
LinalgTilingOptions tilingOptions;
if (!tileSizes.empty())
tilingOptions = tilingOptions.setTileSizes(tileSizes);
@@ -134,19 +138,20 @@ void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
.Default(vector::VectorTransferSplit::None);
CodegenStrategy strategy;
- strategy.template tileIf<LinalgNamedOp>(!tileSizes.empty(), tilingOptions)
- .template promoteIf<LinalgNamedOp>(
- promote, LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(promoteFullTile))
- .template tileIf<LinalgNamedOp>(!registerTileSizes.empty(),
- registerTilingOptions)
- .template promoteIf<LinalgNamedOp>(
- registerPromote,
+ strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
+ .promoteIf<LinalgOp>(promote, anchorOpName,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(promoteFullTile),
+ filterOpName)
+ .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
+ registerTilingOptions)
+ .promoteIf<LinalgOp>(
+ registerPromote, anchorOpName,
LinalgPromotionOptions()
.setAlignment(16)
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
- .template vectorizeIf<LinalgNamedOp>(vectorize)
+ .vectorizeIf<LinalgOp>(vectorize, anchorOpName)
.setVectorTransformsOptions(
vector::VectorTransformsOptions()
.setVectorTransformsOptions(vectorContractLowering)
@@ -156,20 +161,6 @@ void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
strategy.transform(getFunction());
}
-/// Apply transformations specified as patterns.
-void TestLinalgCodegenStrategy::runOnFunction() {
- if (anchorOpName == MatmulOp::getOperationName())
- applyStrategyToNamedLinalgOp<MatmulOp>();
- else if (anchorOpName == MatmulColumnMajorOp::getOperationName())
- applyStrategyToNamedLinalgOp<MatmulColumnMajorOp>();
- else if (anchorOpName == CopyOp::getOperationName())
- applyStrategyToNamedLinalgOp<CopyOp>();
- else if (anchorOpName == FillOp::getOperationName())
- applyStrategyToNamedLinalgOp<FillOp>();
- else
- llvm_unreachable("Unsupported anchor op");
-}
-
namespace mlir {
namespace test {
void registerTestLinalgCodegenStrategy() {
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 4ed00e4fbefc..f2c9067d5cc2 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -45,12 +45,15 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({2}),
- LinalgMarker(Identifier::get("basic_fusion", context),
- Identifier::get("after_basic_fusion", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_basic_fusion_producer", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_basic_fusion_original", context)));
+ LinalgTransformationFilter(
+ Identifier::get("basic_fusion", context),
+ Identifier::get("after_basic_fusion", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_basic_fusion_producer", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_basic_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
@@ -58,12 +61,14 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0}),
- LinalgMarker(Identifier::get("lhs_fusion", context),
- Identifier::get("after_lhs_fusion", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_lhs_fusion_producer", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_lhs_fusion_original", context)));
+ LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
+ Identifier::get("after_lhs_fusion", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_lhs_fusion_producer", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_lhs_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
@@ -71,12 +76,14 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({1}),
- LinalgMarker(Identifier::get("rhs_fusion", context),
- Identifier::get("after_rhs_fusion", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_rhs_fusion_producer", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_rhs_fusion_original", context)));
+ LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
+ Identifier::get("after_rhs_fusion", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_rhs_fusion_producer", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_rhs_fusion_original", context)));
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
@@ -84,12 +91,13 @@ static void fillFusionPatterns(MLIRContext *context,
.setTileSizes({32, 64, 16})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0, 2}),
- LinalgMarker(Identifier::get("two_operand_fusion", context),
- Identifier::get("after_two_operand_fusion", context)),
- LinalgMarker(
+ LinalgTransformationFilter(
+ Identifier::get("two_operand_fusion", context),
+ Identifier::get("after_two_operand_fusion", context)),
+ LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_two_operand_fusion_producer", context)),
- LinalgMarker(
+ LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_two_operand_fusion_original", context)));
@@ -98,11 +106,13 @@ static void fillFusionPatterns(MLIRContext *context,
LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgFusionOptions().setIndicesToFuse({0, 1}),
- LinalgMarker(Identifier::get("transpose_fusion", context),
- Identifier::get("after_transpose_fusion", context)),
- LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get("after_transpose_fusion_producer", context)),
- LinalgMarker(
+ LinalgTransformationFilter(
+ Identifier::get("transpose_fusion", context),
+ Identifier::get("after_transpose_fusion", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_transpose_fusion_producer", context)),
+ LinalgTransformationFilter(
ArrayRef<Identifier>(),
Identifier::get("after_transpose_fusion_original", context)));
}
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index a322b627756e..db05d60ad8c7 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -98,29 +98,35 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
- LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
+ LinalgTransformationFilter(Identifier::get("MEM", ctx),
+ Identifier::get("L3", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
- LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
+ LinalgTransformationFilter(Identifier::get("L3", ctx),
+ Identifier::get("L2", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
- LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
+ LinalgTransformationFilter(Identifier::get("L2", ctx),
+ Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
- LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
+ LinalgTransformationFilter(Identifier::get("L1", ctx),
+ Identifier::get("REG", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
- LinalgMarker({}, Identifier::get("L1", ctx)));
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
- LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
- Identifier::get("L3", ctx),
- Identifier::get("L2", ctx)},
- Identifier::get("REG", ctx)));
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>{Identifier::get("MEM", ctx),
+ Identifier::get("L3", ctx),
+ Identifier::get("L2", ctx)},
+ Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
@@ -130,24 +136,24 @@ static void applyPatterns(FuncOp funcOp) {
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
- LinalgMarker(Identifier::get("__with_perm__", ctx),
- Identifier::get("L2__with_perm__", ctx)));
+ LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
+ Identifier::get("L2__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
- LinalgMarker(Identifier::get("L2__with_perm__", ctx),
- Identifier::get("L1__with_perm__", ctx)));
+ LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
+ Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
- LinalgMarker(Identifier::get("L1__with_perm__", ctx),
- Identifier::get("REG__with_perm__", ctx)));
+ LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
+ Identifier::get("REG__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
- LinalgMarker(Identifier::get("__with_perm__", ctx),
- Identifier::get("L1__with_perm__", ctx)));
+ LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
+ Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
@@ -155,8 +161,9 @@ static void applyPatterns(FuncOp funcOp) {
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
- LinalgMarker(Identifier::get("par__with_perm__", ctx),
- Identifier::get("after_par__with_perm__", ctx)));
+ LinalgTransformationFilter(
+ Identifier::get("par__with_perm__", ctx),
+ Identifier::get("after_par__with_perm__", ctx)));
//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
@@ -164,7 +171,7 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgLoweringPattern<DotOp>>(
ctx,
/*loweringType=*/LinalgLoweringType::Loops,
- LinalgMarker(Identifier::get("REG", ctx)));
+ LinalgTransformationFilter(Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg distribution patterns.
@@ -178,7 +185,8 @@ static void applyPatterns(FuncOp funcOp) {
LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>,
LinalgVectorizationPattern<GenericOp>>(
- ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
+ ctx, LinalgVectorizationOptions(),
+ LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)));
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
@@ -186,34 +194,38 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
- LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get("PERMUTED", ctx)));
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
- LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get("PERMUTED", ctx)));
//===--------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
- LinalgMarker(Identifier::get("_promote_views_", ctx),
- Identifier::get("_views_promoted_", ctx)));
+ LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
+ Identifier::get("_views_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.setUseFullTileBuffersByDefault(true),
- LinalgMarker(Identifier::get("_promote_first_view_", ctx),
- Identifier::get("_first_view_promoted_", ctx)));
+ LinalgTransformationFilter(
+ Identifier::get("_promote_first_view_", ctx),
+ Identifier::get("_first_view_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<FillOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.setUseFullTileBuffers({true})
.setAlignment(32),
- LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
- Identifier::get("_views_aligned_promoted_", ctx)));
+ LinalgTransformationFilter(
+ Identifier::get("_promote_views_aligned_", ctx),
+ Identifier::get("_views_aligned_promoted_", ctx)));
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
@@ -230,18 +242,19 @@ static void fillL1TilingAndMatmulToVectorPatterns(
patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
- LinalgMarker(Identifier::get(startMarker, ctx),
- Identifier::get("L1", ctx))));
+ LinalgTransformationFilter(Identifier::get(startMarker, ctx),
+ Identifier::get("L1", ctx))));
patternsVector.emplace_back(
std::make_unique<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
- LinalgMarker(Identifier::get("L1", ctx),
- Identifier::get("VEC", ctx))));
+ LinalgTransformationFilter(Identifier::get("L1", ctx),
+ Identifier::get("VEC", ctx))));
patternsVector.emplace_back(
std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
- ctx, LinalgMarker(Identifier::get("VEC", ctx))));
+ ctx, LinalgVectorizationOptions(),
+ LinalgTransformationFilter(Identifier::get("VEC", ctx))));
patternsVector.back()
.insert<LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>>(ctx);
@@ -289,8 +302,8 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
- LinalgMarker(Identifier::get("START", ctx),
- Identifier::get("PROMOTE", ctx)));
+ LinalgTransformationFilter(Identifier::get("START", ctx),
+ Identifier::get("PROMOTE", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
@@ -306,7 +319,7 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
copyCallBackFn(b, src, dst, true);
return success();
}),
- LinalgMarker(Identifier::get("PROMOTE", ctx)));
+ LinalgTransformationFilter(Identifier::get("PROMOTE", ctx)));
}
template <typename IdOp, typename NProcsOp>
@@ -335,8 +348,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsEqNiters),
- LinalgMarker(Identifier::get("distribute1", context),
- Identifier::get("after_distribute1", context)));
+ LinalgTransformationFilter(
+ Identifier::get("distribute1", context),
+ Identifier::get("after_distribute1", context)));
}
{
@@ -351,8 +365,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsGeNiters),
- LinalgMarker(Identifier::get("distribute2", context),
- Identifier::get("after_distribute2", context)));
+ LinalgTransformationFilter(
+ Identifier::get("distribute2", context),
+ Identifier::get("after_distribute2", context)));
}
{
@@ -367,8 +382,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsDefault),
- LinalgMarker(Identifier::get("distribute3", context),
- Identifier::get("after_distribute3", context)));
+ LinalgTransformationFilter(
+ Identifier::get("distribute3", context),
+ Identifier::get("after_distribute3", context)));
}
{
@@ -383,8 +399,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsMixed1),
- LinalgMarker(Identifier::get("distribute4", context),
- Identifier::get("after_distribute4", context)));
+ LinalgTransformationFilter(
+ Identifier::get("distribute4", context),
+ Identifier::get("after_distribute4", context)));
}
{
@@ -399,8 +416,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsMixed2),
- LinalgMarker(Identifier::get("distribute5", context),
- Identifier::get("after_distribute5", context)));
+ LinalgTransformationFilter(
+ Identifier::get("distribute5", context),
+ Identifier::get("after_distribute5", context)));
}
{
@@ -416,8 +434,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
.setDistributionOptions(cyclicNprocsMixed3),
- LinalgMarker(Identifier::get("distribute6", context),
- Identifier::get("after_distribute6", context)));
+ LinalgTransformationFilter(
+ Identifier::get("distribute6", context),
+ Identifier::get("after_distribute6", context)));
}
{
@@ -432,8 +451,9 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
.setDistributionOptions(cyclicNprocsEqNiters),
- LinalgMarker(Identifier::get("tensors_distribute1", context),
- Identifier::get("tensors_after_distribute1", context)));
+ LinalgTransformationFilter(
+ Identifier::get("tensors_distribute1", context),
+ Identifier::get("tensors_after_distribute1", context)));
}
}
@@ -452,8 +472,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
- LinalgMarker(Identifier::get("START", ctx),
- Identifier::get("L2", ctx))));
+ LinalgTransformationFilter(Identifier::get("START", ctx),
+ Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
}
@@ -511,7 +531,8 @@ static void applyTileAndPadPattern(FuncOp funcOp) {
.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
context, linalgTilingOptions,
- linalg::LinalgMarker(Identifier::get("tile-and-pad", context)));
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tile-and-pad", context)));
applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
More information about the Mlir-commits
mailing list