[Mlir-commits] [mlir] e3d386e - [mlir][linalg] Add a tile and fuse on tensors pattern.

Tobias Gysi llvmlistbot at llvm.org
Mon Nov 22 03:14:14 PST 2021


Author: Tobias Gysi
Date: 2021-11-22T11:13:21Z
New Revision: e3d386ea27336edc04ae4fd324ab4337b9f3cf16

URL: https://github.com/llvm/llvm-project/commit/e3d386ea27336edc04ae4fd324ab4337b9f3cf16
DIFF: https://github.com/llvm/llvm-project/commit/e3d386ea27336edc04ae4fd324ab4337b9f3cf16.diff

LOG: [mlir][linalg] Add a tile and fuse on tensors pattern.

Add a pattern to apply the new tile and fuse on tensors method. Integrate the pattern into the CodegenStrategy and use the CodegenStrategy to implement the tests.

Depends On D114012

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114067

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index c0173ec2f443a..b25ba69a42a38 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -81,6 +81,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFuseTensorOpsPass();
 //===----------------------------------------------------------------------===//
 /// Linalg strategy passes.
 //===----------------------------------------------------------------------===//
+/// Create a LinalgStrategyTileAndFusePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTileAndFusePass(
+    StringRef opName = "", linalg::LinalgTilingAndFusionOptions opt = {},
+    linalg::LinalgTransformationFilter filter =
+        linalg::LinalgTransformationFilter());
+
 /// Create a LinalgStrategyTilePass.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTilePass(
     StringRef opName = "",

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c9bcfebecb022..4c55b4e349d74 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -235,6 +235,18 @@ def LinalgTileAndFuseTensorOps
   let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"];
 }
 
+def LinalgStrategyTileAndFusePass
+    : FunctionPass<"linalg-strategy-tile-and-fuse-pass"> {
+  let summary = "Configurable pass to apply pattern-based tiling and fusion.";
+  let constructor = "mlir::createLinalgStrategyTileAndFusePass()";
+  let options = [
+    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+      "Which func op is the anchor to latch on.">,
+    Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+      "Which linalg op within the func is the anchor to latch on.">,
+  ];
+}
+
 def LinalgStrategyTilePass
     : FunctionPass<"linalg-strategy-tile-pass"> {
   let summary = "Configurable pass to apply pattern-based linalg tiling.";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index f4579a6150489..ea2afa01889b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -30,6 +30,22 @@ struct Transformation {
   LinalgTransformationFilter::FilterFunction filter = nullptr;
 };
 
+/// Represent one application of LinalgStrategyTileAndFusePass.
+struct TileAndFuse : public Transformation {
+  TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
+              LinalgTransformationFilter::FilterFunction f = nullptr)
+      : Transformation(f), opName(name), options(options) {}
+
+  void addToPassPipeline(OpPassManager &pm,
+                         LinalgTransformationFilter m) const override {
+    pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
+  }
+
+private:
+  std::string opName;
+  linalg::LinalgTilingAndFusionOptions options;
+};
+
 /// Represent one application of LinalgStrategyTilePass.
 struct Tile : public Transformation {
   Tile(StringRef name, linalg::LinalgTilingOptions options,
@@ -147,6 +163,22 @@ struct VectorLowering : public Transformation {
 
 /// Codegen strategy controls how a Linalg op is progressively lowered.
 struct CodegenStrategy {
+  /// Append a pattern to tile the Op `opName` and fuse its producers with
+  /// tiling and fusion `options`.
+  CodegenStrategy &
+  tileAndFuse(StringRef opName, LinalgTilingAndFusionOptions options,
+              LinalgTransformationFilter::FilterFunction f = nullptr) {
+    transformationSequence.emplace_back(
+        std::make_unique<TileAndFuse>(opName, options, f));
+    return *this;
+  }
+  /// Conditionally append a pattern to tile the Op `opName` and fuse its
+  /// producers with tiling and fusion `options`.
+  CodegenStrategy &
+  tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
+                LinalgTransformationFilter::FilterFunction f = nullptr) {
+    return b ? tileAndFuse(opName, options, f) : *this;
+  }
   /// Append a pattern to add a level of tiling for Op `opName` with tiling
   /// `options`.
   CodegenStrategy &
@@ -161,7 +193,7 @@ struct CodegenStrategy {
   CodegenStrategy &
   tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
          LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? tile(opName, options) : *this;
+    return b ? tile(opName, options, f) : *this;
   }
   /// Append a pattern to pad and hoist the operands of Op `opName` with padding
   /// `options`.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6c044bc26c934..cfdd560f0cacc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -517,6 +517,13 @@ struct LinalgPaddingOptions {
   }
 };
 
+struct LinalgTilingAndFusionOptions {
+  /// Tile sizes used to tile the root operation.
+  SmallVector<int64_t> tileSizes;
+  /// Tile interchange used to permute the tile loops.
+  SmallVector<int64_t> tileInterchange;
+};
+
 struct LinalgTilingOptions {
   /// Computation function that returns the tile sizes for each operation.
   /// Delayed construction of constant tile sizes should occur to interoperate
@@ -767,6 +774,34 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
             fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {}
 };
 
+///
+/// Linalg tile and fuse tensor ops pattern.
+///
+/// Apply tiling and fusion as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tileConsumerAndFuseProducers` for more details.
+struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
+  // Entry point to match any LinalgOp.
+  LinalgTileAndFuseTensorOpsPattern(
+      MLIRContext *context, LinalgTilingAndFusionOptions options,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  // Entry point to match a specific LinalgOp.
+  LinalgTileAndFuseTensorOpsPattern(
+      StringRef opName, MLIRContext *context,
+      LinalgTilingAndFusionOptions options,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+  /// Tile sizes and interchange used to tile the root operation.
+  LinalgTilingAndFusionOptions options;
+};
+
 ///
 /// Linalg generic interchage pattern.
 ///

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 25aee5f23a513..c1cdd3eda2cb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -235,6 +235,9 @@ class TileLoopNest {
   /// Returns the tiled root operation.
   LinalgOp getRootOp() { return rootOp; }
 
+  /// Returns the tiled root operation and the fused producers.
+  SmallVector<LinalgOp> getAllTiledAndFusedOps();
+
   /// Returns the loop ops generated from tiling.
   ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 00904a48712b7..68d02db728df7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -390,6 +390,17 @@ ValueRange TileLoopNest::getRootOpReplacementResults() {
   return tileLoopOps.front()->getOpResults();
 }
 
+SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
+  SmallVector<LinalgOp> result;
+  for (const auto &kvp : tiledRootAndFusedOpsLoops) {
+    auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst());
+    assert(linalgOp &&
+           "expect all tiled and fused operations are linalg operations");
+    result.push_back(linalgOp);
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // Tile and fuse entry-points.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 24cec12cec62d..acbc30a93c6c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -36,6 +36,43 @@ using namespace linalg;
 
 namespace {
 
+/// Configurable pass to apply pattern-based tiling and fusion.
+struct LinalgStrategyTileAndFusePass
+    : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
+
+  LinalgStrategyTileAndFusePass() = default;
+
+  LinalgStrategyTileAndFusePass(StringRef opName,
+                                LinalgTilingAndFusionOptions opt,
+                                LinalgTransformationFilter filt)
+      : options(opt), filter(filt) {
+    this->anchorOpName.setValue(opName.str());
+  }
+
+  void runOnFunction() override {
+    auto funcOp = getFunction();
+    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+      return;
+
+    RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
+    if (!anchorOpName.empty()) {
+      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+          anchorOpName, funcOp.getContext(), options, filter);
+    } else {
+      tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+          funcOp.getContext(), options, filter);
+    }
+    // Search the root operation using bottom up traversal.
+    GreedyRewriteConfig grc;
+    grc.useTopDownTraversal = false;
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(tilingAndFusionPattern), grc);
+  }
+
+  LinalgTilingAndFusionOptions options;
+  LinalgTransformationFilter filter;
+};
+
 /// Configurable pass to apply pattern-based linalg tiling.
 struct LinalgStrategyTilePass
     : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
@@ -380,6 +417,15 @@ struct LinalgStrategyRemoveMarkersPass
 };
 } // namespace
 
+/// Create a LinalgStrategyTileAndFusePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyTileAndFusePass(StringRef opName,
+                                          LinalgTilingAndFusionOptions options,
+                                          LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
+                                                         filter);
+}
+
 /// Create a LinalgStrategyTilePass.
 std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 36bb0171823f7..20073359a68ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -520,6 +520,75 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
   return success();
 }
 
+/// Linalg tile and fuse tensor ops pattern.
+mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
+    LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
+                                      LinalgTilingAndFusionOptions options,
+                                      LinalgTransformationFilter filter,
+                                      PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
+      options(options) {}
+
+mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
+    LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
+                                      LinalgTilingAndFusionOptions options,
+                                      LinalgTransformationFilter filter,
+                                      PatternBenefit benefit)
+    : RewritePattern(opName, benefit, context), filter(filter),
+      options(options) {}
+
+LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  LinalgOp rootOp = dyn_cast<LinalgOp>(op);
+  if (!rootOp)
+    return failure();
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+
+  // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
+  if (options.tileSizes.size() < rootOp.getNumLoops())
+    return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
+
+  // Check `tileInterchange` contains no entries or as many as `tileSizes`.
+  if (!options.tileInterchange.empty() &&
+      options.tileInterchange.size() != options.tileSizes.size())
+    return rewriter.notifyMatchFailure(
+        op, "expect the number of tile sizes and interchange dims to match");
+
+  // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
+  SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
+                                     options.tileSizes.begin() +
+                                         rootOp.getNumLoops());
+  SmallVector<int64_t> rootInterchange =
+      options.tileInterchange.empty()
+          ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
+          : SmallVector<int64_t>(options.tileInterchange.begin(),
+                                 options.tileInterchange.begin() +
+                                     rootOp.getNumLoops());
+
+  // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
+  // It has to be a permutation since the tiling cannot tile the same loop
+  // dimension multiple times.
+  if (!isPermutation(rootInterchange))
+    return rewriter.notifyMatchFailure(
+        op, "expect the tile interchange permutes the root loops");
+
+  // Tile `rootOp` and fuse its producers.
+  FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
+      rewriter, rootOp, rootTileSizes, rootInterchange);
+  if (failed(tileLoopNest))
+    return rewriter.notifyMatchFailure(
+        op, "tileConsumerAndFuseProducers failed unexpectedly");
+
+  // Replace all uses of the tiled loop operation.
+  rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
+
+  // Apply the filter if specified.
+  for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
+    filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
+  return failure();
+}
+
 /// Linalg generic interchange pattern.
 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
     MLIRContext *context, ArrayRef<unsigned> interchangeVector,

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
index 90b9ad60b97f3..e46218995ffb8 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=MATMUL %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=GENERIC %s
 
-//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
-//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
-//  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)>
-//  CHECK-DAG:  #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+//  MATMUL-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+//  MATMUL-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+//  MATMUL-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)>
+//  MATMUL-DAG:  #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
 
-//      CHECK:  fuse_input
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+//      MATMUL:  fuse_input
+// MATMUL-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
 builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
                          %arg1: tensor<12x25xf32>,
                          %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -18,31 +19,31 @@ builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
   %cst = arith.constant 0.000000e+00 : f32
   %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
 
-  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
-  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
-  //      CHECK:      %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
-  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
-  //      CHECK:        %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+  //      MATMUL:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      MATMUL:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      MATMUL:      %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+  //      MATMUL:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+  //      MATMUL:        %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
 
   // Tile both input operand dimensions.
-  //      CHECK:        %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]])
-  //      CHECK:        %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
-  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
-  // CHECK-SAME:                                          %[[IV1]], %[[IV2]]
-  // CHECK-SAME:                                          %[[UB1]], %[[UB2]]
-  //      CHECK:        %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
-  //      CHECK:        %{{.*}} = linalg.matmul ins(%[[T1]]
+  //      MATMUL:        %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]])
+  //      MATMUL:        %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+  //      MATMUL:        %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  // MATMUL-SAME:                                          %[[IV1]], %[[IV2]]
+  // MATMUL-SAME:                                          %[[UB1]], %[[UB2]]
+  //      MATMUL:        %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      MATMUL:        %{{.*}} = linalg.matmul ins(%[[T1]]
   %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %1 : tensor<24x25xf32>
 }
 
 // -----
 
-//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
-//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+//  MATMUL-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+//  MATMUL-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
 
-//      CHECK:  fuse_output
-// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+//      MATMUL:  fuse_output
+// MATMUL-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
 builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
                           %arg1: tensor<12x25xf32>,
                           %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -55,34 +56,34 @@ builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
   %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
 
   // Update the iteration argument of the outermost tile loop.
-  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
-  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
-  //      CHECK:      %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
-  //      CHECK:      %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]])
+  //      MATMUL:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+  //      MATMUL:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+  //      MATMUL:      %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+  //      MATMUL:      %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]])
 
   // Tile the both output operand dimensions.
-  //      CHECK:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
-  // CHECK-SAME:                                        %[[IV1]], %[[IV0]]
-  // CHECK-SAME:                                        %[[TS1]], %[[TS0]]
-  //      CHECK:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
-  //      CHECK:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
-  //      CHECK:          %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
+  //      MATMUL:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+  // MATMUL-SAME:                                        %[[IV1]], %[[IV0]]
+  // MATMUL-SAME:                                        %[[TS1]], %[[TS0]]
+  //      MATMUL:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      MATMUL:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+  //      MATMUL:          %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
   %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %1 : tensor<24x25xf32>
 }
 
 // -----
 
-//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
-//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
-//  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)>
-//  CHECK-DAG:  #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+//  MATMUL-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+//  MATMUL-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+//  MATMUL-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)>
+//  MATMUL-DAG:  #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 
-//      CHECK:  fuse_reduction
-// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
-// CHECK-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32>
+//      MATMUL:  fuse_reduction
+// MATMUL-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// MATMUL-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32>
 builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
                              %arg1: tensor<12x25xf32>,
                              %arg2: tensor<24x25xf32>,
@@ -98,23 +99,23 @@ builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
     linalg.yield %2 : f32
   } -> tensor<12x25xf32>
 
-  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
-  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
-  //      CHECK:      %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
-  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
-  //      CHECK:        %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
-  //      CHECK:        %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
-  //      CHECK:        %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]])
+  //      MATMUL:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      MATMUL:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      MATMUL:      %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
+  //      MATMUL:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+  //      MATMUL:        %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+  //      MATMUL:        %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+  //      MATMUL:        %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]])
 
   // Tile only the parallel dimensions but not the reduction dimension.
-  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
-  // CHECK-SAME:                                          %[[IV2]], 0, %[[IV0]]
-  // CHECK-SAME:                                          %[[UB2]], 7, %[[UB0]]
-  //      CHECK:        %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
-  // CHECK-SAME:                                          %[[IV2]], %[[IV0]]
-  // CHECK-SAME:                                          %[[UB2]], %[[UB0]]
-  //      CHECK:        %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
-  //      CHECK:        %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]]
+  //      MATMUL:        %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+  // MATMUL-SAME:                                          %[[IV2]], 0, %[[IV0]]
+  // MATMUL-SAME:                                          %[[UB2]], 7, %[[UB0]]
+  //      MATMUL:        %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+  // MATMUL-SAME:                                          %[[IV2]], %[[IV0]]
+  // MATMUL-SAME:                                          %[[UB2]], %[[UB0]]
+  //      MATMUL:        %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+  //      MATMUL:        %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]]
   %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %1 : tensor<24x25xf32>
 }
@@ -124,9 +125,9 @@ builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
 
-//      CHECK:  fuse_transposed
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
-// CHECK-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32>
+//      MATMUL:  fuse_transposed
+// MATMUL-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// MATMUL-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32>
 builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>,
                               %arg1: tensor<12x25xf32>,
                               %arg2: tensor<24x25xf32>,
@@ -142,26 +143,26 @@ builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>,
     linalg.yield %2 : f32
   } -> tensor<24x12xf32>
 
-  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
-  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
-  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+  //      MATMUL:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      MATMUL:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      MATMUL:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
 
   // Swap the input operand slice offsets due to the transposed indexing map.
-  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
-  // CHECK-SAME:                                          %[[IV2]], %[[IV1]]
-  //      CHECK:        %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
-  // CHECK-SAME:                                          %[[IV1]], %[[IV2]]
-  //      CHECK:        %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
-  //      CHECK:        %{{.*}} = linalg.matmul ins(%[[T2]]
+  //      MATMUL:        %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+  // MATMUL-SAME:                                          %[[IV2]], %[[IV1]]
+  //      MATMUL:        %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+  // MATMUL-SAME:                                          %[[IV1]], %[[IV2]]
+  //      MATMUL:        %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+  //      MATMUL:        %{{.*}} = linalg.matmul ins(%[[T2]]
   %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %1 : tensor<24x25xf32>
 }
 
 // -----
 
-//      CHECK:  fuse_input_and_output
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
-// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+//      MATMUL:  fuse_input_and_output
+// MATMUL-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// MATMUL-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
 builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
                                     %arg1: tensor<12x25xf32>,
                                     %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -175,27 +176,27 @@ builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
   %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
 
   // Fuse both producers to the appropriate tile loops.
-  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
-  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
-  //      CHECK:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
-  // CHECK-SAME:                                        %[[IV1]], %[[IV0]]
-  //      CHECK:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
-  //      CHECK:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
-  //      CHECK:          %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
-  // CHECK-SAME:                                            %[[IV1]], %[[IV2]]
-  //      CHECK:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
-  //      CHECK:          %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
+  //      MATMUL:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+  //      MATMUL:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+  //      MATMUL:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+  // MATMUL-SAME:                                        %[[IV1]], %[[IV0]]
+  //      MATMUL:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      MATMUL:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+  //      MATMUL:          %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
+  // MATMUL-SAME:                                            %[[IV1]], %[[IV2]]
+  //      MATMUL:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+  //      MATMUL:          %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
   %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %2 : tensor<24x25xf32>
 }
 
 // -----
 
-//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+//  MATMUL-DAG:  #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 
-//      CHECK:  fuse_indexed
-// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32>
+//      MATMUL:  fuse_indexed
+// MATMUL-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32>
 builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
                            %arg1: tensor<12x25xi32>,
                            %arg2: tensor<24x25xi32>) -> tensor<24x25xi32> {
@@ -213,19 +214,19 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
     linalg.yield %9 : i32
   } -> tensor<12x25xi32>
 
-  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
-  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
-  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+  //      MATMUL:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      MATMUL:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      MATMUL:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
 
   // Shift the indexes by the slice offsets and swap the offsets due to the transposed indexing map.
-  //      CHECK:        %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
-  // CHECK-SAME:                                          %[[IV2]], %[[IV0]]
-  //      CHECK:  linalg.generic {{.*}} outs(%[[T1]]
-  //      CHECK:  %[[IDX0:.*]] = linalg.index 0
-  //      CHECK:  %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]])
-  //      CHECK:  %[[IDX1:.*]] = linalg.index 1
-  //      CHECK:  %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]])
-  //      CHECK:  %{{.*}} = arith.addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]]
+  //      MATMUL:        %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+  // MATMUL-SAME:                                          %[[IV2]], %[[IV0]]
+  //      MATMUL:  linalg.generic {{.*}} outs(%[[T1]]
+  //      MATMUL:  %[[IDX0:.*]] = linalg.index 0
+  //      MATMUL:  %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]])
+  //      MATMUL:  %[[IDX1:.*]] = linalg.index 1
+  //      MATMUL:  %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]])
+  //      MATMUL:  %{{.*}} = arith.addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]]
   %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xi32>, tensor<12x25xi32>) outs(%arg2 : tensor<24x25xi32>) -> tensor<24x25xi32>
   return %1 : tensor<24x25xi32>
 }
@@ -235,28 +236,28 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 
-//      CHECK:  fuse_outermost_reduction
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
-// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32>
+//      GENERIC:  fuse_outermost_reduction
+// GENERIC-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
+// GENERIC-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32>
 func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
                                %arg1: tensor<10xf32>) -> tensor<10xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
 
   // Cannot fuse the output fill since the reduction loop is the outermost loop.
-  //      CHECK:      %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
+  //      GENERIC:      %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
   %1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32>
 
-  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
-  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+  //      GENERIC:  scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
+  //      GENERIC:    scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
 
-  // Check the input fill has been fused.
-  //      CHECK:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
-  // CHECK-SAME:                                        %[[IV1]], %[[IV0]]
-  //      CHECK:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
-  //      CHECK:      %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
-  // CHECK-SAME:                                        %[[IV1]]
-  //      CHECK:  linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
+  // MATMUL the input fill has been fused.
+  //      GENERIC:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+  // GENERIC-SAME:                                        %[[IV1]], %[[IV0]]
+  //      GENERIC:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      GENERIC:      %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
+  // GENERIC-SAME:                                        %[[IV1]]
+  //      GENERIC:  linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
   %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) {
   ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
     %3 = arith.addf %arg2, %arg3 : f32
@@ -267,39 +268,39 @@ func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
 
 // -----
 
-//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)>
-//  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>
+//  GENERIC-DAG:  #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+//  GENERIC-DAG:  #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)>
+//  GENERIC-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>
 #map0 = affine_map<(d0, d1) -> (d0, d0 + d1)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
 
-//      CHECK:  fuse_non_rectangular
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
+//      GENERIC:  fuse_non_rectangular
+// GENERIC-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
 func @fuse_non_rectangular(%arg0: tensor<10x17xf32>,
                            %arg1: tensor<10x8xf32>) -> tensor<10x8xf32> {
 
-  //  CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-  //  CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
-  //  CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
-  //  CHECK-DAG:  %[[C8:.*]] = arith.constant 8 : index
-  //  CHECK-DAG:  %[[C10:.*]] = arith.constant 10 : index
+  //  GENERIC-DAG:  %[[C0:.*]] = arith.constant 0 : index
+  //  GENERIC-DAG:  %[[C4:.*]] = arith.constant 4 : index
+  //  GENERIC-DAG:  %[[C5:.*]] = arith.constant 5 : index
+  //  GENERIC-DAG:  %[[C8:.*]] = arith.constant 8 : index
+  //  GENERIC-DAG:  %[[C10:.*]] = arith.constant 10 : index
   %cst = arith.constant 0.000000e+00 : f32
   %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
 
-  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]]
-  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]]
+  //      GENERIC:  scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]]
+  //      GENERIC:    scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]]
 
   // Compute producer on a hyper rectangular bounding box. Along the second dimenson,
   // the offset is set to the sum of the induction variables, and the upper bound
   // to either 8 (tile size) or 17 (sum of max indices (9+7) then + 1) minus the
   // induction variables.
-  //      CHECK:      %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]]
-  //      CHECK:      %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]]
-  //      CHECK:      %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]]
-  //      CHECK:      %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
-  // CHECK-SAME:                                        %[[IV1]], %[[SUM]]
-  // CHECK-SAME:                                                , %[[UB1]]
-  //      CHECK:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //  GENERIC-DAG:      %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]]
+  //  GENERIC-DAG:      %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]]
+  //  GENERIC-DAG:      %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]]
+  //      GENERIC:      %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  // GENERIC-SAME:                                        %[[IV1]], %[[SUM]]
+  // GENERIC-SAME:                                                , %[[UB1]]
+  //      GENERIC:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
   %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x17xf32>) outs(%arg1 : tensor<10x8xf32>) {
   ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
     %2 = arith.addf %arg2, %arg3 : f32

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
index 1578d230017ba..ef28641f4062c 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
@@ -1,11 +1,12 @@
-// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.conv_2d fuse tile-sizes=4,4,0,0 tile-interchange=0,1,2,3 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=CONV %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=4,4,0 tile-interchange=0,1,2 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s
 
-//      CHECK:  fuse_conv_chain
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32>
-// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32>
-// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32>
-// CHECK-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32>
-// CHECK-SAME:    %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+//      CONV:  fuse_conv_chain
+// CONV-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32>
+// CONV-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32>
+// CONV-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32>
+// CONV-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32>
+// CONV-SAME:    %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32>
 builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
                               %arg1: tensor<11x11xf32>,
                               %arg2: tensor<10x10xf32>,
@@ -14,34 +15,34 @@ builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
   %cst = arith.constant 1.0 : f32
 
   // Do not tile the filter fill since the filter dimensions are not tiled.
-  //      CHECK:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+  //      CONV:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
   %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32>
 
   // Fuse all other operations.
-  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
-  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]]
+  //      CONV:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
+  //      CONV:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]]
 
-  //      CHECK:          %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
-  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CHECK:          %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
-  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CHECK:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
-  //      CHECK:          %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
+  //      CONV:          %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+  // CONV-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CONV:          %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
+  // CONV-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CONV:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+  //      CONV:          %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
   %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32>
   %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32>
 
-  //      CHECK:          %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
-  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CHECK:          %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
-  //      CHECK:          %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
+  //      CONV:          %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
+  // CONV-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CONV:          %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
+  //      CONV:          %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
   %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32>
   %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32>
 
   // Use the argument passed in by iteration argument.
-  //      CHECK:          %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
-  // CHECK-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CHECK:          %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
-  //      CHECK:          %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
+  //      CONV:          %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
+  // CONV-SAME:                                            %[[IV0]], %[[IV1]]
+  //      CONV:          %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
+  //      CONV:          %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
   %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
   %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32>
   return %6 : tensor<8x8xf32>
@@ -49,8 +50,8 @@ builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
 
 // -----
 
-//      CHECK:  fuse_matmul_chain
-// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+//      MATMUL:  fuse_matmul_chain
+// MATMUL-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32>
 builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %c0 = arith.constant 0 : index
   %c12 = arith.constant 12 : index
@@ -60,24 +61,24 @@ builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %cst = arith.constant 0.000000e+00 : f32
 
   // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled.
-  //      CHECK:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+  //      MATMUL:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
   %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
 
-  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
-  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
+  //      MATMUL:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
+  //      MATMUL:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
 
   // Only the outermost loop of the producer matmul is tiled.
-  //      CHECK:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
-  // CHECK-SAME:                                        %[[IV0]], 0
-  //      CHECK:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
-  //      CHECK:      %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
+  //      MATMUL:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+  // MATMUL-SAME:                                        %[[IV0]], 0
+  //      MATMUL:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      MATMUL:      %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
   %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
 
   // Use the argument passed in by iteration argument.
-  //      CHECK:      %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
-  // CHECK-SAME:                                        %[[IV0]], %[[IV1]]
-  //      CHECK:      %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
-  //      CHECK:      %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
+  //      MATMUL:      %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
+  // MATMUL-SAME:                                        %[[IV0]], %[[IV1]]
+  //      MATMUL:      %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
+  //      MATMUL:      %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
   %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
   return %2 : tensor<8x8xf32>
 }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 49b2bb9b4bcd6..c2c563abe4cfc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -52,16 +52,21 @@ struct TestLinalgCodegenStrategy
 
   void runOnFunction() override;
 
-  void runStrategy(LinalgTilingOptions tilingOptions,
+  void runStrategy(LinalgTilingAndFusionOptions tilingAndFusionOptions,
+                   LinalgTilingOptions tilingOptions,
                    LinalgTilingOptions registerTilingOptions,
                    LinalgPaddingOptions paddingOptions,
                    vector::VectorContractLowering vectorContractLowering,
                    vector::VectorTransferSplit vectorTransferSplit);
 
+  Option<bool> fuse{
+      *this, "fuse",
+      llvm::cl::desc("Fuse the producers after tiling the root op."),
+      llvm::cl::init(false)};
   ListOption<int64_t> tileSizes{*this, "tile-sizes",
                                 llvm::cl::MiscFlags::CommaSeparated,
                                 llvm::cl::desc("Specifies the tile sizes.")};
-  ListOption<unsigned> tileInterchange{
+  ListOption<int64_t> tileInterchange{
       *this, "tile-interchange", llvm::cl::MiscFlags::CommaSeparated,
       llvm::cl::desc("Specifies the tile interchange.")};
 
@@ -148,6 +153,7 @@ struct TestLinalgCodegenStrategy
 };
 
 void TestLinalgCodegenStrategy::runStrategy(
+    LinalgTilingAndFusionOptions tilingAndFusionOptions,
     LinalgTilingOptions tilingOptions,
     LinalgTilingOptions registerTilingOptions,
     LinalgPaddingOptions paddingOptions,
@@ -156,7 +162,10 @@ void TestLinalgCodegenStrategy::runStrategy(
   assert(!anchorOpName.empty());
   CodegenStrategy strategy;
   StringRef genericOpName = GenericOp::getOperationName();
-  strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
+  strategy
+      .tileAndFuseIf(fuse && !tileSizes.empty(), anchorOpName,
+                     tilingAndFusionOptions)
+      .tileIf(!fuse && !tileSizes.empty(), anchorOpName, tilingOptions)
       .promoteIf(promote, anchorOpName,
                  LinalgPromotionOptions()
                      .setAlignment(16)
@@ -204,11 +213,17 @@ void TestLinalgCodegenStrategy::runOnFunction() {
   if (!anchorFuncOpName.empty() && anchorFuncOpName != getFunction().getName())
     return;
 
+  LinalgTilingAndFusionOptions tilingAndFusionOptions;
+  tilingAndFusionOptions.tileSizes = {tileSizes.begin(), tileSizes.end()};
+  tilingAndFusionOptions.tileInterchange = {tileInterchange.begin(),
+                                            tileInterchange.end()};
+
   LinalgTilingOptions tilingOptions;
   if (!tileSizes.empty())
     tilingOptions = tilingOptions.setTileSizes(tileSizes);
   if (!tileInterchange.empty())
-    tilingOptions = tilingOptions.setInterchange(tileInterchange);
+    tilingOptions = tilingOptions.setInterchange(
+        SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
 
   LinalgTilingOptions registerTilingOptions;
   if (!registerTileSizes.empty())
@@ -245,8 +260,8 @@ void TestLinalgCodegenStrategy::runOnFunction() {
           .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
           .Default(vector::VectorTransferSplit::None);
 
-  runStrategy(tilingOptions, registerTilingOptions, paddingOptions,
-              vectorContractLowering, vectorTransferSplit);
+  runStrategy(tilingAndFusionOptions, tilingOptions, registerTilingOptions,
+              paddingOptions, vectorContractLowering, vectorTransferSplit);
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list