[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 16 09:32:10 PDT 2024
================
@@ -1223,26 +1223,86 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
return success();
}
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
+/// Fetches the FIRST OpOperand of the tilable user (and use) of the value `val`
+/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
+/// Returns failure otherwise.
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
- return failure();
- // Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
+ OpOperand *operand = nullptr;
+ for (auto &use : val.getUses()) {
+ Operation *user = use.getOwner();
+ // Step 1. Check if the user is tilable.
+ if (!isa<TilingInterface, DestinationStyleOpInterface>(user)) {
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for
+ // now. Add support for other op such as op has
+ // InferTypeOpInterface.
+ continue;
+ } else {
+ // Step 2. Check if user stay in the same block.
+ if (containingOpBlock != user->getBlock())
+ continue;
+ // Step 3. Check if user has succeeding user. Otherwise, it usually
+ // represents already tiled.
+ if (user->use_empty())
+ continue;
+ operand = &use;
+ break;
+ }
+ }
+ if (!operand) {
return failure();
- return &operand;
+ }
+ return operand;
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+/// if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+/// %0 = scf.for()
+/// %1 = scf.for()
+/// %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
----------------
MaheshRavishankar wrote:
It is not necessary that the outer loops here are created through tiling....
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list