[Mlir-commits] [mlir] [NFC] Simplify the tiling implementation using cloning. (PR #72178)
Tomás Longeri
llvmlistbot at llvm.org
Mon Nov 20 11:03:56 PST 2023
================
@@ -636,28 +665,46 @@ void mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<scf::ForOp> loops) {
- auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
- fusedProducerInfo;
- SmallVector<Value> initValues;
+ if (loops.empty())
+ return;
+
+ OpResult fusableProducer = fusedProducerInfo.origProducer;
+ Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
if (succeeded(initValue)) {
- SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
- SmallVector<Value> yieldedVals =
- yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
- resultOffsets, resultSizes, loops);
- }
- for (auto tileAndFusedOp : tileAndFusedOps) {
- auto dstStyleProducer =
- dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
- if (!dstStyleProducer)
- continue;
- Value dstValue =
- dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
- ->get();
- updateDestinationOperandsForTiledOp(
- rewriter, dstValue, loops.back().getRegionIterArgs().back());
+
+ auto newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Value iv,
+ ValueRange newRegionIterArgs) -> SmallVector<Value> {
+ OpBuilder::InsertionGuard g(innerRewriter);
+ if (auto tiledDestStyleOp =
+ tiledAndFusedProducer
+ .getDefiningOp<DestinationStyleOpInterface>()) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
+ auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+ sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+ unsigned resultNumber = fusableProducer.getResultNumber();
+ rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
+ tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
+ });
+
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ Value replacement = rewriter.create<tensor::InsertSliceOp>(
+ fusedProducerInfo.origProducer.getLoc(),
+ fusedProducerInfo.tiledAndFusedProducer,
+ loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+ return {replacement};
+ }
----------------
tlongeri wrote:
It looks like you're missing a return value for when the if condition is false? I am getting build errors over this.
https://github.com/llvm/llvm-project/pull/72178
More information about the Mlir-commits
mailing list