[Mlir-commits] [mlir] 04dac2c - [mlir][SCF][bufferize][NFC] Implement resolveConflicts for ParallelInsertSliceOp
Matthias Springer
llvmlistbot at llvm.org
Tue Jun 28 03:23:22 PDT 2022
Author: Matthias Springer
Date: 2022-06-28T12:18:22+02:00
New Revision: 04dac2ca7c06d0ce173e53527e3b90a07e3b325d
URL: https://github.com/llvm/llvm-project/commit/04dac2ca7c06d0ce173e53527e3b90a07e3b325d
DIFF: https://github.com/llvm/llvm-project/commit/04dac2ca7c06d0ce173e53527e3b90a07e3b325d.diff
LOG: [mlir][SCF][bufferize][NFC] Implement resolveConflicts for ParallelInsertSliceOp
This was previous implemented as part of the BufferizableOpInterface of ForEachThreadOp. Moving the implementation to ParallelInsertSliceOp to be consistent with the remaining ops and to have a nice example op that can serve as a blueprint for other ops.
Differential Revision: https://reviews.llvm.org/D128666
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index cde966592212d..5de21cc350ac1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -578,6 +578,10 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
/// Return the number of leading operands before `offsets`, `sizes` and
/// `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+ /// Return the OpResult of the enclosing ForeachThreadOp that is
+ /// corresponding to this ParallelInsertSliceOp.
+ OpResult getTiedOpResult();
}];
let builders = [
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bd0f16dbd0e07..557a9edc2f18e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1215,6 +1215,18 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
+OpResult ParallelInsertSliceOp::getTiedOpResult() {
+ auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
+ assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
+ PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
+ for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
+ Operation &nextOp = it.value();
+ if (&nextOp == getOperation())
+ return foreachThreadOp->getResult(it.index());
+ }
+ llvm_unreachable("ParallelInsertSliceOp not found");
+}
+
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 0227816b77845..1f6359bfbb497 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -961,42 +961,6 @@ struct ForeachThreadOpInterface
return BufferRelation::Equivalent;
}
- LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &state) const {
- auto bufferizableOp = cast<BufferizableOpInterface>(op);
- if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
- return failure();
-
- OpBuilder::InsertionGuard g(rewriter);
- auto foreachThreadOp = cast<ForeachThreadOp>(op);
- for (OpResult opResult : foreachThreadOp->getOpResults()) {
- SmallVector<OpOperand *> destOperands =
- state.getAliasingOpOperand(opResult);
- assert(destOperands.size() == 1 &&
- "expected exactly one aliasing OpOperand");
- assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
- "expected ParallelInsertSliceOp");
-
- // Nothing to do if there is no conflict.
- if (state.isInPlace(*destOperands.front()))
- continue;
-
- // Insert tensor allocation.
- bool isYielded = state.isTensorYielded(opResult);
- FailureOr<Value> alloc = allocateTensorForShapedValue(
- rewriter, op->getLoc(), destOperands.front()->get(),
- /*escape=*/isYielded, state.getOptions());
- if (failed(alloc))
- return failure();
-
- // Update terminator operand.
- rewriter.updateRootInPlace(destOperands.front()->getOwner(),
- [&]() { destOperands.front()->set(*alloc); });
- }
-
- return success();
- }
-
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto foreachThreadOp = cast<ForeachThreadOp>(op);
@@ -1118,7 +1082,55 @@ struct ParallelInsertSliceOpInterface
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
- // RaW conflicts are resolved as part of ForeachThreadOp.
+ // This interface method is overridden because we want to set a custom
+ // insertion point for tensor copies. They should be inserted right before
+ // the ForeachThreadOp. E.g.:
+ //
+ // %r0, %r1 = foreach_thead ... {
+ // ...
+ // perform_concurrently {
+ // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
+ // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
+ // }
+ // }
+ //
+ // After TensorCopyInsertion:
+ //
+ // %copy = bufferization.alloc_tensor() copy(%d)
+ // %r0, %r1 = foreach_thead ... {
+ // ...
+ // perform_concurrently {
+ // parallel_insert_slice %a into %b ...
+ // parallel_insert_slice %c into %copy ...
+ // }
+ // }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ auto insertOp = cast<ParallelInsertSliceOp>(op);
+ auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
+
+ // Nothing to do if the destination tensor is inplace.
+ assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
+ "source is always in-place");
+ if (state.isInPlace(op->getOpOperand(1) /*dest*/))
+ return success();
+
+ // Find corresponding OpResult.
+ OpResult opResult = insertOp.getTiedOpResult();
+
+ // Insert tensor allocation right before the ForeachThreadOp.
+ rewriter.setInsertionPoint(foreachThreadOp);
+ bool isYielded = state.isTensorYielded(opResult);
+ FailureOr<Value> alloc =
+ allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
+ /*escape=*/isYielded, state.getOptions());
+ if (failed(alloc))
+ return failure();
+
+ // Update destination operand.
+ rewriter.updateRootInPlace(
+ insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
+
return success();
}
@@ -1149,29 +1161,20 @@ struct ParallelInsertSliceOpInterface
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
subview)))
return failure();
- rewriter.eraseOp(op);
// Replace all uses of ForeachThreadOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
- // PerformConcurrentlyOp can have multiple ParallelInserSliceOps. Find the
- // index of `op` within yielding ops.
- unsigned resultNum = 0;
- for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
- if (&nextOp == op)
- break;
- resultNum++;
- }
- assert(resultNum < foreachThreadOp->getNumResults() &&
- "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
- SmallVector<OpOperand *> resultUses = llvm::to_vector(
- llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
- [](OpOperand &use) { return &use; }));
+ // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
+ SmallVector<OpOperand *> resultUses =
+ llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
+ [](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });
}
+ rewriter.eraseOp(op);
return success();
}
More information about the Mlir-commits
mailing list