[Mlir-commits] [mlir] e625aea - [mlir][Linalg] Retire Linalg generic interchange pattern and pass

Guray Ozen llvmlistbot at llvm.org
Tue Aug 23 00:40:07 PDT 2022


Author: Guray Ozen
Date: 2022-08-23T09:28:16+02:00
New Revision: e625aea33a653d23d83aab8ea30e6bf7dd0b6b51

URL: https://github.com/llvm/llvm-project/commit/e625aea33a653d23d83aab8ea30e6bf7dd0b6b51
DIFF: https://github.com/llvm/llvm-project/commit/e625aea33a653d23d83aab8ea30e6bf7dd0b6b51.diff

LOG: [mlir][Linalg] Retire Linalg generic interchange pattern and pass

This revision removes the Linalg generic interchange pattern and pass.

It also changes transform-patterns test to make use of transform dialect.

Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785

Differential Revision: https://reviews.llvm.org/D132368

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index ec7db843b3449..ecf684c918d82 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -108,13 +108,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyDecomposePass(
     const linalg::LinalgTransformationFilter &filter =
         linalg::LinalgTransformationFilter());
 
-/// Create a LinalgStrategyInterchangePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgStrategyInterchangePass(
-    ArrayRef<int64_t> iteratorInterchange = {},
-    const linalg::LinalgTransformationFilter &filter =
-        linalg::LinalgTransformationFilter());
-
 /// Create a LinalgStrategyPeelPass.
 std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPeelPass(
     StringRef opName = "",

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c497135959998..85cbdd83874ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -235,17 +235,6 @@ def LinalgStrategyDecomposePass
   ];
 }
 
-def LinalgStrategyInterchangePass
-    : Pass<"linalg-strategy-interchange-pass", "func::FuncOp"> {
-  let summary = "Configurable pass to apply pattern-based iterator interchange.";
-  let constructor = "mlir::createLinalgStrategyInterchangePass()";
-  let dependentDialects = ["linalg::LinalgDialect"];
-  let options = [
-    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
-      "Which func op is the anchor to latch on.">,
-  ];
-}
-
 def LinalgStrategyPeelPass
     : Pass<"linalg-strategy-peel-pass", "func::FuncOp"> {
   let summary = "Configurable pass to apply pattern-based linalg peeling.";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 82a49487458a8..d28f1ccb3de3d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -96,23 +96,6 @@ struct Generalize : public Transformation {
   std::string opName;
 };
 
-/// Represent one application of createLinalgStrategyInterchangePass.
-struct Interchange : public Transformation {
-  explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
-                       LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(std::move(f)),
-        iteratorInterchange(iteratorInterchange.begin(),
-                            iteratorInterchange.end()) {}
-
-  void addToPassPipeline(OpPassManager &pm,
-                         LinalgTransformationFilter m) const override {
-    pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
-  }
-
-private:
-  SmallVector<int64_t> iteratorInterchange;
-};
-
 /// Represent one application of createLinalgStrategyDecomposePass.
 struct Decompose : public Transformation {
   explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -250,20 +233,6 @@ struct CodegenStrategy {
                LinalgTransformationFilter::FilterFunction f = nullptr) {
     return b ? generalize(opName, std::move(f)) : *this;
   }
-  /// Append a pattern to interchange iterators.
-  CodegenStrategy &
-  interchange(ArrayRef<int64_t> iteratorInterchange,
-              const LinalgTransformationFilter::FilterFunction &f = nullptr) {
-    transformationSequence.emplace_back(
-        std::make_unique<Interchange>(iteratorInterchange, f));
-    return *this;
-  }
-  /// Conditionally append a pattern to interchange iterators.
-  CodegenStrategy &
-  interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
-                LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? interchange(iteratorInterchange, std::move(f)) : *this;
-  }
   /// Append patterns to decompose convolutions.
   CodegenStrategy &
   decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5a4c41a938175..8f53d5a8fea8c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -828,38 +828,6 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
   LinalgTilingAndFusionOptions options;
 };
 
-///
-/// Linalg generic interchange pattern.
-///
-/// Apply the `interchange` transformation on a RewriterBase.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `interchange` for more details.
-struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  /// GenericOp-specific constructor with an optional `filter`.
-  GenericOpInterchangePattern(
-      MLIRContext *context, ArrayRef<unsigned> interchangeVector,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1);
-
-  /// `matchAndRewrite` implementation that returns the significant transformed
-  /// pieces of IR.
-  FailureOr<GenericOp>
-  returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const;
-
-  LogicalResult matchAndRewrite(GenericOp op,
-                                PatternRewriter &rewriter) const override {
-    return returningMatchAndRewrite(op, rewriter);
-  }
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
-  /// The interchange vector to reorder the iterators and indexing_maps dims.
-  SmallVector<unsigned, 8> interchangeVector;
-};
-
 ///
 /// Linalg generalization pattern.
 ///

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index ee0846fe7a55d..22b97ffcf78ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -199,37 +199,6 @@ struct LinalgStrategyDecomposePass
   LinalgTransformationFilter filter;
 };
 
-/// Configurable pass to apply pattern-based linalg generalization.
-struct LinalgStrategyInterchangePass
-    : public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
-
-  LinalgStrategyInterchangePass() = default;
-
-  LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
-                                LinalgTransformationFilter filter)
-      : iteratorInterchange(iteratorInterchange.begin(),
-                            iteratorInterchange.end()),
-        filter(std::move(filter)) {}
-
-  void runOnOperation() override {
-    auto funcOp = getOperation();
-    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
-      return;
-
-    SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
-                                            iteratorInterchange.end());
-    RewritePatternSet interchangePattern(funcOp.getContext());
-    interchangePattern.add<GenericOpInterchangePattern>(
-        funcOp.getContext(), interchangeVector, filter);
-    if (failed(applyPatternsAndFoldGreedily(funcOp,
-                                            std::move(interchangePattern))))
-      signalPassFailure();
-  }
-
-  SmallVector<int64_t> iteratorInterchange;
-  LinalgTransformationFilter filter;
-};
-
 /// Configurable pass to apply pattern-based linalg peeling.
 struct LinalgStrategyPeelPass
     : public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> {
@@ -491,15 +460,6 @@ mlir::createLinalgStrategyDecomposePass(
   return std::make_unique<LinalgStrategyDecomposePass>(filter);
 }
 
-/// Create a LinalgStrategyInterchangePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgStrategyInterchangePass(
-    ArrayRef<int64_t> iteratorInterchange,
-    const LinalgTransformationFilter &filter) {
-  return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
-                                                         filter);
-}
-
 /// Create a LinalgStrategyPeelPass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 mlir::createLinalgStrategyPeelPass(StringRef opName,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 11152f177be19..2fcbe680259fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -537,29 +537,6 @@ mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
   return tileLoopNest;
 }
 
-/// Linalg generic interchange pattern.
-mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
-    MLIRContext *context, ArrayRef<unsigned> interchangeVector,
-    LinalgTransformationFilter f, PatternBenefit benefit)
-    : OpRewritePattern(context, benefit), filter(std::move(f)),
-      interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
-
-FailureOr<GenericOp>
-mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
-    GenericOp genericOp, PatternRewriter &rewriter) const {
-  if (failed(filter.checkAndNotify(rewriter, genericOp)))
-    return failure();
-
-  FailureOr<GenericOp> transformedOp =
-      interchangeGenericOp(rewriter, genericOp, interchangeVector);
-  if (failed(transformedOp))
-    return failure();
-
-  // New filter if specified.
-  filter.replaceLinalgTransformationFilter(rewriter, genericOp);
-  return transformedOp;
-}
-
 /// Linalg generalization pattern.
 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 3a704e409d8e2..e7053c3a3383f 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file -test-transform-dialect-interpreter | FileCheck %s
 
 // CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
@@ -114,6 +114,14 @@ func.func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
   }
   return
 }
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]}
+  }
+}
 // CHECK-LABEL:  func @permute_generic
 // CHECK:        linalg.generic {
 // CHECK-SAME:   indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index fc988eaa267c1..576082a572e10 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -230,15 +230,6 @@ static void applyPatterns(func::FuncOp funcOp) {
                .addOpFilter<MatmulOp, FillOp, GenericOp>());
   patterns.add<CopyVectorizationPattern>(ctx);
 
-  //===--------------------------------------------------------------------===//
-  // Linalg generic interchange pattern.
-  //===--------------------------------------------------------------------===//
-  patterns.add<GenericOpInterchangePattern>(
-      ctx,
-      /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
-      LinalgTransformationFilter(ArrayRef<StringAttr>{},
-                                 StringAttr::get(ctx, "PERMUTED")));
-
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 
   // Drop the marker.


        


More information about the Mlir-commits mailing list