[Mlir-commits] [mlir] e625aea - [mlir][Linalg] Retire Linalg generic interchange pattern and pass
Guray Ozen
llvmlistbot at llvm.org
Tue Aug 23 00:40:07 PDT 2022
Author: Guray Ozen
Date: 2022-08-23T09:28:16+02:00
New Revision: e625aea33a653d23d83aab8ea30e6bf7dd0b6b51
URL: https://github.com/llvm/llvm-project/commit/e625aea33a653d23d83aab8ea30e6bf7dd0b6b51
DIFF: https://github.com/llvm/llvm-project/commit/e625aea33a653d23d83aab8ea30e6bf7dd0b6b51.diff
LOG: [mlir][Linalg] Retire Linalg generic interchange pattern and pass
This revision removes the Linalg generic interchange pattern and pass.
It also changes transform-patterns test to make use of transform dialect.
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785
Differential Revision: https://reviews.llvm.org/D132368
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index ec7db843b3449..ecf684c918d82 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -108,13 +108,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyDecomposePass(
const linalg::LinalgTransformationFilter &filter =
linalg::LinalgTransformationFilter());
-/// Create a LinalgStrategyInterchangePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgStrategyInterchangePass(
- ArrayRef<int64_t> iteratorInterchange = {},
- const linalg::LinalgTransformationFilter &filter =
- linalg::LinalgTransformationFilter());
-
/// Create a LinalgStrategyPeelPass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPeelPass(
StringRef opName = "",
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c497135959998..85cbdd83874ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -235,17 +235,6 @@ def LinalgStrategyDecomposePass
];
}
-def LinalgStrategyInterchangePass
- : Pass<"linalg-strategy-interchange-pass", "func::FuncOp"> {
- let summary = "Configurable pass to apply pattern-based iterator interchange.";
- let constructor = "mlir::createLinalgStrategyInterchangePass()";
- let dependentDialects = ["linalg::LinalgDialect"];
- let options = [
- Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
- "Which func op is the anchor to latch on.">,
- ];
-}
-
def LinalgStrategyPeelPass
: Pass<"linalg-strategy-peel-pass", "func::FuncOp"> {
let summary = "Configurable pass to apply pattern-based linalg peeling.";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 82a49487458a8..d28f1ccb3de3d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -96,23 +96,6 @@ struct Generalize : public Transformation {
std::string opName;
};
-/// Represent one application of createLinalgStrategyInterchangePass.
-struct Interchange : public Transformation {
- explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
- LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(std::move(f)),
- iteratorInterchange(iteratorInterchange.begin(),
- iteratorInterchange.end()) {}
-
- void addToPassPipeline(OpPassManager &pm,
- LinalgTransformationFilter m) const override {
- pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
- }
-
-private:
- SmallVector<int64_t> iteratorInterchange;
-};
-
/// Represent one application of createLinalgStrategyDecomposePass.
struct Decompose : public Transformation {
explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -250,20 +233,6 @@ struct CodegenStrategy {
LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? generalize(opName, std::move(f)) : *this;
}
- /// Append a pattern to interchange iterators.
- CodegenStrategy &
- interchange(ArrayRef<int64_t> iteratorInterchange,
- const LinalgTransformationFilter::FilterFunction &f = nullptr) {
- transformationSequence.emplace_back(
- std::make_unique<Interchange>(iteratorInterchange, f));
- return *this;
- }
- /// Conditionally append a pattern to interchange iterators.
- CodegenStrategy &
- interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
- LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? interchange(iteratorInterchange, std::move(f)) : *this;
- }
/// Append patterns to decompose convolutions.
CodegenStrategy &
decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5a4c41a938175..8f53d5a8fea8c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -828,38 +828,6 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
LinalgTilingAndFusionOptions options;
};
-///
-/// Linalg generic interchange pattern.
-///
-/// Apply the `interchange` transformation on a RewriterBase.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `interchange` for more details.
-struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- /// GenericOp-specific constructor with an optional `filter`.
- GenericOpInterchangePattern(
- MLIRContext *context, ArrayRef<unsigned> interchangeVector,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<GenericOp>
- returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(GenericOp op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgTransformationFilter filter;
- /// The interchange vector to reorder the iterators and indexing_maps dims.
- SmallVector<unsigned, 8> interchangeVector;
-};
-
///
/// Linalg generalization pattern.
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index ee0846fe7a55d..22b97ffcf78ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -199,37 +199,6 @@ struct LinalgStrategyDecomposePass
LinalgTransformationFilter filter;
};
-/// Configurable pass to apply pattern-based linalg generalization.
-struct LinalgStrategyInterchangePass
- : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
-
- LinalgStrategyInterchangePass() = default;
-
- LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
- LinalgTransformationFilter filter)
- : iteratorInterchange(iteratorInterchange.begin(),
- iteratorInterchange.end()),
- filter(std::move(filter)) {}
-
- void runOnOperation() override {
- auto funcOp = getOperation();
- if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
- return;
-
- SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
- iteratorInterchange.end());
- RewritePatternSet interchangePattern(funcOp.getContext());
- interchangePattern.add<GenericOpInterchangePattern>(
- funcOp.getContext(), interchangeVector, filter);
- if (failed(applyPatternsAndFoldGreedily(funcOp,
- std::move(interchangePattern))))
- signalPassFailure();
- }
-
- SmallVector<int64_t> iteratorInterchange;
- LinalgTransformationFilter filter;
-};
-
/// Configurable pass to apply pattern-based linalg peeling.
struct LinalgStrategyPeelPass
: public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> {
@@ -491,15 +460,6 @@ mlir::createLinalgStrategyDecomposePass(
return std::make_unique<LinalgStrategyDecomposePass>(filter);
}
-/// Create a LinalgStrategyInterchangePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgStrategyInterchangePass(
- ArrayRef<int64_t> iteratorInterchange,
- const LinalgTransformationFilter &filter) {
- return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
- filter);
-}
-
/// Create a LinalgStrategyPeelPass.
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createLinalgStrategyPeelPass(StringRef opName,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 11152f177be19..2fcbe680259fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -537,29 +537,6 @@ mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
return tileLoopNest;
}
-/// Linalg generic interchange pattern.
-mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
- MLIRContext *context, ArrayRef<unsigned> interchangeVector,
- LinalgTransformationFilter f, PatternBenefit benefit)
- : OpRewritePattern(context, benefit), filter(std::move(f)),
- interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
-
-FailureOr<GenericOp>
-mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
- GenericOp genericOp, PatternRewriter &rewriter) const {
- if (failed(filter.checkAndNotify(rewriter, genericOp)))
- return failure();
-
- FailureOr<GenericOp> transformedOp =
- interchangeGenericOp(rewriter, genericOp, interchangeVector);
- if (failed(transformedOp))
- return failure();
-
- // New filter if specified.
- filter.replaceLinalgTransformationFilter(rewriter, genericOp);
- return transformedOp;
-}
-
/// Linalg generalization pattern.
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 3a704e409d8e2..e7053c3a3383f 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file -test-transform-dialect-interpreter | FileCheck %s
// CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
@@ -114,6 +114,14 @@ func.func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
return
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]}
+ }
+}
// CHECK-LABEL: func @permute_generic
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index fc988eaa267c1..576082a572e10 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -230,15 +230,6 @@ static void applyPatterns(func::FuncOp funcOp) {
.addOpFilter<MatmulOp, FillOp, GenericOp>());
patterns.add<CopyVectorizationPattern>(ctx);
- //===--------------------------------------------------------------------===//
- // Linalg generic interchange pattern.
- //===--------------------------------------------------------------------===//
- patterns.add<GenericOpInterchangePattern>(
- ctx,
- /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
- LinalgTransformationFilter(ArrayRef<StringAttr>{},
- StringAttr::get(ctx, "PERMUTED")));
-
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
// Drop the marker.
More information about the Mlir-commits
mailing list