[Mlir-commits] [mlir] 9b17bf2 - [mlir][Linalg] Make Linalg fusion a test pass

Nicolas Vasilache llvmlistbot at llvm.org
Thu Oct 29 08:19:53 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-29T15:18:51Z
New Revision: 9b17bf2e54c71b36bf28fbab05698fb73ea8dda9

URL: https://github.com/llvm/llvm-project/commit/9b17bf2e54c71b36bf28fbab05698fb73ea8dda9
DIFF: https://github.com/llvm/llvm-project/commit/9b17bf2e54c71b36bf28fbab05698fb73ea8dda9.diff

LOG: [mlir][Linalg] Make Linalg fusion a test pass

Linalg "tile-and-fuse" is currently exposed as a Linalg pass "-linalg-fusion" but only the mechanics of the transformation are currently relevant.
Instead turn it into a "-test-linalg-greedy-fusion" pass which performs canonicalizations to enable more fusions to compose.
This allows dropping the OperationFolder which is not meant to be used with the pattern rewrite infrastructure.

Differential Revision: https://reviews.llvm.org/D90394

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/fusion-2-level.mlir
    mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index a0235cf87fdb..24570d3c4ec6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -18,7 +18,6 @@
 namespace mlir {
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
 
-std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
 std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
 std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2df2051255c2..7446ca8f6636 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -23,12 +23,6 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
-def LinalgFusion : FunctionPass<"linalg-fusion"> {
-  let summary = "Fuse operations in the linalg dialect";
-  let constructor = "mlir::createLinalgFusionPass()";
-  let dependentDialects = ["linalg::LinalgDialect"];
-}
-
 def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
   let summary = "Fuse operations on RankedTensorType in linalg dialect";
   let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9b343b3b04ab..d8c595fb91fd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -88,27 +88,22 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
 /// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
 /// to be a `subview` op (generally obtained by applying the tiling
 /// transformation).
-/// When non-null, the optional pointer `folder` is used to call into the
-/// `createAndFold` builder method. If `folder` is null, the regular `create`
-/// method is called.
 Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
                                           unsigned consumerIdx,
-                                          const LinalgDependenceGraph &graph,
-                                          OperationFolder *folder = nullptr);
+                                          const LinalgDependenceGraph &graph);
 /// Tensor counterpart of `fuseProducerOfBuffer`.
 /// This implements the fusion part of the "tileAndFuse on tensors"
 /// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
 /// to be the result of a `subtensor` op (generally obtained by applying the
 /// tiling transformation).
 Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
-                                          unsigned consumerIdx,
-                                          OperationFolder *folder);
+                                          unsigned consumerIdx);
 
 /// Fuse linalg operation on tensors, with the producer of the operand at
 /// position `consumerIdx` of the consumer.
-Optional<SmallVector<Value, 1>>
-fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
-              unsigned consumerIdx, OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
+                                              Operation *consumer,
+                                              unsigned consumerIdx);
 
 /// Returns the linearized list of all shape dimensions in a `linalgOp`.
 /// Applying the inverse, concatenated loopToOperandRangeMaps to this list
@@ -122,17 +117,12 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
 /// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
 /// concatenated indexing maps to the result of `getShape`. Returns None if
 /// the bounds computation fails.
-Optional<SmallVector<Value, 4>>
-getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
-              OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
+                                              LinalgOp linalgOp);
 
 /// Returns the values obtained by applying `map` to the list of values.
-/// When non-null, the optional pointer `folder` is used to call into the
-/// `createAndFold` builder method. If `folder` is null, the regular `create`
-/// method is called.
 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
-                                       AffineMap map, ValueRange values,
-                                       OperationFolder *folder = nullptr);
+                                       AffineMap map, ValueRange values);
 
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 00eb9a2fe834..ac35d87a8413 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -13,7 +13,6 @@
 #include "PassDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -24,7 +23,6 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
@@ -37,8 +35,6 @@ using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 
-using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
-
 using llvm::dbgs;
 
 /// Implements a simple high-level fusion pass on linalg structured operations.
@@ -201,8 +197,7 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
 ///   2. Tensor case: `producerIdx` is the index of the tensor in
 ///      `producer.getResults()`.
 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
-                     LinalgOp consumer, unsigned consumerIdx,
-                     OperationFolder *folder = nullptr) {
+                     LinalgOp consumer, unsigned consumerIdx) {
   Operation *shapeProducingOp =
       consumer.getShapedOperand(consumerIdx).getDefiningOp();
   assert((isa<SubViewOp>(shapeProducingOp) ||
@@ -244,9 +239,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
                  << "existing LoopRange: " << loopRanges[i] << "\n");
     else {
       auto shapeDim = getShapeDefiningLoopRange(producer, i);
-      loopRanges[i] = Range{folded_std_constant_index(folder, 0),
+      loopRanges[i] = Range{std_constant_index(0),
                             std_dim(shapeDim.shape, shapeDim.dimension),
-                            folded_std_constant_index(folder, 1)};
+                            std_constant_index(1)};
       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
     }
   }
@@ -396,15 +391,21 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
   return {};
 }
 
-Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
-    OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
-    const LinalgDependenceGraph &graph, OperationFolder *folder) {
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
+                                   unsigned consumerIdx,
+                                   const LinalgDependenceGraph &graph) {
   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
       findFusableProducer(consumer, consumerIdx, graph);
   if (!fusableDependence)
     return {};
 
   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+  // If producer is already in the same block as consumer, we are done.
+  if (consumer.getOperation()->getBlock() ==
+      producerOp.getOperation()->getBlock())
+    return {};
+
   Value producerView = fusableDependence->dependentOpView.view;
   Value consumerView = fusableDependence->indexingView;
 
@@ -427,8 +428,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
   assert(producerIdxOpt.hasValue() && "incorrect operand index");
   unsigned producerIdx = producerIdxOpt.getValue();
 
-  auto fusedProducer =
-      fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+  auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
   return FusionInfo{producerOp, fusedProducer};
 }
 
@@ -459,10 +459,9 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
   }
 }
 
-Optional<FusionInfo>
-mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
-                                   unsigned consumerIdx,
-                                   OperationFolder *folder) {
+Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
+                                                        LinalgOp consumer,
+                                                        unsigned consumerIdx) {
   Value inputTensor = consumer.getInput(consumerIdx);
   LinalgOp producerOp;
   unsigned producerIdx;
@@ -475,13 +474,18 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
     return {};
   }
 
+  // If producer is already in the same block as consumer, we are done.
+  if (consumer.getOperation()->getBlock() ==
+      producerOp.getOperation()->getBlock())
+    return {};
+
   // Insert fused `producer` just before `consumer`.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(consumer.getOperation());
   ScopedContext scope(b, consumer.getLoc());
   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
   LinalgOp fusedProducer =
-      fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+      fuse(b, producerOp, producerIdx, consumer, consumerIdx);
 
   // Replace use.
   // Canonicalizations are not guaranteed to have happened before constructing
@@ -796,72 +800,3 @@ mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
   }
   return llvm::None;
 }
-
-static void fuseLinalgOpsGreedily(FuncOp f) {
-  LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
-
-  OpBuilder b(f);
-  OperationFolder folder(f.getContext());
-  DenseSet<Operation *> eraseSet;
-
-  // Save original Linalg ops, we only want to make a pass over those.
-  SmallVector<Operation *, 8> linalgOps;
-  f.walk([&](LinalgOp op) {
-    // TODO: support multi-results.
-    if (op.getOperation()->getNumResults() <= 1)
-      linalgOps.push_back(op);
-  });
-
-  // Tile and Fuse for tensors inputs (TODO: all tensor operands).
-  for (auto *op : llvm::reverse(linalgOps)) {
-    LinalgOp linalgOp = cast<LinalgOp>(op);
-    for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
-      if (en.value().getType().isa<MemRefType>()) {
-        // TODO: LinalgDependenceGraph should be able to update itself.
-        // The current naive and expensive reconstruction of the graph should be
-        // removed.
-        linalg::Aliases aliases;
-        linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-        if (auto info =
-                fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) {
-          auto *originalOp = info->originalProducer.getOperation();
-          eraseSet.insert(originalOp);
-          auto *originalOpInLinalgOpsVector =
-              std::find(linalgOps.begin(), linalgOps.end(), originalOp);
-          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
-        }
-      } else {
-        assert(en.value().getType().isa<RankedTensorType>());
-        // Tile and Fuse tensor input (TODO: init_tensors too).
-        if (en.index() >= linalgOp.getNumInputs())
-          continue;
-        if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) {
-          auto *originalOp = info->originalProducer.getOperation();
-          auto *originalOpInLinalgOpsVector =
-              std::find(linalgOps.begin(), linalgOps.end(), originalOp);
-          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
-          // Don't mark for erasure in the tensor case, let DCE handle this.
-        }
-      }
-    }
-  }
-  // The `fuseProducerOfBuffer` function performs structural checks and in
-  // particular that no covering read or write exist between the consumer and
-  // the producer. As a consequence, the only fusions that may occur preserve
-  // subsequent dependences and are guaranteed by construction to produce the
-  // whole view. We may thus erase the producer once it is fused.
-  for (auto *e : eraseSet)
-    e->erase();
-
-  LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
-}
-
-namespace {
-struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
-  void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
-};
-} // namespace
-
-std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
-  return std::make_unique<LinalgFusionPass>();
-}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d30713d76e36..3e3392d84975 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -177,8 +177,7 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
 
 static Optional<SmallVector<Value, 1>>
 fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
-                  PatternRewriter &rewriter,
-                  OperationFolder *folder = nullptr) {
+                  PatternRewriter &rewriter) {
   if (!areTensorOpsFusable(producer, consumer, consumerIdx))
     return llvm::None;
 
@@ -440,8 +439,8 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
 /// conditions have been satisfied.
 static Optional<SmallVector<Value, 1>>
 fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
-                           unsigned fusedTensorIndex, PatternRewriter &rewriter,
-                           OperationFolder *folder = nullptr) {
+                           unsigned fusedTensorIndex,
+                           PatternRewriter &rewriter) {
   assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
          "preconditions for fuse operation failed");
   // Check if reshape is expanding or collapsing.
@@ -929,7 +928,7 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
 
 Optional<SmallVector<Value, 1>>
 mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
-                            unsigned consumerIdx, OperationFolder *folder) {
+                            unsigned consumerIdx) {
   if (consumerIdx >= consumer->getNumOperands())
     return llvm::None;
   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
@@ -942,7 +941,7 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
     return llvm::None;
 
   return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
-                           consumerIdx, rewriter, folder);
+                           consumerIdx, rewriter);
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 3f29949ffe63..210d17516718 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -24,7 +24,6 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -57,30 +56,27 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
   return llvm::None;
 }
 
-static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
-                                           AffineMap map,
-                                           ValueRange operandsRef,
-                                           OperationFolder *folder) {
+static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
+                                             AffineMap map,
+                                             ValueRange operandsRef) {
   SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
   fullyComposeAffineMapAndOperands(&map, &operands);
   canonicalizeMapAndOperands(&map, &operands);
-  return folder ? folder->create<AffineApplyOp>(b, loc, map, operands)
-                : b.create<AffineApplyOp>(loc, map, operands);
+  return b.createOrFold<AffineApplyOp>(loc, map, operands);
 }
 
 SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
                                                      AffineMap map,
-                                                     ValueRange values,
-                                                     OperationFolder *folder) {
+                                                     ValueRange values) {
   SmallVector<Value, 4> res;
   res.reserve(map.getNumResults());
   unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
   // For each `expr` in `map`, applies the `expr` to the values extracted from
   // ranges. If the resulting application can be folded into a Value, the
-  // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
+  // folding occurs eagerly.
   for (auto expr : map.getResults()) {
     AffineMap map = AffineMap::get(numDims, numSym, expr);
-    res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, folder));
+    res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
   }
   return res;
 }
@@ -159,15 +155,14 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
   return res;
 }
 
-Optional<SmallVector<Value, 4>>
-getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
+Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
+                                              LinalgOp linalgOp) {
   SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
   AffineMap invertedMap =
       inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
   if (!invertedMap)
     return {};
-  return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes,
-                          folder);
+  return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
 }
 
 /// Specialization to build an scf "for" nest.

diff  --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
index 0c9b9ca0dca7..27154a5e277c 100644
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s
 
 func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, %C: memref<?x?xf32, offset: ?, strides: [?, 1]>, %D: memref<?x?xf32, offset: ?, strides: [?, 1]>, %E: memref<?x?xf32, offset: ?, strides: [?, 1]>) -> memref<?x?xf32, offset: ?, strides: [?, 1]> {
   %c1 = constant 1 : index

diff  --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
index ee256170bad0..3b4948a3b4f1 100644
--- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
 
 #map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
 #id_2d = affine_map<(d0, d1) -> (d0, d1)>
@@ -82,8 +82,11 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
     ^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
       %i_int = index_cast %i: index to i32
       %i_float = sitofp %i_int : i32 to f32
+      %j_int = index_cast %j: index to i32
+      %j_float = sitofp %j_int : i32 to f32
       %ab = addf %a, %b : f32
-      %out = addf %ab, %i_float : f32
+      %tmp = addf %ab, %i_float : f32
+      %out = addf %tmp, %j_float : f32
       linalg.yield %out : f32
   }
   %C_X = dim %C, %c0 : memref<?x?xf32>
@@ -115,6 +118,7 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
 // CHECK:          [[i_new:%.*]] = addi [[i]], [[I]] : index
 // CHECK:          [[j_new:%.*]] = addi [[j]], [[J]] : index
 // CHECK:          {{.*}} = index_cast [[i_new]] : index to i32
+// CHECK:          {{.*}} = index_cast [[j_new]] : index to i32
 // CHECK:      linalg.generic
 // CHECK:          addf
 
@@ -137,10 +141,13 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
     ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
    outs(%C : memref<?x?xf32>) {
     ^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
+      %i_int = index_cast %i: index to i32
+      %i_float = sitofp %i_int : i32 to f32
       %j_int = index_cast %j: index to i32
       %j_float = sitofp %j_int : i32 to f32
       %ab = addf %a, %b : f32
-      %out = addf %ab, %j_float : f32
+      %tmp = addf %ab, %i_float : f32
+      %out = addf %tmp, %j_float : f32
       linalg.yield %out : f32
   }
   %C_X = dim %C, %c0 : memref<?x?xf32>
@@ -176,8 +183,8 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
 // CHECK-NOT:  scf.parallel
 // CHECK:      linalg.indexed_generic
 // CHECK:        ^bb0([[i:%.*]]: index, [[j:%.*]]: index
-// CHECK:          [[i_new:%.*]] = addi [[i]], [[C0]] : index
 // CHECK:          [[j_new:%.*]] = addi [[j]], [[J]] : index
+// CHECK:          {{.*}} = index_cast [[i]] : index to i32
 // CHECK:          {{.*}} = index_cast [[j_new]] : index to i32
 // CHECK:      linalg.generic
 // CHECK:          addf

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 788cb89b40f7..9a4a1c5f3f6f 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
 
 func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
          %B: memref<?x?xf32, offset: 0, strides: [?, 1]>,
@@ -98,6 +98,8 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 
 // -----
 
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+
 func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
          %B: memref<?x?xf32, offset: 0, strides: [?, ?]>,
          %C: memref<?x?xf32, offset: 0, strides: [?, ?]>,
@@ -137,9 +139,11 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 }
 // CHECK-LABEL: func @f3
 // CHECK:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK:  %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
+// CHECK:  %[[D_0:.*]] = dim %[[D]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[D_1:.*]] = dim %[[D]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[C_1:.*]] = dim %[[C]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
 // CHECK:    scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
 // CHECK:      scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
@@ -148,6 +152,8 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 
 // -----
 
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+
 func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
          %B: memref<?x?xf32, offset: 0, strides: [?, ?]>,
          %C: memref<?x?xf32, offset: 0, strides: [?, ?]>,
@@ -190,9 +196,11 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 }
 // CHECK-LABEL: func @f4
 // CHECK:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK:  %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
+// CHECK:  %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
 // CHECK:    scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
 // CHECK:      scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
@@ -246,26 +254,24 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 }
 // CHECK-LABEL: func @f5
 // CHECK:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK-DAG:  %[[B_1:.*]] = dim %[[B]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG:  %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG:  %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
-// CHECK:    scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
-// CHECK:      scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
-// CHECK-DAG:    %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]]
-// CHECK-DAG:    %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]]
-// CHECK-DAG:    %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]]
-// CHECK:        dim
-// CHECK-DAG:    %[[C_I0:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK-DAG:    %[[B_0K:.*]] = subview %[[B]][%{{.*}}, %[[K]]]
-// CHECK-DAG:    %[[D_IK_:.*]] = subview %[[D]][%[[I]], %[[K]]]
-// CHECK:        dim
-// CHECK-DAG:    %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
-// CHECK-DAG:    %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
-// CHECK-DAG:    %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK:        linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]]
-// CHECK:        linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]]
-// CHECK:        linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
+// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
+// CHECK-DAG:  %[[B_1:.*]] = dim %[[B]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[D_0:.*]] = dim %[[D]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[B_00:.*]] = subview %[[B]][0, 0]{{.*}}
+//     CHECK:  scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
+// CHECK-DAG:    %[[A_I0:.*]] = subview %[[A]][%[[I]], 0]
+// CHECK-DAG:    %[[C_I0:.*]] = subview %[[C]][%[[I]], 0]
+//     CHECK:    scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
+//     CHECK:      %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]]
+//     CHECK:      scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
+// CHECK-DAG:        %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]]
+// CHECK-DAG:        %[[B_0K:.*]] = subview %[[B]][0, %[[K]]]
+// CHECK-DAG:        %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]]
+//     CHECK:        linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]]
+//     CHECK:        linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]]
+//     CHECK:        linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
 
 // -----
 
@@ -390,11 +396,13 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 }
 // CHECK-LABEL: func @f7
 // CHECK:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK:  %[[A_0:.*]] = dim %[[A]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[A_1:.*]] = dim %[[A]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK:  %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
+// CHECK:  %[[A_0:.*]] = dim %[[A]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[A_1:.*]] = dim %[[A]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK:  %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
 // CHECK:  linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
 // CHECK:  scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
 // CHECK:    scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index e43f261632e9..41adff7d46c3 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -linalg-fusion -canonicalize -cse -split-input-file | FileCheck %s --check-prefix=CANONICALIZED
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
 
 #map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
 #map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
@@ -41,44 +40,19 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
 //  CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
 //  CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
 //  CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
-//       CHECK: %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG: %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
 //       CHECK: scf.for %[[I:[0-9a-z]*]]
+//       CHECK:     %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<2x?xf32>
 //  CHECK-NEXT:   scf.for %[[J:[0-9a-z]*]]
-//  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]]
-//
-// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value.
-//       CHECK:     subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x4xf32>
-//       CHECK:     %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<4x?xf32>
-//       CHECK:     %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-//
-// subtensors of the producing matmul.
-//       CHECK:     %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[C0]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-//  CHECK-NEXT:     %[[stB2:.*]] = subtensor %[[B]][%[[C0]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-//  CHECK-NEXT:     %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-//  CHECK-NEXT:     %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) init(%[[stC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
-//  CHECK-NEXT:     %[[stD2:.*]] = tensor_cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
-//  CHECK-NEXT:     %[[stG:.*]] = linalg.matmul ins(%[[stD2]], %[[stB1]] : tensor<?x4xf32>, tensor<4x?xf32>) init(%[[stF]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
-//  CHECK-NEXT:     subtensor_insert %[[stG]]
-
-
-// CANONICALIZED-LABEL: func @matmul_tensors(
-//  CANONICALIZED-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
-//  CANONICALIZED-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
-//  CANONICALIZED-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
-//       CANONICALIZED: %[[C0:.*]] = constant 0 : index
-//       CANONICALIZED: %[[C1:.*]] = constant 1 : index
-//       CANONICALIZED: scf.for %[[I:[0-9a-z]*]]
-//  CANONICALIZED-NEXT:   scf.for %[[J:[0-9a-z]*]]
-//  CANONICALIZED-NEXT:     scf.for %[[K:[0-9a-z]*]]
-//
-//       CANONICALIZED:     %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
-//       CANONICALIZED:     %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1]  : tensor<?x?xf32> to tensor<2x3xf32>
+//  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
+//   CHECK-DAG:       %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
+//   CHECK-DAG:       %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1]  : tensor<?x?xf32> to tensor<2x3xf32>
 //
 // subtensors of the producing matmul.
-//       CANONICALIZED:     %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
-//       CANONICALIZED:     %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<2x?xf32>
-//  CANONICALIZED-NEXT:     %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1]  : tensor<?x?xf32> to tensor<?x4xf32>
-//  CANONICALIZED-NEXT:     %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1]  : tensor<?x?xf32> to tensor<2x4xf32>
-//  CANONICALIZED-NEXT:     %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
-//  CANONICALIZED-NEXT:     %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
-//  CANONICALIZED-NEXT:     subtensor_insert %[[stG]]
+//   CHECK-DAG:       %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1]  : tensor<?x?xf32> to tensor<?x4xf32>
+//   CHECK-DAG:       %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1]  : tensor<?x?xf32> to tensor<2x4xf32>
+//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
+//  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
+//  CHECK-NEXT:       subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 4dfb653ac858..33f9429bdb27 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -13,7 +13,9 @@
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -104,10 +106,96 @@ void TestLinalgFusionTransforms::runOnFunction() {
   applyFusionPatterns(&getContext(), getFunction());
 }
 
+static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
+  OpBuilder b(f);
+  DenseSet<Operation *> eraseSet;
+
+  // Save original Linalg ops, we only want to make a pass over those.
+  SmallVector<Operation *, 8> linalgOps;
+  f.walk([&](LinalgOp op) {
+    // TODO: support multi-results.
+    if (op.getOperation()->getNumResults() <= 1)
+      linalgOps.push_back(op);
+  });
+
+  // Tile and Fuse for tensors inputs (TODO: all tensor operands).
+  bool changed = false;
+  for (auto *op : llvm::reverse(linalgOps)) {
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
+      if (en.value().getType().isa<MemRefType>()) {
+        // TODO: LinalgDependenceGraph should be able to update itself.
+        // The current naive and expensive reconstruction of the graph should be
+        // removed.
+        linalg::Aliases aliases;
+        linalg::LinalgDependenceGraph graph(aliases, linalgOps);
+        if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+          auto *originalOp = info->originalProducer.getOperation();
+          eraseSet.insert(originalOp);
+          auto *originalOpInLinalgOpsVector =
+              std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+          changed = true;
+        }
+      } else {
+        assert(en.value().getType().isa<RankedTensorType>());
+        // Tile and Fuse tensor input (TODO: init_tensors too).
+        if (en.index() >= linalgOp.getNumInputs())
+          continue;
+        if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+          auto *originalOp = info->originalProducer.getOperation();
+          auto *originalOpInLinalgOpsVector =
+              std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+          // Don't mark for erasure in the tensor case, let DCE handle this.
+          changed = true;
+        }
+      }
+    }
+  }
+  // The `fuseProducerOfBuffer` function performs structural checks and in
+  // particular that no covering read or write exist between the consumer and
+  // the producer. As a consequence, the only fusions that may occur preserve
+  // subsequent dependences and are guaranteed by construction to produce the
+  // whole view. We may thus erase the producer once it is fused.
+  for (auto *e : eraseSet)
+    e->erase();
+
+  return changed ? success() : failure();
+}
+
+namespace {
+struct TestLinalgGreedyFusion
+    : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
+  void runOnFunction() override {
+    MLIRContext *context = &getContext();
+    OwningRewritePatternList patterns =
+        linalg::getLinalgTilingCanonicalizationPatterns(context);
+    patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
+    FrozenRewritePatternList frozenPatterns(std::move(patterns));
+    while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
+      applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
+      PassManager pm(context);
+      pm.addPass(createLoopInvariantCodeMotionPass());
+      pm.addPass(createCanonicalizerPass());
+      pm.addPass(createCSEPass());
+      LogicalResult res = pm.run(getFunction().getParentOfType<ModuleOp>());
+      if (failed(res))
+        this->signalPassFailure();
+    }
+  }
+};
+} // namespace
+
 namespace mlir {
 void registerTestLinalgFusionTransforms() {
   PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
       "test-linalg-fusion-transform-patterns",
       "Test Linalg fusion transformation patterns by applying them greedily.");
 }
+void registerTestLinalgGreedyFusion() {
+  PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
+      "test-linalg-greedy-fusion",
+      "Test Linalg fusion by applying a greedy test transformation.");
+}
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ac29305e3c6f..196bda69dbaf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -61,6 +61,7 @@ void registerTestGpuParallelLoopMappingPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
 void registerTestLinalgFusionTransforms();
+void registerTestLinalgGreedyFusion();
 void registerTestLinalgHoisting();
 void registerTestLinalgTransforms();
 void registerTestLivenessPass();
@@ -121,6 +122,7 @@ void registerTestPasses() {
   registerTestInterfaces();
   registerTestLinalgCodegenStrategy();
   registerTestLinalgFusionTransforms();
+  registerTestLinalgGreedyFusion();
   registerTestLinalgHoisting();
   registerTestLinalgTransforms();
   registerTestLivenessPass();


        


More information about the Mlir-commits mailing list