[Mlir-commits] [mlir] [mlir][scf] Extend option to yield replacement for multiple results case (PR #93144)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 27 22:43:27 PDT 2024
================
@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops) {
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ ArrayRef<unsigned> yieldResultNumber) {
if (loops.empty())
return success();
- OpResult fusableProducer = fusedProducerInfo.origProducer;
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
- FailureOr<Value> initValue = tensor::getOrCreateDestination(
- rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
- if (succeeded(initValue)) {
-
- YieldTiledValuesFn newYieldValuesFn =
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
- -> LogicalResult {
- OpBuilder::InsertionGuard g(innerRewriter);
- if (auto tiledDestStyleOp =
- tiledAndFusedProducer
- .getDefiningOp<DestinationStyleOpInterface>()) {
- rewriter.setInsertionPoint(tiledDestStyleOp);
- Value newRegionArg = newRegionIterArgs.back();
+ Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+ *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+ Location loc = originalOwner->getLoc();
+ // a. collect all init Value to be appended
+ SmallVector<unsigned> initNumberList =
+ yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
+ 0, originalOwner->getNumResults()))
+ : llvm::to_vector(yieldResultNumber);
+ SmallVector<Value> initValueList;
+ for (const auto &resultNumber : initNumberList) {
+ FailureOr<Value> initValue = tensor::getOrCreateDestination(
+ rewriter, loc, originalOwner->getResult(resultNumber));
+ if (succeeded(initValue)) {
+ initValueList.push_back(initValue.value());
+ } else {
+ return failure();
+ }
+ }
+
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+
+ // get sliceOp tile information
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+ sliceSizes = sliceOp.getMixedSizes();
+
+ // expect all strides of sliceOp being 1
+ if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 1);
+ }))
+ return failure();
+
+ unsigned sliceResultNumber =
+ fusedProducerInfo.origProducer.getResultNumber();
+
+ auto tilableOp = cast<TilingInterface>(originalOwner);
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+ // skip tensor.pack/unpack/pad, which expects single opResult
+ if (tilableOp->getNumResults() > 1 &&
+ failed(tilableOp.getIterationDomainTileFromResultTile(
----------------
Yun-Fly wrote:
> Can you just add a comment here as to why this is a failure for now
Yes, sure. It is useful as a kind reminder for all of us.
> It looks good to me to land.
Thanks again!
https://github.com/llvm/llvm-project/pull/93144
More information about the Mlir-commits
mailing list