[Mlir-commits] [mlir] 98835e3 - [mlir][Linalg] Enable TileAndFusePattern to work with tensors.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 28 14:13:30 PST 2021


Author: MaheshRavishankar
Date: 2021-01-28T14:13:01-08:00
New Revision: 98835e3d9849a17e9f8eb5dcd8aee3c9d32e1e07

URL: https://github.com/llvm/llvm-project/commit/98835e3d9849a17e9f8eb5dcd8aee3c9d32e1e07
DIFF: https://github.com/llvm/llvm-project/commit/98835e3d9849a17e9f8eb5dcd8aee3c9d32e1e07.diff

LOG: [mlir][Linalg] Enable TileAndFusePattern to work with tensors.

Differential Revision: https://reviews.llvm.org/D94531

Added: 
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 4fe90897873b..7441a54cdc05 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -732,45 +732,6 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
   return fusableLoops;
 }
 
-// /// For `consumer` with tensor semantics, find the Linalg operation on
-// tensors
-// /// producer the operand at position `consumerIdx`. This is a simple use-def
-// /// chain using the SSA value, but returned as an element of the
-// /// `LinalgDependenceGraphElem` to use the same analysis for both tensors and
-// /// buffers.
-// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-// findFusableProducerForTensorOp(OpOperand &consumerOpOperand) {
-//   // For now only looking for cases where the operand is produced by another
-//   // Linalg structured operation.
-//   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
-//   if (!consumer || !consumer.hasTensorSemantics())
-//     return llvm::None;
-//   unsigned consumerIdx = consumerOpOperand.getOperandNumber();
-//   Value value = consumerOpOperand.get();
-//   if (auto linalgOp = value.getDefiningOp<LinalgOp>()) {
-//     return LinalgDependenceGraph::LinalgDependenceGraphElem{
-//         &(linalgOp
-//               .getOutputOpOperands()[value.cast<OpResult>().getResultNumber()]),
-//         &(consumer.getInputOpOperands()[consumerIdx]),
-//         LinalgDependenceGraph::DependenceType::RAW};
-//   }
-//   return llvm::None;
-// }
-
-// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-// findFusableProducer(OpOperand &consumerOpOperand,
-//                     const LinalgDependenceGraph &dependenceGraph) {
-//   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
-//   if (!consumer)
-//     return llvm::None;
-//   if (consumer.hasBufferSemantics())
-//     return findFusableProducerForBufferOp(consumerOpOperand,
-//     dependenceGraph);
-//   if (consumer.hasTensorSemantics())
-//     return findFusableProducerForTensorOp(consumerOpOperand);
-//   return llvm::None;
-// }
-
 /// Find all dependences that are fusable.
 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7260bb4aca6b..fc647ea45478 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -283,6 +283,19 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
   return success();
 }
 
+static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
+  if (tiledOp.loops.empty())
+    return tiledOp.op.getOperation()->getResults();
+  return tiledOp.loops.front()->getResults();
+}
+
+static ValueRange
+getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
+  if (tiledAndFusedOp.fusedLoops.empty())
+    return tiledAndFusedOp.op.getOperation()->getResults();
+  return tiledAndFusedOp.fusedLoops.front()->getResults();
+}
+
 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
     StringRef opName, MLIRContext *context,
     const LinalgDependenceGraph &dependenceGraph,
@@ -301,8 +314,6 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
     return failure();
   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
     return failure();
-  if (!linalgOp.hasBufferSemantics())
-    return failure();
 
   DenseSet<Operation *> producers;
   producers.insert(linalgOp);
@@ -359,9 +370,11 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
     if (!unfusedTiledOp)
       return failure();
-    rewriter.eraseOp(tiledAndFusedOps->op);
+    rewriter.replaceOp(tiledAndFusedOps->op,
+                       getTiledOpResult(unfusedTiledOp.getValue()));
     tiledAndFusedOps->op = unfusedTiledOp->op;
   }
+  op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
 
   marker.replaceLinalgTransformationFilter(rewriter,
                                            tiledAndFusedOps->op.getOperation());

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
new file mode 100644
index 000000000000..10cd5b454a4a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
+
+module {
+  func @matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                      %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
+                      %arg4: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>   // <MxN1> <N1xN2>
+    %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
+      ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32>   // <MxN2> <N2xN3>
+    return %1 : tensor<?x?xf32>
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)>
+//      CHECK: func @matmul_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME:     %[[C0]] to %[[M]] step %[[C32]]
+// CHECK-SAME:     iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG6]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[M_2]], %[[IV0]])
+//      CHECK:     %[[N3:.+]] = dim %[[ARG6]], %[[C1]]
+//      CHECK:     %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[N3]]]
+//      CHECK:     %[[N2:.+]] = dim %[[ARG1]], %[[C1]]
+//      CHECK:     %[[N1:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[N1]]]
+//      CHECK:     %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0]
+// CHECK-SAME:       [%[[N1]], %[[N2]]]
+//      CHECK:     %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[N2]]]
+//      CHECK:     %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME:       __internal_linalg_transform__ = "after_lhs_fusion_producer"
+// CHECK-SAME:       ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:       outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+//      CHECK:     %[[N3_2:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       %[[C0]] to %[[N3_2]] step %[[C64]]
+// CHECK-SAME:       iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] =
+// CHECK-SAME:         %[[C0]] to %[[N2]] step %[[C16]]
+// CHECK-SAME:         iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
+//      CHECK:         %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
+//      CHECK:         %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_N2]]]
+//      CHECK:         %[[N2_3:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:         %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]]
+//      CHECK:         %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]
+//      CHECK:         %[[ST_ARG3:.+]] = subtensor %[[ARG3]][%[[IV2]], %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_N2_2]], %[[TILE_N3]]]
+//      CHECK:         %[[M_4:.+]] = dim %[[ARG10]], %[[C0]]
+//      CHECK:         %[[N3_3:.+]] = dim %[[ARG10]], %[[C1]]
+//      CHECK:         %[[TILE_N3_2:.+]] = affine.min #[[MAP4]](%[[N3_3]], %[[IV1]])
+//      CHECK:         %[[ST_ARG4:.+]] = subtensor %[[ARG10]][0, %[[IV1]]]
+// CHECK-SAME:           [%[[M_4]], %[[TILE_N3_2]]]
+//      CHECK:         %[[ST_RESULT:.+]] = linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_lhs_fusion"
+// CHECK-SAME:           ins(%[[ST_LHS]], %[[ST_ARG3]]
+// CHECK-SAME:             : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:           outs(%[[ST_ARG4]] : tensor<?x?xf32>)
+//      CHECK:         %[[UPDATE1:.+]] = subtensor_insert %[[ST_RESULT]]
+// CHECK-SAME:           into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3_2]]]
+//      CHECK:         scf.yield %[[UPDATE1]]
+//      CHECK:       }
+//      CHECK:       scf.yield %[[YIELD1]]
+//      CHECK:     }
+//      CHECK:     %[[UPDATE0:.+]] = subtensor_insert %[[YIELD0]] into
+// CHECK-SAME:       %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]]
+//      CHECK:     scf.yield %[[UPDATE0]]
+//      CHECK:   }
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+module {
+  func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                           %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %0 = dim %arg2, %c0 : tensor<?x?xf32>
+    %1 = dim %arg2, %c1 : tensor<?x?xf32>
+    %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    %3 = dim %2, %c0 : tensor<?x?xf32>
+    %4 = dim %2, %c1 : tensor<?x?xf32>
+    %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
+    %6 = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"],
+       __internal_linalg_transform__ = "transpose_fusion"}
+      ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%5 : tensor<?x?xf32>) {
+      ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+        %7 = addf %arg3, %arg4 : f32
+        linalg.yield %7 : f32
+      } -> tensor<?x?xf32>
+    return %6 : tensor<?x?xf32>
+  }
+}
+//       CHECK: func @matmul_plus_matmul
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
+//       CHECK:     %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:       iter_args(%[[ARG6:.+]] = %[[ARG4]])
+//       CHECK:       %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]]
+//       CHECK:       %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+//       CHECK:       %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, %[[IV1]]]
+//       CHECK:       %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], %[[IV1]]]
+//       CHECK:       %[[LHS:.+]] = linalg.matmul
+//  CHECK-SAME:         ins(%[[ST_ARG0]], %[[ST_ARG1]]
+//  CHECK-SAME:           : tensor<?x?xf32>, tensor<?x?xf32>)
+//  CHECK-SAME:         outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+//       CHECK:       %[[ST_RESULT:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[LHS]] : tensor<?x?xf32>)
+//  CHECK-SAME:         outs(%[[ST_ARG6]] : tensor<?x?xf32>)
+//       CHECK:       %[[UPDATE:.+]] = subtensor_insert %[[ST_RESULT]]
+//  CHECK-SAME:         into %[[ARG6]][%[[IV0]], %[[IV1]]]
+//       CHECK:       scf.yield %[[UPDATE]]
+//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:   return %[[RESULT]]

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index f2c9067d5cc2..20e0166d7331 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -20,30 +20,14 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-namespace {
-struct TestLinalgFusionTransforms
-    : public PassWrapper<TestLinalgFusionTransforms, FunctionPass> {
-  TestLinalgFusionTransforms() = default;
-  TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
-                    StandardOpsDialect>();
-  }
-
-  void runOnFunction() override;
-};
-} // namespace
-
+template <LinalgTilingLoopType LoopType>
 static void fillFusionPatterns(MLIRContext *context,
                                const LinalgDependenceGraph &dependenceGraph,
                                OwningRewritePatternList &patterns) {
   patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
                   LinalgTileAndFusePattern<ConvOp>>(
       context, dependenceGraph,
-      LinalgTilingOptions()
-          .setTileSizes({32, 64, 16})
-          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
       LinalgFusionOptions().setIndicesToFuse({2}),
       LinalgTransformationFilter(
           Identifier::get("basic_fusion", context),
@@ -57,9 +41,7 @@ static void fillFusionPatterns(MLIRContext *context,
 
   patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
       context, dependenceGraph,
-      LinalgTilingOptions()
-          .setTileSizes({32, 64, 16})
-          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
       LinalgFusionOptions().setIndicesToFuse({0}),
       LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
                                  Identifier::get("after_lhs_fusion", context)),
@@ -72,9 +54,7 @@ static void fillFusionPatterns(MLIRContext *context,
 
   patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
       context, dependenceGraph,
-      LinalgTilingOptions()
-          .setTileSizes({32, 64, 16})
-          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
       LinalgFusionOptions().setIndicesToFuse({1}),
       LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
                                  Identifier::get("after_rhs_fusion", context)),
@@ -87,9 +67,7 @@ static void fillFusionPatterns(MLIRContext *context,
 
   patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
       context, dependenceGraph,
-      LinalgTilingOptions()
-          .setTileSizes({32, 64, 16})
-          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
       LinalgFusionOptions().setIndicesToFuse({0, 2}),
       LinalgTransformationFilter(
           Identifier::get("two_operand_fusion", context),
@@ -103,8 +81,7 @@ static void fillFusionPatterns(MLIRContext *context,
 
   patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
       context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
-          LinalgTilingLoopType::ParallelLoops),
+      LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
       LinalgFusionOptions().setIndicesToFuse({0, 1}),
       LinalgTransformationFilter(
           Identifier::get("transpose_fusion", context),
@@ -117,18 +94,30 @@ static void fillFusionPatterns(MLIRContext *context,
           Identifier::get("after_transpose_fusion_original", context)));
 }
 
-static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
-  OwningRewritePatternList fusionPatterns;
-  Aliases alias;
-  LinalgDependenceGraph dependenceGraph =
-      LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
-  fillFusionPatterns(context, dependenceGraph, fusionPatterns);
-  applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
-}
+namespace {
+template <LinalgTilingLoopType LoopType = LinalgTilingLoopType::ParallelLoops>
+struct TestLinalgFusionTransforms
+    : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
+  TestLinalgFusionTransforms() = default;
+  TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
 
-void TestLinalgFusionTransforms::runOnFunction() {
-  applyFusionPatterns(&getContext(), getFunction());
-}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
+                    StandardOpsDialect>();
+  }
+
+  void runOnFunction() override {
+    MLIRContext *context = &this->getContext();
+    FuncOp funcOp = this->getFunction();
+    OwningRewritePatternList fusionPatterns;
+    Aliases alias;
+    LinalgDependenceGraph dependenceGraph =
+        LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
+    fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
+    applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
+  }
+};
+} // namespace
 
 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
   OpBuilder b(f);
@@ -237,7 +226,7 @@ struct TestLinalgTileAndFuseSequencePass
     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
     OpBuilder builder(funcOp.getContext());
     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
-    if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) {
+    if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
           return linalgOp.hasTensorSemantics();
         }))
       loopType = LinalgTilingLoopType::Loops;
@@ -260,10 +249,17 @@ struct TestLinalgTileAndFuseSequencePass
 namespace mlir {
 namespace test {
 void registerTestLinalgFusionTransforms() {
-  PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
+  PassRegistration<TestLinalgFusionTransforms<>> testFusionTransformsPass(
       "test-linalg-fusion-transform-patterns",
       "Test Linalg fusion transformation patterns by applying them greedily.");
 }
+void registerTestLinalgTensorFusionTransforms() {
+  PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::Loops>>
+      testTensorFusionTransformsPass(
+          "test-linalg-tensor-fusion-transform-patterns",
+          "Test Linalg on tensor fusion transformation "
+          "patterns by applying them greedily.");
+}
 void registerTestLinalgGreedyFusion() {
   PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
       "test-linalg-greedy-fusion",

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dc68f8f4d778..7336ae18e3c6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,6 +74,7 @@ void registerTestGpuParallelLoopMappingPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
 void registerTestLinalgFusionTransforms();
+void registerTestLinalgTensorFusionTransforms();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgHoisting();
 void registerTestLinalgTileAndFuseSequencePass();
@@ -145,6 +146,7 @@ void registerTestPasses() {
   test::registerTestInterfaces();
   test::registerTestLinalgCodegenStrategy();
   test::registerTestLinalgFusionTransforms();
+  test::registerTestLinalgTensorFusionTransforms();
   test::registerTestLinalgGreedyFusion();
   test::registerTestLinalgHoisting();
   test::registerTestLinalgTileAndFuseSequencePass();


        


More information about the Mlir-commits mailing list