[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 09:32:10 PDT 2024


================
@@ -1223,26 +1223,86 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
   return success();
 }
 
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
+/// Fetches the FIRST OpOperand of the tilable user (and use) of the value `val`
+/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
+/// Returns failure otherwise.
 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
+  OpOperand *operand = nullptr;
+  for (auto &use : val.getUses()) {
+    Operation *user = use.getOwner();
+    // Step 1. Check if the user is tilable.
+    if (!isa<TilingInterface, DestinationStyleOpInterface>(user)) {
+      // TODO: We have to init result of consumer before scf.for, use
+      //       DestinationStyleOpInterface to get result shape from init for
+      //       now. Add support for other op such as op has
+      //       InferTypeOpInterface.
+      continue;
+    } else {
+      // Step 2. Check if user stay in the same block.
+      if (containingOpBlock != user->getBlock())
+        continue;
+      // Step 3. Check if user has succeeding user. Otherwise, it usually
+      // represents already tiled.
+      if (user->use_empty())
+        continue;
+      operand = &use;
+      break;
+    }
+  }
+  if (!operand) {
     return failure();
-  return &operand;
+  }
+  return operand;
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
----------------
MaheshRavishankar wrote:

It is not necessary that the outer loops here are created through tiling....

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list