[Mlir-commits] [mlir] [NFC] Simplify the tiling implementation using cloning. (PR #72178)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 15 04:34:09 PST 2023


================
@@ -496,42 +496,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
       reductionDims.push_back(idx);
   }
 
-  // 1. create the inital tensor value.
+  // 2. create the inital tensor value.
   FailureOr<Operation *> identityTensor =
       op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
                                                   reductionDims);
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
                                 "cannot create a tensor of identity value.");
-  // 2. Create the nested loops.
+  // 3. Create the nested loops.
   SmallVector<OpFoldResult> offsets, sizes;
-  SmallVector<scf::ForOp> loops = generateTileLoopNest(
-      b, loc, iterationDomain, tileSizesVector, offsets, sizes);
+  SmallVector<scf::ForOp> loops =
+      generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
+                           sizes, identityTensor.value()->getResults());
+
+  // 4. Generate the tiled implementation within the inner most loop.
+  // 4a. Clone the operation within the loop body.
+  SmallVector<Value> clonedOpDestination =
+      llvm::map_to_vector(identityTensor.value()->getResults(),
+                          [](OpResult res) -> Value { return res; });
+  if (!loops.empty()) {
+    b.setInsertionPointToEnd(loops.back().getBody());
+    clonedOpDestination =
+        llvm::map_to_vector(loops.back().getRegionIterArgs(),
+                            [](BlockArgument b) -> Value { return b; });
+  }
+  auto clonedOp = cast<PartialReductionOpInterface>(
+      cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
 
-  // 3. Generate the tiled implementation within the inner most loop.
-  b.setInsertionPoint(loops.back().getBody()->getTerminator());
-  Operation *parallelOp = op.tileToPartialReduction(
-      b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
+  // 4b. Tile the cloned operation.
+  Operation *parallelOp = clonedOp.tileToPartialReduction(
+      b, loc, clonedOpDestination, offsets, sizes, reductionDims);
+  // 4c. Delete the cloned operation.
+  b.eraseOp(clonedOp);
 
-  SmallVector<OpFoldResult> resultSizesList;
-  for (size_t i = 0; i < offsets.size(); i++)
-    resultSizesList.push_back(
+  SmallVector<OpFoldResult> outSizes;
+  for (size_t i = 0; i < offsets.size(); i++) {
+    outSizes.push_back(
         tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+  }
   SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-  SmallVector<Value> replacements = yieldTiledValues(
-      b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
-      resultSizesList, loops);
-
-  auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
-  auto innerMostLoop = loops.back();
-  SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
-  assert(destinationTensors.size() ==
-             innerMostLoop.getRegionIterArgs().size() &&
-         "unexpected number of outputs");
-  updateDestinationOperandsForTiledOp(b, destinationTensors,
----------------
nicolasvasilache wrote:

++1

https://github.com/llvm/llvm-project/pull/72178


More information about the Mlir-commits mailing list