[Mlir-commits] [mlir] 7ebf70d - [mlir][SCF][bufferize][NFC] Bufferize parallel_insert_slice separately
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 04:16:11 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T13:16:02+02:00
New Revision: 7ebf70d85d63acaaea56bdb1a13c5a2573868e1c
URL: https://github.com/llvm/llvm-project/commit/7ebf70d85d63acaaea56bdb1a13c5a2573868e1c
DIFF: https://github.com/llvm/llvm-project/commit/7ebf70d85d63acaaea56bdb1a13c5a2573868e1c.diff
LOG: [mlir][SCF][bufferize][NFC] Bufferize parallel_insert_slice separately
This allows for better type inference during bufferization and is in preparation of supporting memory spaces.
Differential Revision: https://reviews.llvm.org/D128580
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 785cd0d7806d..fb514a2f2b08 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -949,64 +949,32 @@ struct ForeachThreadOpInterface
return success();
}
- LogicalResult bufferize(Operation *op, RewriterBase &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
- OpBuilder::InsertionGuard g(b);
auto foreachThreadOp = cast<ForeachThreadOp>(op);
- // Gather new results of the ForeachThreadOp.
- SmallVector<Value> newResults;
- for (OpResult opResult : foreachThreadOp->getOpResults()) {
- OpOperand *insertDest =
- getInsertionDest(foreachThreadOp)[opResult.getResultNumber()];
- // Insert copies right before the PerformConcurrentlyOp terminator. They
- // should not be inside terminator (which would be the default insertion
- // point).
- Value buffer = getBuffer(b, insertDest->get(), options);
- newResults.push_back(buffer);
- }
+#ifndef NDEBUG
+ // ParallelInsertSliceOpInterface replaces all uses.
+ for (OpResult opResult : foreachThreadOp->getOpResults())
+ assert(opResult.getUses().empty() &&
+ "expected that all uses were already replaced");
+#endif // NDEBUG
// Create new ForeachThreadOp without any results and drop the automatically
// introduced terminator.
TypeRange newResultTypes;
- auto newForeachThreadOp =
- b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
- foreachThreadOp.getNumThreads());
+ auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
+ foreachThreadOp.getLoc(), newResultTypes,
+ foreachThreadOp.getNumThreads());
newForeachThreadOp.getBody()->getTerminator()->erase();
// Move over block contents of the old op.
- b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
- {newForeachThreadOp.getBody()->getArguments()});
-
- // Bufferize terminator.
- auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
- newForeachThreadOp.getBody()->getTerminator());
- b.setInsertionPoint(performConcurrentlyOp);
- unsigned resultCounter = 0;
- WalkResult walkResult =
- performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
- Location loc = insertOp.getLoc();
- Type srcType = getMemRefType(
- insertOp.getSource().getType().cast<RankedTensorType>(), options);
- // ParallelInsertSliceOp bufferizes to a copy.
- auto srcMemref = b.create<bufferization::ToMemrefOp>(
- loc, srcType, insertOp.getSource());
- Value destMemref = newResults[resultCounter++];
- Value subview = b.create<memref::SubViewOp>(
- loc, destMemref, insertOp.getMixedOffsets(),
- insertOp.getMixedSizes(), insertOp.getMixedStrides());
- // This memcpy will fold away if everything bufferizes in-place.
- if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
- subview)))
- return WalkResult::interrupt();
- b.eraseOp(insertOp);
- return WalkResult::advance();
- });
- if (walkResult.wasInterrupted())
- return failure();
+ rewriter.mergeBlocks(foreachThreadOp.getBody(),
+ newForeachThreadOp.getBody(),
+ {newForeachThreadOp.getBody()->getArguments()});
- // Replace the op.
- replaceOpWithBufferizedValues(b, op, newResults);
+ // Remove the old op.
+ rewriter.eraseOp(op);
return success();
}
@@ -1104,9 +1072,50 @@ struct ParallelInsertSliceOpInterface
return success();
}
- LogicalResult bufferize(Operation *op, RewriterBase &b,
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
- // Will be bufferized as part of ForeachThreadOp.
+ OpBuilder::InsertionGuard g(rewriter);
+ auto insertOp = cast<ParallelInsertSliceOp>(op);
+ auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
+ auto foreachThreadOp =
+ cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
+
+ // If the op bufferizes out-of-place, allocate the copy before the
+ // ForeachThreadOp.
+ rewriter.setInsertionPoint(foreachThreadOp);
+ Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options);
+
+ // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
+ rewriter.setInsertionPoint(performConcurrentlyOp);
+ Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options);
+ Value subview = rewriter.create<memref::SubViewOp>(
+ insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ // This memcpy will fold away if everything bufferizes in-place.
+ if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer,
+ subview)))
+ return failure();
+ rewriter.eraseOp(op);
+
+ // Replace all uses of ForeachThreadOp (just the corresponding result).
+ rewriter.setInsertionPointAfter(foreachThreadOp);
+ Value toTensorOp =
+ rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), destBuffer);
+ unsigned resultNum = 0;
+ for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
+ if (&nextOp == op)
+ break;
+ resultNum++;
+ }
+ assert(resultNum < foreachThreadOp->getNumResults() &&
+ "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
+ SmallVector<OpOperand *> resultUses = llvm::to_vector(
+ llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
+ [](OpOperand &use) { return &use; }));
+ for (OpOperand *use : resultUses) {
+ rewriter.updateRootInPlace(use->getOwner(),
+ [&]() { use->set(toTensorOp); });
+ }
return success();
}
More information about the Mlir-commits
mailing list