[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 16 19:09:39 PDT 2024
================
@@ -1223,26 +1223,86 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
return success();
}
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
+/// Fetches the FIRST OpOperand of the tilable user (and use) of the value `val`
+/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
+/// Returns failure otherwise.
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
- return failure();
- // Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
+ OpOperand *operand = nullptr;
+ for (auto &use : val.getUses()) {
+ Operation *user = use.getOwner();
+ // Step 1. Check if the user is tilable.
+ if (!isa<TilingInterface, DestinationStyleOpInterface>(user)) {
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for
+ // now. Add support for other op such as op has
+ // InferTypeOpInterface.
+ continue;
+ } else {
+ // Step 2. Check if user stay in the same block.
+ if (containingOpBlock != user->getBlock())
+ continue;
+ // Step 3. Check if user has succeeding user. Otherwise, it usually
+ // represents already tiled.
+ if (user->use_empty())
+ continue;
+ operand = &use;
+ break;
+ }
+ }
+ if (!operand) {
return failure();
- return &operand;
+ }
+ return operand;
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+/// if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
----------------
Yun-Fly wrote:
Yes, you are right...
The motivation of this utility is actually find the nest loop between two candidates under different tiling/loop level. E.g.
```
%0= scf.forall() {
%1 = scf.for() {
%2 = scf.for() {
%3 = tensor.insert_slice
yield %3
}
yield %2
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %1
}
}
%4 = consumer ins(%0)
```
There exists two candidates under different level of loops. If we want to fuse consumer into inner most candidates in iterative fashion as you suggested. Here is subsequent steps:
1. Firstly fuse consumer into first outer candidate. In this case, previous version of `tileAndFuseConsumerOfSlice` is good enough to do this. The resultant IR turns out:
```
%0= scf.forall() {
%1 = scf.for() {
%2 = scf.for() {
%3 = tensor.insert_slice
yield %3
}
yield %2
}
%4 = consumer ins(%1)
scf.forall.in_parallel {
tensor.parallel_insert_slice %4
}
}
```
2. Then iteratively fuse consumer into second inner candidate. Note that there are two nest loops between two candidates(or between target candidate and consumer). We need to clone all of them right before consumer to ensure legal domination.
As you said, it is quite possible that either `%1` or `%2` loop not comes from tiling. But it seems still necessary to clone all of them no matter where they comes from?
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list