[Mlir-commits] [mlir] [NFC] Simplify the tiling implementation using cloning. (PR #72178)

Tomás Longeri llvmlistbot at llvm.org
Mon Nov 20 11:03:56 PST 2023


================
@@ -636,28 +665,46 @@ void mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<scf::ForOp> loops) {
-  auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
-      fusedProducerInfo;
-  SmallVector<Value> initValues;
+  if (loops.empty())
+    return;
+
+  OpResult fusableProducer = fusedProducerInfo.origProducer;
+  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
   FailureOr<Value> initValue = tensor::getOrCreateDestination(
       rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
   if (succeeded(initValue)) {
-    SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
-    SmallVector<Value> yieldedVals =
-        yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
-                         resultOffsets, resultSizes, loops);
-  }
-  for (auto tileAndFusedOp : tileAndFusedOps) {
-    auto dstStyleProducer =
-        dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
-    if (!dstStyleProducer)
-      continue;
-    Value dstValue =
-        dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
-            ->get();
-    updateDestinationOperandsForTiledOp(
-        rewriter, dstValue, loops.back().getRegionIterArgs().back());
+
+    auto newYieldValuesFn =
+        [&](RewriterBase &innerRewriter, Value iv,
+            ValueRange newRegionIterArgs) -> SmallVector<Value> {
+      OpBuilder::InsertionGuard g(innerRewriter);
+      if (auto tiledDestStyleOp =
+              tiledAndFusedProducer
+                  .getDefiningOp<DestinationStyleOpInterface>()) {
+        rewriter.setInsertionPoint(tiledDestStyleOp);
+        BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
+        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
+            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+        unsigned resultNumber = fusableProducer.getResultNumber();
+        rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
+          tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
+        });
+
+        Block *block = rewriter.getInsertionPoint()->getBlock();
+        rewriter.setInsertionPoint(block->getTerminator());
+        Value replacement = rewriter.create<tensor::InsertSliceOp>(
+            fusedProducerInfo.origProducer.getLoc(),
+            fusedProducerInfo.tiledAndFusedProducer,
+            loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
+            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+        return {replacement};
+      }
----------------
tlongeri wrote:

It looks like you're missing a return value for when the if condition is false? I am getting build errors over this.

https://github.com/llvm/llvm-project/pull/72178


More information about the Mlir-commits mailing list