[Mlir-commits] [mlir] [mlir][TilingInterface] Make `tileAndFuseConsumerOfSlice` take surrounding loops as an argument. (PR #132082)

Abhishek Varma llvmlistbot at llvm.org
Thu Mar 20 05:07:50 PDT 2025


================
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
 /// by a tensor.parallel_insert_slice.
 static FailureOr<OpOperand *>
 getUntiledConsumerFromSlice(RewriterBase &rewriter,
-                            tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
+                            tensor::ParallelInsertSliceOp candidateSliceOp,
+                            MutableArrayRef<LoopLikeOpInterface> loops) {
+  assert(!loops.empty() && "unexpected loops to be empty");
+  // 1. Check that the surrounding loop is a single scf.forall loop.
+  if (loops.size() != 1) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "expected single surrounding scf.forall");
+  }
+  auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
+  if (!forallOp) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "expected single surrounding scf.forall");
+  }
+
+  // 2. Fetch the corresponding output
   Value sliceDest = candidateSliceOp.getDest();
   auto iterArg = dyn_cast<BlockArgument>(sliceDest);
   if (!iterArg)
     return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
+  if (iterArg.getOwner()->getParentOp() != forallOp)
     return failure();
+
   unsigned resultNumber =
       forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
           .getResultNumber();
 
-  return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
+  return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
 }
 
 /// A utility to fetch an untiled consumer of
 /// tensor.insert_slice/tensor.parallel_insert_slice.
 static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
+                            MutableArrayRef<LoopLikeOpInterface> loops) {
+  if (loops.empty()) {
+    return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
+  }
----------------
Abhishek-Varma wrote:

Can be removed as the caller of this function `tileAndFuseConsumerOfSlice` is already ensuring the same.

https://github.com/llvm/llvm-project/pull/132082


More information about the Mlir-commits mailing list