[Mlir-commits] [mlir] 9ecc817 - [mlir] Add support for fusion into TiledLoopOp.

Alexander Belyaev llvmlistbot at llvm.org
Fri May 21 09:14:18 PDT 2021


Author: Alexander Belyaev
Date: 2021-05-21T18:13:45+02:00
New Revision: 9ecc8178d72097c8f9e31ea7c50085748d187aff

URL: https://github.com/llvm/llvm-project/commit/9ecc8178d72097c8f9e31ea7c50085748d187aff
DIFF: https://github.com/llvm/llvm-project/commit/9ecc8178d72097c8f9e31ea7c50085748d187aff.diff

LOG: [mlir] Add support for fusion into TiledLoopOp.

Differential Revision: https://reviews.llvm.org/D102722

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index dd5c2d5a51d98..334cecb3eb477 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -584,6 +584,15 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
   ];
 
   let extraClassDeclaration = [{
+    /// Number of loops
+    unsigned getNumLoops() { return step().size(); }
+
+    /// Number of input operands
+    unsigned getNumInputs() { return inputs().size(); }
+
+    /// Number of output operands
+    unsigned getNumOutputs() { return outputs().size(); }
+
     /// Number of operands controlling the loop: lbs, ubs, steps
     unsigned getNumControlOperands() { return 3 * getNumLoops(); }
 
@@ -597,7 +606,6 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
       return getBody()->getArguments().take_back(outputs().size());
     }
 
-
     void setLowerBounds(ValueRange lowerBounds) {
       unsigned numLoops = getNumLoops();
       assert(lowerBounds.size() == numLoops &&
@@ -622,6 +630,16 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
         setOperand(pos, steps[i]);
     }
 
+    /// Block argument that corresponds to the `input` or `output` operand.
+    BlockArgument getTiedBlockArgument(OpOperand& operand) {
+      auto operandIndex = operand.getOperandNumber();
+      assert(
+          operandIndex >= getNumControlOperands() &&
+          operandIndex < getNumOperands() &&
+          "tied block arg is defined only for `input` and `output` arguments");
+      return getBody()->getArgument(operandIndex - 2 * getNumLoops());
+    }
+
    /// Result that corresponds to the `outputs` argument of tensor type.
    OpResult getTiedOpResult(OpOperand& opOperand) {
       // No result can correspond to a memref argument.
@@ -642,7 +660,76 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
       return getOperation()->getResult(tensorId);
     }
 
-    unsigned getNumLoops() { return step().size(); }
+    /// Append `operand` to the `input` arguments.
+    OpOperand& appendInputOperand(OpBuilder& builder, Value operand) {
+      int numLoops = getNumLoops();
+      int numInputs = getNumInputs();
+      int numOutputs = getNumOutputs();
+
+      getOperation()->insertOperands(getNumControlOperands() + numInputs,
+                                     operand);
+      getBody()->insertArgument(numLoops + numInputs, operand.getType());
+      getOperation()->setAttr(
+          TiledLoopOp::getOperandSegmentSizeAttr(),
+          builder.getI32VectorAttr(
+              {numLoops, numLoops, numLoops, numInputs + 1, numOutputs}));
+      return getOperation()->getOpOperand(getNumControlOperands() + numInputs);
+    }
+
+    /// Append `operand` to the `output` arguments.
+    OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) {
+      int numLoops = getNumLoops();
+      int numInputs = getNumInputs();
+      int numOutputs = getNumOutputs();
+
+      getOperation()->insertOperands(
+          getNumControlOperands() + numInputs + numOutputs, operand);
+      getBody()->insertArgument(numLoops + numInputs + numOutputs,
+                                operand.getType());
+      getOperation()->setAttr(
+          TiledLoopOp::getOperandSegmentSizeAttr(),
+          builder.getI32VectorAttr(
+              {numLoops, numLoops, numLoops, numInputs, numOutputs + 1}));
+      return getOperation()->getOpOperand(getNumControlOperands() + numInputs +
+                                          numOutputs);
+    }
+
+    /// Erase `operand` from the `input` or `output` arguments.
+    void eraseOperand(OpBuilder& builder, OpOperand& operand) {
+      int numInputs = getNumInputs();
+      int numLoops = getNumLoops();
+      int numOutputs = getNumOutputs();
+      int numControlOperands = getNumControlOperands();
+
+      auto operandIndex = operand.getOperandNumber();
+      assert(operandIndex >= numControlOperands &&
+             operandIndex < getNumOperands() &&
+             "Can erase only `input` or `output` operand");
+
+      if (operandIndex >= numControlOperands + numInputs)
+        --numOutputs;
+      else
+        --numInputs;
+
+      getOperation()->eraseOperand(operandIndex);
+      getBody()->eraseArgument(operandIndex - 2 * numLoops);
+      getOperation()->setAttr(
+          TiledLoopOp::getOperandSegmentSizeAttr(),
+          builder.getI32VectorAttr(
+              {numLoops, numLoops, numLoops, numInputs, numOutputs}));
+    }
+
+    OpOperand* findInputOperand(Value value) {
+      OperandRange::iterator it = llvm::find(inputs(), value);
+      if (it == inputs().end()) return nullptr;
+      return it.getBase();
+    }
+
+    OpOperand* findOutputOperand(Value value) {
+      OperandRange::iterator it = llvm::find(outputs(), value);
+      if (it == outputs().end()) return nullptr;
+      return it.getBase();
+    }
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 108d0414c34bf..465f933f862cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -107,6 +107,66 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
   llvm_unreachable("Expect to be able to extract a shape defining loop range");
 }
 
+// Return tiled operands for the fused producer op. When fusing into
+// `linalg.tiled_loop` one has to update `input` and `output` arguments of the
+// loop correspondingly.
+// Each input tensor of the producer op has to be added to `inputs` of the
+// `tiled_loop` if it is not present there already. Each output tensor has to
+// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
+// on whether the correponding result is an input or an output to the loop.
+//
+// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
+// intermediate result is not used by any other operation but the consumer. A
+// more generic way is to append all missing output tensors of the producer to
+// the tiled loop outputs and hence modify the number of the results, since we
+// would need to add the intermediate results to `linalg.yield`. After that a
+// canonicalization pass would move the unused output args of the `tiled_loop`
+// to the `input` section.
+static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
+  auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
+  if (!tiledLoop)
+    return llvm::to_vector<4>(producer.getShapedOperands());
+
+  SmallVector<Value, 4> tiledOperands;
+  assert(producer.hasTensorSemantics() &&
+         "only fusion on tensors is currently supported for TiledLinalgOp");
+
+  for (auto producerInput : producer.getInputTensors()) {
+    OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
+    if (addedInput == nullptr)
+      addedInput = &tiledLoop.appendInputOperand(b, producerInput);
+    BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
+    tiledOperands.push_back(addedBlockArg);
+  }
+  for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
+    Value producerOutput = en.value();
+
+    Value result = producer->getResult(en.index());
+    OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
+    OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
+    assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
+           "The result should be present in `input` or `output` args of "
+           "`tiled_loop");
+
+    bool isInput = resultInputOperand;
+    int opNumber = isInput ? resultInputOperand->getOperandNumber()
+                           : resultOutputOperand->getOperandNumber();
+
+    OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
+    if (addedOutput == nullptr)
+      addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
+                            : &tiledLoop.appendOutputOperand(b, producerOutput);
+
+    OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
+    auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
+    auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
+    resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
+    tiledLoop.eraseOperand(b, resultOperand);
+    tiledOperands.push_back(addedBlockArg);
+  }
+  return tiledOperands;
+}
+
 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
 /// provides the loop range information for the fused loops. The rest are
 /// obtained from the producer itself, since they are not tiled + fused.
@@ -143,8 +203,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   clonedShapes.reserve(producer.getNumShapedOperands());
 
   // Compute subranges for all tensor input/output operands.
-  auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
-  clonedShapes.append(makeTiledShapes(b, loc, producer, tiledOperands, ivs,
+  clonedShapes.append(makeTiledShapes(b, loc, producer,
+                                      getTiledOperands(b, producer), ivs,
                                       tileSizes, sizeBounds));
 
   // Append the other operands.
@@ -808,7 +868,7 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
     origOpToFusedOp[origOp.getOperation()] = fusedOp;
     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
 
-    // Prepare the b for the next insertion point.
+    // Prepare the builder for the next insertion point.
     auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
     if (!origOp.hasTensorSemantics())
       continue;
@@ -844,16 +904,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
     //  2. encode destructive updates that may be inplaceable by bufferization.
     // To keep the second type of information while letting the unfused op die
     // unused, we need to forward the producer output operand.
-    for (auto &operand :
-         cast<scf::ForOp>(tiledLinalgOp.loops.front()).getIterOpOperands())
-      if (auto opResult = operand.get().dyn_cast<OpResult>())
-        if (opResult.getOwner() == origOp)
-          operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
+    if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
+      for (auto &operand : forOp.getIterOpOperands())
+        if (auto opResult = operand.get().dyn_cast<OpResult>())
+          if (opResult.getOwner() == origOp)
+            operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
+    }
   }
   return fusedOps;
 }
 
-template <typename LoopType>
 static Optional<TiledAndFusedLinalgOps>
 tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
                          const LinalgDependenceGraph &dependenceGraph,
@@ -928,11 +988,9 @@ mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
                                    const LinalgTilingOptions &tilingOptions) {
   switch (tilingOptions.loopType) {
   case LinalgTilingLoopType::Loops:
-    return tileAndFuseLinalgOpsImpl<scf::ForOp>(b, ops, dependenceGraph,
-                                                tilingOptions);
   case LinalgTilingLoopType::ParallelLoops:
-    return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(b, ops, dependenceGraph,
-                                                     tilingOptions);
+  case LinalgTilingLoopType::TiledLoops:
+    return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
   default:;
   }
   return llvm::None;

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 3775b67c43548..dc8a4acb4b4f9 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -1,15 +1,16 @@
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse  --split-input-file | FileCheck %s --check-prefix=TLOOP
 
 module {
-  func @matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
-                      %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
-                      %arg4: tensor<?x?xf32>) -> tensor<?x?xf32> {
-    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>   // <MxN1> <N1xN2>
-    %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
-      ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32>   // <MxN2> <N2xN3>
-    return %1 : tensor<?x?xf32>
+  func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+                      %AB_init: tensor<?x?xf32>, %C: tensor<?x?xf32>,
+                      %ABC_init: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%AB_init : tensor<?x?xf32>) -> tensor<?x?xf32>   // <MxN1> <N1xN2>
+    %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
+      ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%ABC_init : tensor<?x?xf32>) -> tensor<?x?xf32>   // <MxN2> <N2xN3>
+    return %ABC : tensor<?x?xf32>
   }
 }
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)>
@@ -90,6 +91,64 @@ module {
 //      CHECK:   }
 //      CHECK:   return %[[RESULT]]
 
+// TLOOP-LABEL:  func @matmul_fusion(
+// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+// TLOOP:  %[[C32:.*]] = constant 32 : index
+// TLOOP:  %[[C64:.*]] = constant 64 : index
+// TLOOP:  %[[C16:.*]] = constant 16 : index
+// TLOOP:  %[[C0:.*]] = constant 0 : index
+// TLOOP:  %[[C1:.*]] = constant 1 : index
+
+// TLOOP:  %[[DIM_A0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+
+// TLOOP:  %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]]) 
+// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]]) 
+// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor<?x?xf32>,
+// TLOOP-SAME:      %[[A_:.*]] = %[[A]]: tensor<?x?xf32>,
+// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: tensor<?x?xf32>,
+// TLOOP-SAME:      %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor<?x?xf32>)
+// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor<?x?xf32>) {
+
+// TLOOP:    %[[ABC_INIT_SUB:.*]] = subtensor %[[ABC_INIT_]][%[[IV0]], 0]
+// TLOOP:    %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
+// TLOOP:    %[[AB_INIT_SUB:.*]] = subtensor %[[AB_INIT_]][%[[IV0]], 0]
+
+// TLOOP:    %[[AB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME:  ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
+
+// TLOOP:    %[[DIM_B_1:.*]] = memref.dim %[[B_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_C_1:.*]] = memref.dim %[[C_]], %[[C1]] : [[TY]]
+
+// TLOOP:    %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = 
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C64]], %[[C16]]) 
+// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]],
+// TLOOP-SAME:      %[[C__:.*]] = %[[C_]]: [[TY]])
+// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]])
+// TLOOP-SAME: iterators["parallel", "reduction"] {
+
+// TLOOP:      %[[AB_SUB_SUB:.*]] = subtensor %[[AB_SUB_]][0, %[[IV2]]]
+// TLOOP:      %[[C__SUB:.*]] = subtensor %[[C__]][%[[IV2]], %[[IV1]]]
+// TLOOP:      %[[ABS_INIT_SUB_SUB:.*]] = subtensor %[[ABC_INIT_SUB_]][0, %[[IV1]]]
+
+// TLOOP:      %[[ABC_SUB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME:  ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME:  outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
+
+// TLOOP:      %[[RES0:.*]] = subtensor_insert %[[ABC_SUB_SUB]]
+// TLOOP-SAME:   into %[[ABC_INIT_SUB_]][0, %[[IV1]]]
+// TLOOP:      linalg.yield %[[RES0]] : [[TY]]
+// TLOOP:    }
+// TLOOP:    %[[RES1:.*]] = subtensor_insert %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0]
+// TLOOP:    linalg.yield %[[RES1]] : [[TY]]
+// TLOOP:  }
+// TLOOP:  return %[[ABC]] : [[TY]]
+
 // -----
 
 module {
@@ -144,6 +203,48 @@ module {
 //       CHECK:     scf.yield %[[YIELD]]
 //       CHECK:   return %[[RESULT]]
 
+// TLOOP-LABEL: func @matmul_plus_matmul
+// TLOOP-SAME:    %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME:    %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// TLOOP-SAME:    %[[AB:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
+// TLOOP:  %[[C32:.*]] = constant 32 : index
+// TLOOP:  %[[C64:.*]] = constant 64 : index
+// TLOOP:  %[[C0:.*]] = constant 0 : index
+// TLOOP:  %[[C1:.*]] = constant 1 : index
+
+// TLOOP:  %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+// TLOOP:  %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
+
+// TLOOP:  %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]]
+
+// TLOOP:  %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C32]], %[[C64]])
+// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
+// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]],
+// TLOOP-SAME:      %[[AB_:.*]] = %[[AB]]: [[TY]])
+// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) {
+
+// TLOOP:    %[[INIT_SUB:.*]] = subtensor %[[INIT_]][%[[IV0]], %[[IV1]]]
+// TLOOP:    %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
+// TLOOP:    %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[IV1]]]
+// TLOOP:    %[[AB_SUB_INIT:.*]] = subtensor %[[AB_]][%[[IV0]], %[[IV1]]]
+
+// TLOOP:    %[[AB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME:  ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME:  outs(%[[AB_SUB_INIT]] : [[TY]])
+
+// TLOOP:    %[[DOUBLE_AB:.*]] = linalg.generic
+// TLOOP-SAME:  ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]])
+
+// TLOOP:    %[[RESULT_SUB:.*]] = subtensor_insert
+// TLOOP-SAME:  %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]]
+
+// TLOOP:    linalg.yield %[[RESULT_SUB]] : [[TY]]
+// TLOOP:  }
+// TLOOP:  return %[[RESULT]] : [[TY]]
+
 // -----
 
 module {
@@ -174,3 +275,53 @@ module {
 //       CHECK:       scf.yield %[[ST_MM]] : tensor<?x?xf32>
 //       CHECK:     %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}}
 //       CHECK:     scf.yield %[[MM]] : tensor<?x?xf32>
+
+
+// TLOOP-LABEL: func @matmul_out_fusion(
+// TLOOP-SAME:    %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME:    %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// TLOOP-SAME:    %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
+// TLOOP-DAG:  %[[C0_F32:.*]] = constant 0.0
+// TLOOP-DAG:  %[[C32:.*]] = constant 32 : index
+// TLOOP-DAG:  %[[C64:.*]] = constant 64 : index
+// TLOOP-DAG:  %[[C16:.*]] = constant 16 : index
+// TLOOP-DAG:  %[[C0:.*]] = constant 0 : index
+// TLOOP-DAG:  %[[C1:.*]] = constant 1 : index
+
+// TLOOP:  %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
+// TLOOP:  %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
+
+// TLOOP:  %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = 
+// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
+// TLOOP-SAME: step (%[[C32]], %[[C64]])
+// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
+// TLOOP-SAME:      %[[B_:.*]] = %[[B]]: [[TY]])
+// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
+
+// TLOOP:    %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
+// TLOOP:    %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
+// TLOOP:    %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP:    %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
+
+// TLOOP:    %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) 
+// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
+// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
+// TLOOP-SAME:      %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
+// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
+// TLOOP-SAME: iterators["reduction"] {
+
+// TLOOP:      %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]]
+// TLOOP:      %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0]
+
+// TLOOP:      %[[AB_SUB_SUB:.*]] = linalg.matmul
+// TLOOP-SAME:   ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
+// TLOOP-SAME:   outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
+// TLOOP:      linalg.yield %[[AB_SUB_SUB]] : [[TY]]
+// TLOOP:    }
+// TLOOP:    %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]]
+// TLOOP-SAME:  into %[[OUT_]][%[[I]], %[[J]]]
+// TLOOP:    linalg.yield %[[SUB_RESULT]] : [[TY]]
+// TLOOP:  }
+// TLOOP:  return %[[AB]] : [[TY]]

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 3e67774ba13a4..4413faca5dc61 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -278,6 +278,13 @@ void registerTestLinalgTensorFusionTransforms() {
           "Test Linalg on tensor fusion transformation "
           "patterns by applying them greedily.");
 }
+void registerTestLinalgTiledLoopFusionTransforms() {
+  PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
+      testTiledLoopFusionTransformsPass(
+          "test-linalg-tiled-loop-fusion-transform-patterns",
+          "Test Linalg on tensor fusion transformation "
+          "patterns by applying them greedily.");
+}
 void registerTestLinalgGreedyFusion() {
   PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
       "test-linalg-greedy-fusion",

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1ee3b6a0f2ace..23bfe775cae2c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -81,6 +81,7 @@ void registerTestLinalgElementwiseFusion();
 void registerTestPushExpandingReshape();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
+void registerTestLinalgTiledLoopFusionTransforms();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgHoisting();
 void registerTestLinalgTileAndFuseSequencePass();
@@ -159,6 +160,7 @@ void registerTestPasses() {
   test::registerTestPushExpandingReshape();
   test::registerTestLinalgFusionTransforms();
   test::registerTestLinalgTensorFusionTransforms();
+  test::registerTestLinalgTiledLoopFusionTransforms();
   test::registerTestLinalgGreedyFusion();
   test::registerTestLinalgHoisting();
   test::registerTestLinalgTileAndFuseSequencePass();


        


More information about the Mlir-commits mailing list