[Mlir-commits] [mlir] 7ee3455 - [mlir][TilingInterface] Fix `iter_args` handling in tile (and fuse).

Mahesh Ravishankar llvmlistbot at llvm.org
Mon Sep 26 12:14:54 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-09-26T19:09:29Z
New Revision: 7ee34550f5495479428098256d0685c498036ec2

URL: https://github.com/llvm/llvm-project/commit/7ee34550f5495479428098256d0685c498036ec2
DIFF: https://github.com/llvm/llvm-project/commit/7ee34550f5495479428098256d0685c498036ec2.diff

LOG: [mlir][TilingInterface] Fix `iter_args` handling in tile (and fuse).

The current approach for handling `iter_args` was to replace all uses
of the value that is used as `init` value with the corresponding
region block argument within the `scf.for`. This is not always
correct. Instead a more deliberate approach needs to be taken to
handle these. If the slice being fused represents a slice of the
destination operand of the untiled op, then
- Make the destination of the fused producer the `init` value of the
  loop nest
- For the tiled and fused producer op created, replace the slice of
  the destination operand with a slice of the corresponding region
  iter arg of the innermost loop of the generated loop nest

Differential Revision: https://reviews.llvm.org/D134411

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fe7f1b03d3f1e..0fa064501fea0 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,6 +250,9 @@ def ForOp : SCF_Op<"for",
     void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
     void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
     void setStep(Value step) { getOperation()->setOperand(2, step); }
+    void setIterArg(unsigned iterArgNum, Value iterArgValue) {
+      getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
+    }
 
     /// Number of induction variables, always 1 for scf::ForOp.
     unsigned getNumInductionVars() { return 1; }
@@ -267,6 +270,17 @@ def ForOp : SCF_Op<"for",
     unsigned getNumIterOperands() {
       return getOperation()->getNumOperands() - getNumControlOperands();
     }
+    /// Get the iter arg number for an operand. If it isnt an iter arg
+    /// operand return llvm::None.
+    Optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
+      if (opOperand.getOwner() != getOperation())
+        return llvm::None;
+      unsigned operandNumber = opOperand.getOperandNumber();
+      if (operandNumber < getNumControlOperands())
+        return llvm::None;
+      return operandNumber - getNumControlOperands();
+    }
+
     /// Get the region iter arg that corresponds to an OpOperand.
     /// This helper prevents internal op implementation detail leakage to
     /// clients by hiding the operand / block argument mapping.

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 63b94a1bae72c..34a4e4f14cba6 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -44,13 +44,15 @@ class FuncOp;
 /// - `loop` isnt erased, but is left in a "no-op" state where the body of the
 ///   loop just yields the basic block arguments that correspond to the
 ///   initialization values of a loop. The loop is dead after this method.
-/// - All uses of the `newIterOperands` within the generated new loop
-///   are replaced with the corresponding `BlockArgument` in the loop body.
+/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
+///   `newIterOperands` within the generated new loop are replaced
+///   with the corresponding `BlockArgument` in the loop body.
 using NewYieldValueFn = std::function<SmallVector<Value>(
     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
 scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
                                     ValueRange newIterOperands,
-                                    const NewYieldValueFn &newYieldValuesFn);
+                                    const NewYieldValueFn &newYieldValuesFn,
+                                    bool replaceIterOperandsUsesInLoop = true);
 
 /// Update a perfectly nested loop nest to yield new values from the innermost
 /// loop and propagating it up through the loop nest. This function
@@ -64,12 +66,14 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
 ///   the body of the loop just yields the basic block arguments that correspond
 ///   to the initialization values of a loop. The original loops are dead after
 ///   this method.
-/// - All uses of the `newIterOperands` within the generated new loop
-///   are replaced with the corresponding `BlockArgument` in the loop body.
+/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
+///   `newIterOperands` within the generated new loop are replaced with the
+///   corresponding `BlockArgument` in the loop body.
 SmallVector<scf::ForOp>
 replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
                              ValueRange newIterOperands,
-                             const NewYieldValueFn &newYieldValueFn);
+                             const NewYieldValueFn &newYieldValueFn,
+                             bool replaceIterOperandsUsesInLoop = true);
 
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 15ca875977047..0c6ba3d195da5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -167,6 +167,44 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
   return loops;
 }
 
+/// If the tiled operation is in destination passing style, update the
+/// slice of the destination used (which refers to the untiled destination)
+/// to use the corresponding region argument of the innermost loop.
+///
+/// ```mlir
+/// %0 =
+/// scf.for %iv0 = ... iter_args(%arg = %0) {
+///   %1 = tensor.extract_slice %0
+///   %2 = tiled_op
+///   %3 = tensor.insert_slice %2 into %arg
+///   scf.yield %3
+/// }
+/// ```
+///
+/// is transformed to
+///
+/// ```mlir
+/// scf.for %iv0 = ... iter_args(%arg = %0) {
+///   %1 = tensor.extract_slice %arg
+///   %2 = tiled_op
+///   %3 = tensor.insert_slice %2 into %arg
+///   scf.yield %3
+/// }
+/// ```
+/// TODO: This can be made much cleaner when `DestinationStyleOp` interface is
+/// available generally.
+static void
+updateDestinationOperandsForTiledOp(OpBuilder &builder,
+                                    ValueRange tiledOpDestinationValues,
+                                    ValueRange bbArgsList) {
+  for (auto destValue : llvm::enumerate(tiledOpDestinationValues)) {
+    auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!sliceOp)
+      continue;
+    sliceOp.setOperand(0, bbArgsList[destValue.index()]);
+  }
+}
+
 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
                                           scf::SCFTilingOptions options,
                                           PatternBenefit benefit)
@@ -281,7 +319,6 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
 
   // 5. If the original operations has results, modify the loop nest to yield
   // the replacement values.
-  SmallVector<Value> replacements;
   if (tilingResult.loops.empty()) {
     // 5a. If there were no loops, the tiled implementation results are the
     // replacements.
@@ -289,7 +326,15 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
     return tilingResult;
   }
 
-  // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
+  // 6. Yield the results of the tiled operation from the loop nest as
+  //    replacements for the original untiled ops.
+  if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) {
+    return rewriter.notifyMatchFailure(
+        tilingResult.tiledOp,
+        "expected tiled op to have as many results as the untiled operation");
+  }
+
+  // `scf.for` with tensor semantics requires the loop nest to yield the
   // replacement values using destructive updates. Use the `TilingInterface`
   // to get the position of the result tiles and use that to generate the
   // destructive update pattern, i.e.,
@@ -335,7 +380,7 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
   };
   SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
       rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
-      yieldValueFn);
+      yieldValueFn, /*replaceIterOperandsUsesInLoops =*/false);
   for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
     rewriter.eraseOp(loop.value());
     tilingResult.loops[loop.index()] = newLoops[loop.index()];
@@ -363,36 +408,26 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::
     : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
       tilingPattern(context, std::move(options)) {}
 
-/// Return the `Value` that is defined by an operation that implements
-/// the `TilingInterface`. Looks through `iter_args` of scf.for nest
-/// if required.
-static Optional<OpResult> getFusableProducer(Value v) {
-  while (auto blockArg = v.dyn_cast<BlockArgument>()) {
-    auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
-    if (!loopOp)
-      return llvm::None;
-    v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
-  }
-  if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
-    return llvm::None;
-  return v.cast<OpResult>();
-}
-
-// Replace iter args of the outer most loop with region args of the inner most
-// one.
-static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
-                            PatternRewriter &rewriter) {
-  assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
-         "expect same number of iter args");
-  Block *block = &(*innerFor.getRegion().begin());
-  for (auto it :
-       llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
-    Value source = std::get<0>(it);
-    Value target = std::get<1>(it);
-    source.replaceUsesWithIf(target, [&](OpOperand &use) {
-      return use.getOwner()->getBlock() == block;
-    });
+/// Return the untiled producer whose slice is used in a tiled consumer. The
+/// method traverses the tile loop nest (`loops`) if needed, and returns the
+/// `iter_args` of the outer most that is encountered. Traversing the iter_args
+/// indicates that this is a destination operand of the consumer. If there was
+/// no loop traversal needed, the second value of the returned tuple is empty.
+static std::tuple<OpResult, Optional<OpOperand *>>
+getUntiledProducerFromSliceSource(OpOperand *source,
+                                  ArrayRef<scf::ForOp> loops) {
+  Optional<OpOperand *> destinationIterArg;
+  auto loopIt = loops.rbegin();
+  while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
+    scf::ForOp loop = *loopIt;
+    if (iterArg.getOwner()->getParentOp() != loop)
+      break;
+    source = &loop.getOpOperandForRegionIterArg(iterArg);
+    loopIt++;
   }
+  if (loopIt == loops.rend())
+    destinationIterArg = source;
+  return {source->get().dyn_cast<OpResult>(), destinationIterArg};
 }
 
 FailureOr<scf::SCFTileAndFuseResult>
@@ -441,8 +476,9 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
 
     // 2b. Get the producer of the source (potentially walking through
     // `iter_args` of nested `scf.for`)
-    Optional<OpResult> fusableProducer =
-        getFusableProducer(candidateSliceOp.getSource());
+    auto [fusableProducer, destinationIterArg] =
+        getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+                                          tileAndFuseResult.loops);
     if (!fusableProducer)
       continue;
 
@@ -450,7 +486,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
     rewriter.setInsertionPoint(candidateSliceOp);
     FailureOr<Value> fusedProducerValue =
         tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
-                                                     fusableProducer.value());
+                                                     fusableProducer);
     if (failed(fusedProducerValue))
       continue;
     rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
@@ -462,56 +498,81 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
     tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
     addCandidateSlices(fusedProducer, candidates);
 
-    // 2e. If the operation being fused creates a value that is used as `outs`
-    //     in the tiled operation, the result of the unfused operation will be
-    //     used in the `iter_args` of the tiled loop generated. When the
-    //     operation is fused, this use in `iter_args` needs to be modified to
-    //     use the destination of the fused operation. For example, starting
-    //     with
+    // 2e. If the slice is for a destination operand, for example,
     //
-    //     ```mlir
-    //     %0 = linalg.init_tensor ...
-    //     %1 = linalg.fill ... outs(%0:...)...
-    //     %2 = linalg.matmul ... outs(%1:...)....
-    //     ```
+    // ```mlir
+    // %0 = linalg.init
+    // %1 = linalg.fill .. outs(%0 : )
+    // %2 = scf.for .. iter_args(%arg0 = %1) {
+    //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
+    //     %4 = tensor.extract_slice %arg1 [..]
+    //     .. = linalg.matmul .. outs(%4 : )
+    //   }
+    // }
+    // ```
     //
-    //     First the `linalg.matmul` gets tiled
+    // the IR is currently
     //
-    //     ```mlir
-    //     %0 = linalg.init_tensor
-    //     %1 = linalg.fill
-    //     %2 = scf.for .... iter_args(%arg0 = %1)...
-    //        ...
-    //        ... = linalg.matmul ...
+    // ```
+    // %0 = linalg.init
+    // %1 = linalg.fill
+    // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
+    //   %3 = scf.for .. iter_args(%arg1 = %arg0) {
+    //     %4 = tensor.extract_slice %0 /*incorrect value */ [..]
+    //     %5 = linalg.fill .. outs(%4 : )
+    //     .. = linalg.matmul .. outs(%5 : )
+    //   }
+    // }
+    // ```
     //
-    //     ```
+    // The untiled `linalg.fill` is still used as the `init_value` since it
+    // was originally a destination operand of the untiled `linalg.matmul`.
+    // When fusing an operand that is a destination operand.
+    //   - Update the iter_arg of the outer most loop to use the destination
+    //     of the untiled producer.
+    //   - Update the destination of the slice of the tiled producer generated
+    //     to use the same basic block argument as the slice that was used to
+    //     generate inplace the tiled implementation of the producer.
+    // With this the IR will be.
     //
-    //     When the `linalg.fill` gets fused, the `iter_args` needs to be
-    //     modified
-    //
-    //     ```mlir
-    //     %0 = linalg.init_tensor
-    //     %1 = scf.for ... iter_args(%arg0 = %0)...
-    //        ...
-    //        %2 = linalg.fill ...
-    //        %3 = linalg.matmul ... outs(%2: ...)...
-    //     ```
-    TilingInterface unfusedProducerOp =
-        cast<TilingInterface>(fusableProducer->getOwner());
-    scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
-    SmallVector<Value> unfusedProducerOpDestValues =
-        unfusedProducerOp.getDestinationOperands(rewriter);
-    for (OpOperand &uses : unfusedProducerOp->getUses()) {
-      if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
-        unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
-        unsigned operandNumber = uses.getOperandNumber();
-        outerMostTiledLoop->setOperand(
-            operandNumber, unfusedProducerOpDestValues[resultNumber]);
+    // ```
+    // %0 = linalg.init
+    // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
+    //   %2 = scf.for .. iter_args(%arg1 = %arg0) {
+    //     %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
+    //     %4 = linalg.fill .. outs(%3 : )
+    //     .. = linalg.matmul .. outs(%4 : )
+    //   }
+    // }
+    // ```
+    // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
+    // Update to use that when it does become available.
+    scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
+    Optional<unsigned> iterArgNumber;
+    if (destinationIterArg) {
+      iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
+          *destinationIterArg.value());
+    }
+    if (iterArgNumber) {
+      unsigned resultNumber = fusableProducer.getResultNumber();
+      if (auto producerOp =
+              dyn_cast<TilingInterface>(fusableProducer.getOwner())) {
+        SmallVector<Value> destination =
+            producerOp.getDestinationOperands(rewriter);
+        outerMostLoop.setIterArg(iterArgNumber.value(),
+                                 destination[resultNumber]);
+      }
+      if (auto tiledAndFusedInterfaceOp =
+              fusedProducerValue.value().getDefiningOp<TilingInterface>()) {
+        scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
+        SmallVector<Value> destination =
+            tiledAndFusedInterfaceOp.getDestinationOperands(rewriter);
+        updateDestinationOperandsForTiledOp(
+            rewriter, destination[resultNumber],
+            innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
       }
     }
   }
-  replaceIterArgs(tileAndFuseResult.loops.front(),
-                  tileAndFuseResult.loops.back(), rewriter);
   return tileAndFuseResult;
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e40fc9cade586..777187387823b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -40,7 +40,8 @@ struct LoopParams {
 scf::ForOp
 mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
                                ValueRange newIterOperands,
-                               const NewYieldValueFn &newYieldValuesFn) {
+                               const NewYieldValueFn &newYieldValuesFn,
+                               bool replaceIterOperandsUsesInLoop) {
   // Create a new loop before the existing one, with the extra operands.
   OpBuilder::InsertionGuard g(builder);
   builder.setInsertionPoint(loop);
@@ -79,13 +80,15 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
        llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
 
-  // Replace all uses of `newIterOperands` with the corresponding basic block
-  // arguments.
-  for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
-    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
-      Operation *user = use.getOwner();
-      return newLoop->isProperAncestor(user);
-    });
+  if (replaceIterOperandsUsesInLoop) {
+    // Replace all uses of `newIterOperands` with the corresponding basic block
+    // arguments.
+    for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
+      std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
+        Operation *user = use.getOwner();
+        return newLoop->isProperAncestor(user);
+      });
+    }
   }
 
   // Replace all uses of the original loop with corresponding values from the
@@ -104,7 +107,8 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
 
 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
-    ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn) {
+    ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn,
+    bool replaceIterOperandsUsesInLoop) {
   if (loopNest.empty())
     return {};
   SmallVector<scf::ForOp> newLoopNest(loopNest.size());
@@ -121,8 +125,41 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
               newIterOperands.size()));
       return newYields;
     };
-    newLoopNest[loopDepth] = replaceLoopWithNewYields(
-        builder, loopNest[loopDepth], newIterOperands, fn);
+    newLoopNest[loopDepth] =
+        replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands,
+                                 fn, replaceIterOperandsUsesInLoop);
+    if (!replaceIterOperandsUsesInLoop) {
+      /// The yield is expected to producer the following structure
+      /// ```
+      /// %0 = scf.for ... iter_args(%arg0 = %init) {
+      ///   %1 = scf.for ... iter_args(%arg1 = %arg0) {
+      ///     scf.yield %yield
+      ///   }
+      /// }
+      /// ```
+      ///
+      /// since the yield is propagated from inside out, after the inner
+      /// loop is processed the IR is in this form
+      ///
+      /// ```
+      /// scf.for ... iter_args {
+      ///   %1 = scf.for ... iter_args(%arg1 = %init) {
+      ///     scf.yield %yield
+      ///   }
+      /// ```
+      ///
+      /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do.
+      /// `%init` will be replaced with `%arg0` when it is created for the
+      /// outer loop. But without that this has to be done explicitly.
+      unsigned subLen = newIterOperands.size();
+      unsigned subStart =
+          newLoopNest[loopDepth + 1].getNumIterOperands() - subLen;
+      auto resetOperands =
+          newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart,
+                                                                subLen);
+      resetOperands.assign(
+          newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen));
+    }
   }
   return newLoopNest;
 }

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 61aa706b10ae4..dd8631f1fc157 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -30,7 +30,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
 // CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
 // CHECK-SAME:           outs(%[[FILL_TILE]] :
 //      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-//      CHECK        scf.yield %[[INSERT]]
+//      CHECK:       scf.yield %[[INSERT]]
 
 // -----
 
@@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
 //  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
 // CHECK-SAME:           outs(%[[INIT_TILE]] :
 //      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
@@ -80,7 +80,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:           ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
 // CHECK-SAME:           outs(%[[OUTS_TILE]] :
 //      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-//      CHECK        scf.yield %[[INSERT]]
+//      CHECK:       scf.yield %[[INSERT]]
 
 // -----
 
@@ -130,7 +130,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
 // CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
 // CHECK-SAME:         outs(%[[FILL1_TILE]] :
 //      CHECK:     %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
-//      CHECK      scf.yield %[[INSERT]]
+//      CHECK:     scf.yield %[[INSERT]]
 
 // -----
 
@@ -182,7 +182,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
 // CHECK-SAME:           ins(%[[GEMM_TILE]] :
 // CHECK-SAME:           outs(%[[OUTS_TILE]] :
 //      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
-//      CHECK        scf.yield %[[INSERT]]
+//      CHECK:       scf.yield %[[INSERT]]
 
 // -----
 
@@ -218,7 +218,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
 // CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
 //  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
-//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
 //      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
 // CHECK-SAME:           outs(%[[INIT_TILE]] :
 //      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
@@ -229,7 +229,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
 // CHECK-SAME:           ins(%[[GEMM_TILE]] :
 // CHECK-SAME:           outs(%[[INIT_TILE_2]] :
 //      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
-//      CHECK        scf.yield %[[INSERT]]
+//      CHECK:       scf.yield %[[INSERT]]
 
 // -----
 


        


More information about the Mlir-commits mailing list