[Mlir-commits] [mlir] f4ad1b6 - [mlir][Linalg] Quarantine usage of LinalgTransformationFilter to TestTilingInterface.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Oct 12 08:37:02 PDT 2022
Author: Nicolas Vasilache
Date: 2022-10-12T08:36:51-07:00
New Revision: f4ad1b6f697cc80b1a72b3b24fdae5a4db54e304
URL: https://github.com/llvm/llvm-project/commit/f4ad1b6f697cc80b1a72b3b24fdae5a4db54e304
DIFF: https://github.com/llvm/llvm-project/commit/f4ad1b6f697cc80b1a72b3b24fdae5a4db54e304.diff
LOG: [mlir][Linalg] Quarantine usage of LinalgTransformationFilter to TestTilingInterface.
This revision also retires code that has now become dead.
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785
Differential Revision: https://reviews.llvm.org/D135771
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bbac0899338df..af95ed7544714 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -362,67 +362,6 @@ LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp);
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
-// Marker used as attribute name in generated Linalg rewriting transformations.
-struct LinalgTransforms {
- static const StringLiteral kLinalgTransformMarker;
-};
-
-/// Helper class to control application of linalg transformation patterns.
-/// Control comes in 2 forms:
-/// 1. attribute matching and setting behavior using the attribute named
-/// `kLinalgTransformMarker`. This can be used to build a state machine
-/// using attributes and incrementally applying patterns to advance states.
-/// 2. filter function, which is a simple lambda on the Operation* that
-/// returns a LogicalResult.
-struct LinalgTransformationFilter {
- using FilterFunction = std::function<LogicalResult(Operation *)>;
-
- explicit LinalgTransformationFilter(
- ArrayRef<StringAttr> matchDisjunction = {},
- Optional<StringAttr> replacement = None);
-
- explicit LinalgTransformationFilter(
- const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
- Optional<StringAttr> replacement = None);
-
- LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
- LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
- LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
- void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
- Operation *op) const;
- bool hasReplacementFilter(Operation *op) const;
-
- LinalgTransformationFilter &addFilter(const FilterFunction &f) {
- if (f)
- filters.push_back(f);
- return *this;
- }
-
- template <typename... OpTypes>
- LinalgTransformationFilter &addOpFilter() {
- return addFilter(
- [](Operation *op) { return success(isa<OpTypes...>(op)); });
- }
-
- LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
- return addFilter([opName](Operation *op) {
- return success(op->getName().getStringRef() == opName);
- });
- }
-
- LinalgTransformationFilter &setMatchByDefault() {
- matchByDefault = true;
- return *this;
- }
-
-private:
- SmallVector<FilterFunction> filters;
- SmallVector<StringAttr> matchDisjunction;
- Optional<StringAttr> replacement;
- /// When set to true, if the attribute is not set, it will be treated as
- /// a match. Default is false.
- bool matchByDefault;
-};
using TileSizeComputationFunction =
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
@@ -793,14 +732,7 @@ struct LinalgGeneralizationPattern
}
};
-///
-/// Linalg vectorization patterns.
-///
-/// Empty for now, used for SFINAE purposes only.
-struct LinalgVectorizationOptions {};
-
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
+/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
@@ -811,34 +743,6 @@ struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
/// Return vector::CombiningKind for the given op.
llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
-//===----------------------------------------------------------------------===//
-// Transformation and lowering options exposed as auxiliary structs.
-//===----------------------------------------------------------------------===//
-/// Options to control the application of enabling transformations.
-/// Hoisting transformations are always deemed beneficial and must be disabled
-/// explicitly.
-struct LinalgEnablingOptions {
- /// Enable LICM.
- bool licm = true;
- LinalgEnablingOptions &enableLICM(bool val = true) {
- licm = val;
- return *this;
- }
- /// Enable hoisting of redundant vector transfer ops.
- bool hoistRedundantVectorTransfers = true;
- LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) {
- hoistRedundantVectorTransfers = val;
- return *this;
- }
- /// Enable hoisting of redundant vector transfer ops on tensor.
- bool hoistRedundantVectorTransfersOnTensor = true;
- LinalgEnablingOptions &
- enableHoistRedundantVectorTransfersOnTensor(bool val = true) {
- hoistRedundantVectorTransfersOnTensor = val;
- return *this;
- }
-};
-
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
@@ -971,24 +875,6 @@ struct LinalgCopyVTWForwardingPattern
PatternRewriter &rewriter) const override;
};
-//===----------------------------------------------------------------------===//
-// Support for staged pattern application.
-//===----------------------------------------------------------------------===//
-/// Helper function to allow applying rewrite patterns, interleaved with more
-/// global transformations, in a staged fashion:
-/// 1. the first stage consists of a list of FrozenRewritePatternSet. Each
-/// FrozenRewritePatternSet in this list is applied once, in order.
-/// 2. the second stage consists of a single RewritePattern that is applied
-/// greedily until convergence.
-/// 3. the third stage consists of applying a lambda, generally used for
-/// non-local transformation effects. This allows creating custom fused
-/// transformations where patterns can be ordered and applied at a finer
-/// granularity than a sequence of traditional compiler passes.
-LogicalResult applyStagedPatterns(
- Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
- const FrozenRewritePatternSet &stage2Patterns,
- function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
-
/// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
struct ExtractSliceOfPadTensorSwapPattern
: public OpRewritePattern<tensor::ExtractSliceOp> {
@@ -1015,20 +901,6 @@ struct ExtractSliceOfPadTensorSwapPattern
ControlFn controlFn;
};
-//===----------------------------------------------------------------------===//
-// Helper classes for type list expansion.
-//===----------------------------------------------------------------------===//
-template <typename... OpTypes>
-class VectorizationPatterns;
-
-template <>
-class VectorizationPatterns<> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgVectorizationOptions &options,
- const LinalgTransformationFilter &f) {}
-};
-
/// Split Reduction options.
struct SplitReductionOptions {
// Ratio used to split the reduction dimension. If the ratio is <= 1, nothing
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 34df2c81dc1ca..dd04d00bee523 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -659,14 +659,10 @@ struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> {
LogicalResult matchAndRewrite(tensor::PadOp op,
PatternRewriter &rewriter) const override {
- if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker))
- return failure();
tensor::PadOp newPadOp;
LoopNest loopNest;
if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options)))
return failure();
- newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker,
- rewriter.getUnitAttr());
// Replace all uses of the original tensor::PadOp.
rewriter.replaceOp(op, loopNest.getResults()[0]);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 63f74a320f805..8eb41c5d88b42 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -47,75 +47,6 @@ using namespace mlir::linalg;
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
-// Marker used as attribute name in generated Linalg rewriting transformations.
-const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
- "__internal_linalg_transform__";
-
-mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
- ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
- : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
- replacement(replacement), matchByDefault(false) {}
-
-mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
- const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
- Optional<StringAttr> replacement)
- : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
- replacement(replacement), matchByDefault(false) {
- if (f)
- filters.push_back(f);
-}
-
-LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
- PatternRewriter &rewriter, Operation *op) const {
- if (llvm::any_of(filters,
- [&](const FilterFunction &f) { return failed(f(op)); }))
- return failure();
-
- auto attr = op->template getAttrOfType<StringAttr>(
- LinalgTransforms::kLinalgTransformMarker);
-
- if (!attr) {
- // 1. Has no filter case and matchDisjunction is empty.
- if (matchDisjunction.empty() || matchByDefault)
- return success();
-
- // 2. Has no filter but was expecting a filter.
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << " does not have any filter from list: ";
- interleaveComma(matchDisjunction, diag);
- });
- }
-
- // 4. Match explicit filter.
- for (auto filter : matchDisjunction)
- if (attr.getValue() == filter)
- return success();
-
- // 5. Fail to match.
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << " does not have any filter from list: ";
- interleaveComma(matchDisjunction, diag);
- });
-}
-
-void mlir::linalg::LinalgTransformationFilter::
- replaceLinalgTransformationFilter(PatternRewriter &rewriter,
- Operation *op) const {
- if (replacement.has_value())
- op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
- else
- op->removeAttr(
- rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
-}
-
-bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
- Operation *op) const {
- if (!replacement)
- return false;
- auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
- .dyn_cast<StringAttr>();
- return attr && attr == *replacement;
-}
LinalgTilingOptions &
mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
@@ -432,37 +363,6 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
return vectorizeCopy(rewriter, copyOp);
}
-LogicalResult mlir::linalg::applyStagedPatterns(
- Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
- const FrozenRewritePatternSet &stage2Patterns,
- function_ref<LogicalResult(Operation *)> stage3Lambda) {
- unsigned iteration = 0;
- (void)iteration;
- for (const auto &patterns : stage1Patterns) {
- LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
- << *op);
- if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
- LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
- return failure();
- }
- LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
- << *op);
- if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
- LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
- return failure();
- }
- LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
- << *op);
- if (stage3Lambda) {
- if (failed(stage3Lambda(op)))
- return failure();
- LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
- << *op);
- }
- }
- return success();
-}
-
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 146a32edac521..7313c309214cb 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -127,11 +127,6 @@ static void applyPatterns(func::FuncOp funcOp) {
patterns.add<CopyVectorizationPattern>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-
- // Drop the marker.
- funcOp.walk([](LinalgOp op) {
- op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
- });
}
static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
@@ -182,13 +177,6 @@ static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
- auto lambda = [&](void *) {
- getOperation().walk([](LinalgOp op) {
- op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
- });
- };
- std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
-
if (testPatterns)
return applyPatterns(getOperation());
if (testVectorTransferForwardingPatterns)
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 8e3b9765f0882..31e3c1a529a7c 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -31,26 +31,149 @@
using namespace mlir;
+// TODO: this file should disappear and instead tests should make use of the
+// transform dialect.
namespace {
+/// Marker used as attribute name in generated Linalg rewriting transformations.
+const StringLiteral kLinalgTransformMarker = "__internal_linalg_transform__";
+
+/// Helper class to control application of linalg transformation patterns.
+/// Control comes in 2 forms:
+/// 1. attribute matching and setting behavior using the attribute named
+/// `kLinalgTransformMarker`. This can be used to build a state machine
+/// using attributes and incrementally applying patterns to advance states.
+/// 2. filter function, which is a simple lambda on the Operation* that
+/// returns a LogicalResult.
+struct LinalgTransformationFilter {
+ using FilterFunction = std::function<LogicalResult(Operation *)>;
+
+ explicit LinalgTransformationFilter(
+ ArrayRef<StringAttr> matchDisjunction = {},
+ Optional<StringAttr> replacement = None);
+
+ explicit LinalgTransformationFilter(
+ const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
+ Optional<StringAttr> replacement = None);
+
+ LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
+ LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
+ LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
+ void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
+ Operation *op) const;
+ bool hasReplacementFilter(Operation *op) const;
+
+ LinalgTransformationFilter &addFilter(const FilterFunction &f) {
+ if (f)
+ filters.push_back(f);
+ return *this;
+ }
+
+ template <typename... OpTypes>
+ LinalgTransformationFilter &addOpFilter() {
+ return addFilter(
+ [](Operation *op) { return success(isa<OpTypes...>(op)); });
+ }
+
+ LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+ return addFilter([opName](Operation *op) {
+ return success(op->getName().getStringRef() == opName);
+ });
+ }
+
+ LinalgTransformationFilter &setMatchByDefault() {
+ matchByDefault = true;
+ return *this;
+ }
+
+private:
+ SmallVector<FilterFunction> filters;
+ SmallVector<StringAttr> matchDisjunction;
+ Optional<StringAttr> replacement;
+ /// When set to true, if the attribute is not set, it will be treated as
+ /// a match. Default is false.
+ bool matchByDefault;
+};
+
+LinalgTransformationFilter::LinalgTransformationFilter(
+ ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
+ : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+ replacement(replacement), matchByDefault(false) {}
+
+LinalgTransformationFilter::LinalgTransformationFilter(
+ const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
+ Optional<StringAttr> replacement)
+ : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+ replacement(replacement), matchByDefault(false) {
+ if (f)
+ filters.push_back(f);
+}
+
+LogicalResult
+LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
+ Operation *op) const {
+ if (llvm::any_of(filters,
+ [&](const FilterFunction &f) { return failed(f(op)); }))
+ return failure();
+
+ auto attr = op->template getAttrOfType<StringAttr>(kLinalgTransformMarker);
+
+ if (!attr) {
+ // 1. Has no filter case and matchDisjunction is empty.
+ if (matchDisjunction.empty() || matchByDefault)
+ return success();
+
+ // 2. Has no filter but was expecting a filter.
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << " does not have any filter from list: ";
+ interleaveComma(matchDisjunction, diag);
+ });
+ }
+
+ // 4. Match explicit filter.
+ for (auto filter : matchDisjunction)
+ if (attr.getValue() == filter)
+ return success();
+
+ // 5. Fail to match.
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << " does not have any filter from list: ";
+ interleaveComma(matchDisjunction, diag);
+ });
+}
+
+void LinalgTransformationFilter::replaceLinalgTransformationFilter(
+ PatternRewriter &rewriter, Operation *op) const {
+ if (replacement.has_value())
+ op->setAttr(kLinalgTransformMarker, replacement.value());
+ else
+ op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
+}
+
+bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
+ if (!replacement)
+ return false;
+ auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast<StringAttr>();
+ return attr && attr == *replacement;
+}
+
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
/// using a `filter` to avoid recursive application.
struct TestTileUsingSCFForOp
: public OpInterfaceRewritePattern<TilingInterface> {
- TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
+ TestTileUsingSCFForOp(
+ MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
/// Construct a generic pattern applied to `opName`.
- TestTileUsingSCFForOp(StringRef opName, MLIRContext *context,
- scf::SCFTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
+ TestTileUsingSCFForOp(
+ StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -76,7 +199,7 @@ struct TestTileUsingSCFForOp
private:
scf::SCFTilingOptions options;
- linalg::LinalgTransformationFilter filter;
+ LinalgTransformationFilter filter;
};
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
@@ -87,8 +210,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
: public OpInterfaceRewritePattern<TilingInterface> {
TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
MLIRContext *context, scf::SCFTileAndFuseOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -97,8 +219,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
StringRef opName, MLIRContext *context,
scf::SCFTileAndFuseOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)), filter(std::move(filter)) {}
@@ -129,7 +250,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
private:
scf::SCFTileAndFuseOptions options;
- linalg::LinalgTransformationFilter filter;
+ LinalgTransformationFilter filter;
};
/// Pattern to lower operations that implement the `TilingInterface` to
@@ -202,8 +323,8 @@ static void addPatternForTiling(MLIRContext *context,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
- linalg::LinalgTransformationFilter filter(
- StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+ LinalgTransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
@@ -215,8 +336,8 @@ static void addPatternForTileAndFuse(MLIRContext *context,
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
interchange);
- linalg::LinalgTransformationFilter filter(
- StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+ LinalgTransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
context, tileAndFuseOptions, filter);
}
More information about the Mlir-commits
mailing list