[Mlir-commits] [mlir] 9bc3102 - [mlir][scf] Extend consumer fusion to multiple tilable users (#111955)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 5 18:03:27 PST 2024
Author: Yun-Fly
Date: 2024-11-06T10:03:23+08:00
New Revision: 9bc3102bea80f422f4f3b788186f6e1c636e0fba
URL: https://github.com/llvm/llvm-project/commit/9bc3102bea80f422f4f3b788186f6e1c636e0fba
DIFF: https://github.com/llvm/llvm-project/commit/9bc3102bea80f422f4f3b788186f6e1c636e0fba.diff
LOG: [mlir][scf] Extend consumer fusion to multiple tilable users (#111955)
Before, consumer fusion expects single usage(or others are terminator
op). This patch supports multiple tilable consumers fusion.
E.g.
```
%0 = scf.for {
...
%p = tiledProducer
...
}
%1 = tilableConsumer1 ins(%0 : ...)
%2 = tilableConsumer2 ins(%0 : ...)
```
===>
```
%0:3 = scf.for {
...
%p = tiledProducer
%1 = tiledConsumer1 ins(%p : ...)
%2 = tiledConsumer2 ins(%p : ...)
...
}
```
The key process is ensuring that the first user of loop
should not dominate any define of consumer operand(s).
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..02e58141bdc303 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -1580,33 +1582,163 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
return success();
}
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
- Block *containingOpBlock) {
- // Check that the value has exactly one use which isn't a scf.yield or a
- // tensor.parallel_insert_slice op.
- OpOperand *operand = nullptr;
- for (OpOperand &opOperand : val.getUses()) {
- Operation *consumerOp = opOperand.getOwner();
- if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
- continue;
- if (operand)
- return failure();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
+/// An utility to get the first user of the given loopOp. If any of user stay in
+///
diff erent block of loopOp, return failure.
+static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
+ if (!isa<LoopLikeOpInterface>(loopOp))
+ return failure();
+ Operation *firstUserOfLoop = nullptr;
+ for (Operation *userOp : loopOp->getUsers()) {
+ // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+ // block with any other types of operation. Thus, just redirecting to its
+ // parent `InParallelOp`. E.g.
+ //
+ // ```
+ // %1 = scf.for {
+ // ...
+ // }
+ // %2 = consumerOp ins(%1, ...)
+ // scf.forall.in_parallel {
+ // tensor.parallel_insert_slice %1
+ // }
+ // ```
+ // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
+ // same block with `consumerOp`.
+ if (isa<tensor::ParallelInsertSliceOp>(userOp))
+ userOp = userOp->getParentOfType<scf::InParallelOp>();
+
+ if (loopOp->getBlock() != userOp->getBlock())
return failure();
- if (containingOpBlock != consumerOp->getBlock())
+
+ if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
+ firstUserOfLoop = userOp;
+ }
+ return firstUserOfLoop;
+}
+
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer operand. Because that we need to move the
+/// whole loop structure right before the `firstUserOfLoop`. This utility thus
+/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
+/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
+///
+/// ```
+/// %0 = scf.for() {
+/// ...
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumerOperand
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
+/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
+/// use-def chain violation:
+///
+/// ```
+/// %0:2 = scf.for() {
+/// // use before define error
+/// %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumerOperand
+/// ```
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param reorderOperations: the flag controls whether to reorder the backward
+/// slice w.r.t. the defineOp of `consumerOp` operands.
+/// @return: computed backward slice of consumerOp, but excluding those already
+/// dominates `firstUserOfLoop`.
+static FailureOr<llvm::SetVector<Operation *>>
+checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
+ bool reorderOperations) {
+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
+ if (failed(firstUserOfLoop))
+ return failure();
+
+ BackwardSliceOptions options;
+ DominanceInfo dominanceInfo;
+ options.inclusive = true;
+ options.omitBlockArguments = true;
+ bool includeLoopOp = false;
+ options.filter = [&](Operation *op) {
+ if (op == loopOp) {
+ includeLoopOp = true;
+ return false;
+ }
+ // Cut off the slice to not include any operation that already dominates
+ // firstUserOfLoop.
+ return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
+ };
+ llvm::SetVector<Operation *> slice;
+ for (auto operand : consumerOp->getOperands()) {
+ getBackwardSlice(operand, &slice, options);
+ }
+
+ if (!slice.empty()) {
+ // If consumerOp has one producer, which is also the user of loopOp.
+ // E.g.
+ // ```
+ // %0 = %loopOp
+ // %1 = consumerOp1 ins(%0)
+ // %2 = consumerOp2 ins(%0, %1)
+ // ```
+ // We can not fuse consumerOp2 into loopOp due to UD chain, unless
+ // consumerOp1 has already been fused into loopOp before.
+ if (includeLoopOp || !reorderOperations)
return failure();
- operand = &opOperand;
}
- if (operand)
- return operand;
+ return slice;
+}
+
+/// Fetches the OpOperand of the first valid user (and use) of the value `val`
+/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
+/// Returns failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
+ Operation *loopOp,
+ unsigned resultNumber) {
+ if (!isa<LoopLikeOpInterface>(loopOp))
+ return failure();
+ Value val = loopOp->getResult(resultNumber);
+ Block *loopBlock = loopOp->getBlock();
+ for (OpOperand &opOperand : val.getUses()) {
+ Operation *consumerOp = opOperand.getOwner();
+ // Step 1. Check if the user is tilable.
+ if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now. Add
+ // support for other op such as op has InferTypeOpInterface.
+ continue;
+ }
+ // Step 2. Check if user stay in the same block.
+ if (loopBlock != consumerOp->getBlock())
+ continue;
+ // Step 3. Check if user has succeeding user. Otherwise, it usually
+ // represents already tiled.
+ if (consumerOp->use_empty())
+ continue;
+ // Step 4. Check assumption for loop with `reorderOperations` enabled.
+ FailureOr<llvm::SetVector<Operation *>> slice =
+ checkAssumptionForLoop(loopOp, consumerOp, true);
+ if (failed(slice))
+ continue;
+ // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
+ if (!slice->empty()) {
+ mlir::topologicalSort(*slice);
+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
+ assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
+ for (auto op : *slice) {
+ rewriter.moveOpBefore(op, *firstUserOfLoop);
+ }
+ }
+ return &opOperand;
+ }
return failure();
}
@@ -1659,7 +1791,8 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter,
+ tensor::InsertSliceOp candidateSliceOp) {
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
@@ -1672,15 +1805,15 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
if (!forOp)
return failure();
scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
- Value resultingValue = topLevelForOp->getResult(resultNumber);
- return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
+ return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter,
+ tensor::ParallelInsertSliceOp candidateSliceOp) {
// Step 1. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
@@ -1693,45 +1826,22 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
if (!forallOp)
return failure();
- Value resultingValue =
- forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
- return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
+ unsigned resultNumber =
+ forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
+ .getResultNumber();
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
-static LogicalResult checkAssumptionForLoop(Operation *loopOp,
- Operation *consumerOp) {
- // Check if the loop op yields one result.
- if (loopOp->getNumResults() == 1)
- return success();
- // Check if the consumerOp is the first user of the loopOp and if other users
- // are in the same containing block as that of consumer op's.
- Block *parentBlock = consumerOp->getBlock();
- for (Operation *userOp : loopOp->getUsers()) {
- if (userOp == consumerOp)
- continue;
- if (parentBlock != userOp->getBlock() ||
- !consumerOp->isBeforeInBlock(userOp))
- return failure();
- }
- return success();
+ return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
}
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(insertSlice);
+ return getUntiledConsumerFromSlice(rewriter, insertSlice);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(parallelInsertSlice);
+ return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
} else {
return failure();
}
@@ -1751,7 +1861,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
- getUntiledConsumerFromSlice(candidateSliceOp);
+ getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
@@ -1787,11 +1897,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
LoopLikeOpInterface outerMostLoop = nestedLoops.front();
- if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
+ // Check assumption for loop with `reorderOperations` disabled.
+ if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
return rewriter.notifyMatchFailure(
- outerMostLoop,
- "containing loop op should either yield just one value or "
- "have the consumer op as its first user");
+ outerMostLoop, "the first user of loop should not dominate any define "
+ "of consumer operand(s)");
}
OpBuilder::InsertionGuard g(rewriter);
@@ -1812,9 +1922,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
Location loc = outerMostLoop->getLoc();
- // 3. Move the whole loop structure right before consumer Op, the dominance
- // should be already ensured by `checkAssumptionForLoop`.
- rewriter.moveOpBefore(outerMostLoop, consumerOp);
+ // 3. Move the whole loop structure right before firstUserOfLoop, the
+ // dominance should be already ensured by `checkAssumptionForLoop`.
+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
+ if (failed(firstUserOfLoop)) {
+ return rewriter.notifyMatchFailure(
+ outerMostLoop, "could not find the first user of outer most loop");
+ }
+ rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
// 4. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index f5f703d95e2d5b..af836d18e8c028 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -508,3 +508,65 @@ module {
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
// CHECK: }
// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
+
+// -----
+
+module {
+ func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+ %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %insert_slice : tensor<256x256xf32>
+ }
+ %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_add_multiple_tilable_consumers(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME: {
+// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
+// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
+// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
+// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
+// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] :
+// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
+// CHECK: }
+// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index b6da47977cb4cf..5e903e378daf82 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -171,24 +171,27 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
template <typename Range>
static LogicalResult
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
- Range &&payloadOps, TransformResults &transformResults) {
+ Range &&payloadOps, uint32_t numConsumerToFuse,
+ TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps;
for (Operation *target : payloadOps) {
rewriter.setInsertionPoint(target);
- FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
- scf::tileAndFuseConsumerOfSlice(rewriter, target);
+ while (numConsumerToFuse--) {
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumerOfSlice(rewriter, target);
- if (failed(fuseConsumerResults))
- return failure();
+ if (failed(fuseConsumerResults))
+ return failure();
- // Report back the relevant handles to the transform op.
- originalConsumerOps.push_back(
- fuseConsumerResults->origConsumerOperand->getOwner());
- fusedConsumerOps.push_back(
- fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+ // Report back the relevant handles to the transform op.
+ originalConsumerOps.push_back(
+ fuseConsumerResults->origConsumerOperand->getOwner());
+ fusedConsumerOps.push_back(
+ fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+ }
}
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
@@ -200,9 +203,9 @@ DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
- LogicalResult result =
- applyFuseConsumer(rewriter, getOperation(),
- state.getPayloadOps(getTarget()), transformResults);
+ LogicalResult result = applyFuseConsumer(
+ rewriter, getOperation(), state.getPayloadOps(getTarget()),
+ getNumConsumerToFuse(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa90..34b075a5c17f9e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -59,12 +59,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
}];
let arguments =
- (ins TransformHandleTypeInterface:$target);
+ (ins TransformHandleTypeInterface:$target,
+ DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
let results = (outs TransformHandleTypeInterface:$consumer,
TransformHandleTypeInterface:$fused_consumer);
let assemblyFormat = [{
- $target attr-dict `:` functional-type(operands, results)
+ $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
+ attr-dict `:` functional-type(operands, results)
}];
}
More information about the Mlir-commits
mailing list