[Mlir-commits] [mlir] 7ee3455 - [mlir][TilingInterface] Fix `iter_args` handling in tile (and fuse).
Mahesh Ravishankar
llvmlistbot at llvm.org
Mon Sep 26 12:14:54 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-09-26T19:09:29Z
New Revision: 7ee34550f5495479428098256d0685c498036ec2
URL: https://github.com/llvm/llvm-project/commit/7ee34550f5495479428098256d0685c498036ec2
DIFF: https://github.com/llvm/llvm-project/commit/7ee34550f5495479428098256d0685c498036ec2.diff
LOG: [mlir][TilingInterface] Fix `iter_args` handling in tile (and fuse).
The current approach for handling `iter_args` was to replace all uses
of the value that is used as `init` value with the corresponding
region block argument within the `scf.for`. This is not always
correct. Instead a more deliberate approach needs to be taken to
handle these. If the slice being fused represents a slice of the
destination operand of the untiled op, then
- Make the destination of the fused producer the `init` value of the
loop nest
- For the tiled and fused producer op created, replace the slice of
the destination operand with a slice of the corresponding region
iter arg of the innermost loop of the generated loop nest
Differential Revision: https://reviews.llvm.org/D134411
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fe7f1b03d3f1e..0fa064501fea0 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,6 +250,9 @@ def ForOp : SCF_Op<"for",
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
+ void setIterArg(unsigned iterArgNum, Value iterArgValue) {
+ getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
+ }
/// Number of induction variables, always 1 for scf::ForOp.
unsigned getNumInductionVars() { return 1; }
@@ -267,6 +270,17 @@ def ForOp : SCF_Op<"for",
unsigned getNumIterOperands() {
return getOperation()->getNumOperands() - getNumControlOperands();
}
+ /// Get the iter arg number for an operand. If it isnt an iter arg
+ /// operand return llvm::None.
+ Optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
+ if (opOperand.getOwner() != getOperation())
+ return llvm::None;
+ unsigned operandNumber = opOperand.getOperandNumber();
+ if (operandNumber < getNumControlOperands())
+ return llvm::None;
+ return operandNumber - getNumControlOperands();
+ }
+
/// Get the region iter arg that corresponds to an OpOperand.
/// This helper prevents internal op implementation detail leakage to
/// clients by hiding the operand / block argument mapping.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 63b94a1bae72c..34a4e4f14cba6 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -44,13 +44,15 @@ class FuncOp;
/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
/// loop just yields the basic block arguments that correspond to the
/// initialization values of a loop. The loop is dead after this method.
-/// - All uses of the `newIterOperands` within the generated new loop
-/// are replaced with the corresponding `BlockArgument` in the loop body.
+/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
+/// `newIterOperands` within the generated new loop are replaced
+/// with the corresponding `BlockArgument` in the loop body.
using NewYieldValueFn = std::function<SmallVector<Value>(
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
ValueRange newIterOperands,
- const NewYieldValueFn &newYieldValuesFn);
+ const NewYieldValueFn &newYieldValuesFn,
+ bool replaceIterOperandsUsesInLoop = true);
/// Update a perfectly nested loop nest to yield new values from the innermost
/// loop and propagating it up through the loop nest. This function
@@ -64,12 +66,14 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
/// the body of the loop just yields the basic block arguments that correspond
/// to the initialization values of a loop. The original loops are dead after
/// this method.
-/// - All uses of the `newIterOperands` within the generated new loop
-/// are replaced with the corresponding `BlockArgument` in the loop body.
+/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
+/// `newIterOperands` within the generated new loop are replaced with the
+/// corresponding `BlockArgument` in the loop body.
SmallVector<scf::ForOp>
replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
ValueRange newIterOperands,
- const NewYieldValueFn &newYieldValueFn);
+ const NewYieldValueFn &newYieldValueFn,
+ bool replaceIterOperandsUsesInLoop = true);
/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 15ca875977047..0c6ba3d195da5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -167,6 +167,44 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
return loops;
}
+/// If the tiled operation is in destination passing style, update the
+/// slice of the destination used (which refers to the untiled destination)
+/// to use the corresponding region argument of the innermost loop.
+///
+/// ```mlir
+/// %0 =
+/// scf.for %iv0 = ... iter_args(%arg = %0) {
+/// %1 = tensor.extract_slice %0
+/// %2 = tiled_op
+/// %3 = tensor.insert_slice %2 into %arg
+/// scf.yield %3
+/// }
+/// ```
+///
+/// is transformed to
+///
+/// ```mlir
+/// scf.for %iv0 = ... iter_args(%arg = %0) {
+/// %1 = tensor.extract_slice %arg
+/// %2 = tiled_op
+/// %3 = tensor.insert_slice %2 into %arg
+/// scf.yield %3
+/// }
+/// ```
+/// TODO: This can be made much cleaner when `DestinationStyleOp` interface is
+/// available generally.
+static void
+updateDestinationOperandsForTiledOp(OpBuilder &builder,
+ ValueRange tiledOpDestinationValues,
+ ValueRange bbArgsList) {
+ for (auto destValue : llvm::enumerate(tiledOpDestinationValues)) {
+ auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!sliceOp)
+ continue;
+ sliceOp.setOperand(0, bbArgsList[destValue.index()]);
+ }
+}
+
scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
scf::SCFTilingOptions options,
PatternBenefit benefit)
@@ -281,7 +319,6 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
// 5. If the original operations has results, modify the loop nest to yield
// the replacement values.
- SmallVector<Value> replacements;
if (tilingResult.loops.empty()) {
// 5a. If there were no loops, the tiled implementation results are the
// replacements.
@@ -289,7 +326,15 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
return tilingResult;
}
- // 5b. `scf.for` with tensor semantics requires the loop nest to yield the
+ // 6. Yield the results of the tiled operation from the loop nest as
+ // replacements for the original untiled ops.
+ if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) {
+ return rewriter.notifyMatchFailure(
+ tilingResult.tiledOp,
+ "expected tiled op to have as many results as the untiled operation");
+ }
+
+ // `scf.for` with tensor semantics requires the loop nest to yield the
// replacement values using destructive updates. Use the `TilingInterface`
// to get the position of the result tiles and use that to generate the
// destructive update pattern, i.e.,
@@ -335,7 +380,7 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
};
SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
- yieldValueFn);
+ yieldValueFn, /*replaceIterOperandsUsesInLoops =*/false);
for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
rewriter.eraseOp(loop.value());
tilingResult.loops[loop.index()] = newLoops[loop.index()];
@@ -363,36 +408,26 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
tilingPattern(context, std::move(options)) {}
-/// Return the `Value` that is defined by an operation that implements
-/// the `TilingInterface`. Looks through `iter_args` of scf.for nest
-/// if required.
-static Optional<OpResult> getFusableProducer(Value v) {
- while (auto blockArg = v.dyn_cast<BlockArgument>()) {
- auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
- if (!loopOp)
- return llvm::None;
- v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
- }
- if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
- return llvm::None;
- return v.cast<OpResult>();
-}
-
-// Replace iter args of the outer most loop with region args of the inner most
-// one.
-static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
- PatternRewriter &rewriter) {
- assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
- "expect same number of iter args");
- Block *block = &(*innerFor.getRegion().begin());
- for (auto it :
- llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
- Value source = std::get<0>(it);
- Value target = std::get<1>(it);
- source.replaceUsesWithIf(target, [&](OpOperand &use) {
- return use.getOwner()->getBlock() == block;
- });
+/// Return the untiled producer whose slice is used in a tiled consumer. The
+/// method traverses the tile loop nest (`loops`) if needed, and returns the
+/// `iter_args` of the outer most that is encountered. Traversing the iter_args
+/// indicates that this is a destination operand of the consumer. If there was
+/// no loop traversal needed, the second value of the returned tuple is empty.
+static std::tuple<OpResult, Optional<OpOperand *>>
+getUntiledProducerFromSliceSource(OpOperand *source,
+ ArrayRef<scf::ForOp> loops) {
+ Optional<OpOperand *> destinationIterArg;
+ auto loopIt = loops.rbegin();
+ while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
+ scf::ForOp loop = *loopIt;
+ if (iterArg.getOwner()->getParentOp() != loop)
+ break;
+ source = &loop.getOpOperandForRegionIterArg(iterArg);
+ loopIt++;
}
+ if (loopIt == loops.rend())
+ destinationIterArg = source;
+ return {source->get().dyn_cast<OpResult>(), destinationIterArg};
}
FailureOr<scf::SCFTileAndFuseResult>
@@ -441,8 +476,9 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
// 2b. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
- Optional<OpResult> fusableProducer =
- getFusableProducer(candidateSliceOp.getSource());
+ auto [fusableProducer, destinationIterArg] =
+ getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+ tileAndFuseResult.loops);
if (!fusableProducer)
continue;
@@ -450,7 +486,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
rewriter.setInsertionPoint(candidateSliceOp);
FailureOr<Value> fusedProducerValue =
tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
- fusableProducer.value());
+ fusableProducer);
if (failed(fusedProducerValue))
continue;
rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
@@ -462,56 +498,81 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
addCandidateSlices(fusedProducer, candidates);
- // 2e. If the operation being fused creates a value that is used as `outs`
- // in the tiled operation, the result of the unfused operation will be
- // used in the `iter_args` of the tiled loop generated. When the
- // operation is fused, this use in `iter_args` needs to be modified to
- // use the destination of the fused operation. For example, starting
- // with
+ // 2e. If the slice is for a destination operand, for example,
//
- // ```mlir
- // %0 = linalg.init_tensor ...
- // %1 = linalg.fill ... outs(%0:...)...
- // %2 = linalg.matmul ... outs(%1:...)....
- // ```
+ // ```mlir
+ // %0 = linalg.init
+ // %1 = linalg.fill .. outs(%0 : )
+ // %2 = scf.for .. iter_args(%arg0 = %1) {
+ // %3 = scf.for .. iter_args(%arg1 = %arg0) {
+ // %4 = tensor.extract_slice %arg1 [..]
+ // .. = linalg.matmul .. outs(%4 : )
+ // }
+ // }
+ // ```
//
- // First the `linalg.matmul` gets tiled
+ // the IR is currently
//
- // ```mlir
- // %0 = linalg.init_tensor
- // %1 = linalg.fill
- // %2 = scf.for .... iter_args(%arg0 = %1)...
- // ...
- // ... = linalg.matmul ...
+ // ```
+ // %0 = linalg.init
+ // %1 = linalg.fill
+ // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
+ // %3 = scf.for .. iter_args(%arg1 = %arg0) {
+ // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
+ // %5 = linalg.fill .. outs(%4 : )
+ // .. = linalg.matmul .. outs(%5 : )
+ // }
+ // }
+ // ```
//
- // ```
+ // The untiled `linalg.fill` is still used as the `init_value` since it
+ // was originally a destination operand of the untiled `linalg.matmul`.
+ // When fusing an operand that is a destination operand.
+ // - Update the iter_arg of the outer most loop to use the destination
+ // of the untiled producer.
+ // - Update the destination of the slice of the tiled producer generated
+ // to use the same basic block argument as the slice that was used to
+ // generate inplace the tiled implementation of the producer.
+ // With this the IR will be.
//
- // When the `linalg.fill` gets fused, the `iter_args` needs to be
- // modified
- //
- // ```mlir
- // %0 = linalg.init_tensor
- // %1 = scf.for ... iter_args(%arg0 = %0)...
- // ...
- // %2 = linalg.fill ...
- // %3 = linalg.matmul ... outs(%2: ...)...
- // ```
- TilingInterface unfusedProducerOp =
- cast<TilingInterface>(fusableProducer->getOwner());
- scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
- SmallVector<Value> unfusedProducerOpDestValues =
- unfusedProducerOp.getDestinationOperands(rewriter);
- for (OpOperand &uses : unfusedProducerOp->getUses()) {
- if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
- unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
- unsigned operandNumber = uses.getOperandNumber();
- outerMostTiledLoop->setOperand(
- operandNumber, unfusedProducerOpDestValues[resultNumber]);
+ // ```
+ // %0 = linalg.init
+ // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
+ // %2 = scf.for .. iter_args(%arg1 = %arg0) {
+ // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
+ // %4 = linalg.fill .. outs(%3 : )
+ // .. = linalg.matmul .. outs(%4 : )
+ // }
+ // }
+ // ```
+ // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
+ // Update to use that when it does become available.
+ scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
+ Optional<unsigned> iterArgNumber;
+ if (destinationIterArg) {
+ iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
+ *destinationIterArg.value());
+ }
+ if (iterArgNumber) {
+ unsigned resultNumber = fusableProducer.getResultNumber();
+ if (auto producerOp =
+ dyn_cast<TilingInterface>(fusableProducer.getOwner())) {
+ SmallVector<Value> destination =
+ producerOp.getDestinationOperands(rewriter);
+ outerMostLoop.setIterArg(iterArgNumber.value(),
+ destination[resultNumber]);
+ }
+ if (auto tiledAndFusedInterfaceOp =
+ fusedProducerValue.value().getDefiningOp<TilingInterface>()) {
+ scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
+ SmallVector<Value> destination =
+ tiledAndFusedInterfaceOp.getDestinationOperands(rewriter);
+ updateDestinationOperandsForTiledOp(
+ rewriter, destination[resultNumber],
+ innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
}
}
}
- replaceIterArgs(tileAndFuseResult.loops.front(),
- tileAndFuseResult.loops.back(), rewriter);
return tileAndFuseResult;
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e40fc9cade586..777187387823b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -40,7 +40,8 @@ struct LoopParams {
scf::ForOp
mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
ValueRange newIterOperands,
- const NewYieldValueFn &newYieldValuesFn) {
+ const NewYieldValueFn &newYieldValuesFn,
+ bool replaceIterOperandsUsesInLoop) {
// Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(loop);
@@ -79,13 +80,15 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
- // Replace all uses of `newIterOperands` with the corresponding basic block
- // arguments.
- for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
- std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
- Operation *user = use.getOwner();
- return newLoop->isProperAncestor(user);
- });
+ if (replaceIterOperandsUsesInLoop) {
+ // Replace all uses of `newIterOperands` with the corresponding basic block
+ // arguments.
+ for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
+ std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return newLoop->isProperAncestor(user);
+ });
+ }
}
// Replace all uses of the original loop with corresponding values from the
@@ -104,7 +107,8 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
- ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn) {
+ ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn,
+ bool replaceIterOperandsUsesInLoop) {
if (loopNest.empty())
return {};
SmallVector<scf::ForOp> newLoopNest(loopNest.size());
@@ -121,8 +125,41 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
newIterOperands.size()));
return newYields;
};
- newLoopNest[loopDepth] = replaceLoopWithNewYields(
- builder, loopNest[loopDepth], newIterOperands, fn);
+ newLoopNest[loopDepth] =
+ replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands,
+ fn, replaceIterOperandsUsesInLoop);
+ if (!replaceIterOperandsUsesInLoop) {
+ /// The yield is expected to producer the following structure
+ /// ```
+ /// %0 = scf.for ... iter_args(%arg0 = %init) {
+ /// %1 = scf.for ... iter_args(%arg1 = %arg0) {
+ /// scf.yield %yield
+ /// }
+ /// }
+ /// ```
+ ///
+ /// since the yield is propagated from inside out, after the inner
+ /// loop is processed the IR is in this form
+ ///
+ /// ```
+ /// scf.for ... iter_args {
+ /// %1 = scf.for ... iter_args(%arg1 = %init) {
+ /// scf.yield %yield
+ /// }
+ /// ```
+ ///
+ /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do.
+ /// `%init` will be replaced with `%arg0` when it is created for the
+ /// outer loop. But without that this has to be done explicitly.
+ unsigned subLen = newIterOperands.size();
+ unsigned subStart =
+ newLoopNest[loopDepth + 1].getNumIterOperands() - subLen;
+ auto resetOperands =
+ newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart,
+ subLen);
+ resetOperands.assign(
+ newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen));
+ }
}
return newLoopNest;
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 61aa706b10ae4..dd8631f1fc157 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -30,7 +30,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME: outs(%[[FILL_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-// CHECK scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[INSERT]]
// -----
@@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
@@ -80,7 +80,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
// CHECK-SAME: outs(%[[OUTS_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-// CHECK scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[INSERT]]
// -----
@@ -130,7 +130,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
// CHECK-SAME: outs(%[[FILL1_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
-// CHECK scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[INSERT]]
// -----
@@ -182,7 +182,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
// CHECK-SAME: ins(%[[GEMM_TILE]] :
// CHECK-SAME: outs(%[[OUTS_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
-// CHECK scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[INSERT]]
// -----
@@ -218,7 +218,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
-// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
@@ -229,7 +229,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
// CHECK-SAME: ins(%[[GEMM_TILE]] :
// CHECK-SAME: outs(%[[INIT_TILE_2]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
-// CHECK scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[INSERT]]
// -----
More information about the Mlir-commits
mailing list