[Mlir-commits] [mlir] [mlir][scf] Extend option to yield replacement for multiple results case (PR #93144)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 27 22:58:23 PDT 2024
https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/93144
>From 4a786b2a7adbe92197890bf3c60361dafec2f3ca Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 22 May 2024 23:15:36 -0700
Subject: [PATCH 1/5] yield replacement for multiple results
---
.../SCF/Transforms/TileUsingInterface.h | 6 +-
.../mlir/Interfaces/TilingInterface.td | 38 ++++-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 25 ++-
.../SCF/Transforms/TileUsingInterface.cpp | 148 +++++++++++++-----
.../tile-fuse-and-yield-using-interface.mlir | 62 ++++++++
5 files changed, 233 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..807379d99b599 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
/// where `%0` had other uses as well. If not reconstructed from within the loop
/// body, uses of `%0` could not be replaced, making it still live and the
/// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops);
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 8865aba3b4ef0..3cd9c8ccce075 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
For an operation to be "tiled and fused" with its (already tiled) consumer,
an operation has to implement the following additional method (see
description below):
- - `generateResultTileValue
+ - `generateResultTileValue`
+ - `getIterationDomainTileFromResultTile`
For an operation to be "tiled and fused" with its (already tiled) producer,
an operation has to implement the following additional methods (see
@@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the tile of the iteration domain based
+ on the given tile of the certain result.
+
+ This method is required to allow operations to be "tiled and fused"
+ with an (already tiled) consumer. Given a tile of an result,
+ returns the tile of the iteration space that uses this tile.
+ - `resultNumber` is the result of the producer used by the consumer.
+ - `offsets` is the offset of the slice of the producer result used by
+ the tiled implementation of the consumer.
+ - `sizes` is the size of the slice of the producer result used by the
+ consumer.
+ If fusion of the producer with the consumer is not legal for the
+ result, or if this mapping cannot be computed, the implementation
+ should return a failure.
+
+ For most cases `generateResultTileValue` could be a implemented using
+ `getIterationDomainTileFromResultTile` + `getTiledImplementation`
+ methods.
+ }],
+ /*retType=*/"::mlir::LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromResultTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$resultNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..424f29e787215 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -215,10 +215,11 @@ struct LinalgOpTilingInterface
return success();
}
- FailureOr<TilingResult>
- generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) const {
+ LogicalResult getIterationDomainTileFromResultTile(
+ Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the output is a projected
@@ -232,9 +233,21 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
- SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- mappedOffsets, mappedSizes);
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
+ FailureOr<TilingResult>
+ generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ if (failed(getIterationDomainTileFromResultTile(
+ op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return failure();
+ }
auto tilingInterfaceOp = cast<TilingInterface>(op);
FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..e8cf7298be681 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops) {
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ std::optional<ArrayRef<unsigned>> yieldResultNumber) {
if (loops.empty())
return success();
- OpResult fusableProducer = fusedProducerInfo.origProducer;
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
- FailureOr<Value> initValue = tensor::getOrCreateDestination(
- rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
- if (succeeded(initValue)) {
-
- YieldTiledValuesFn newYieldValuesFn =
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
- -> LogicalResult {
- OpBuilder::InsertionGuard g(innerRewriter);
- if (auto tiledDestStyleOp =
- tiledAndFusedProducer
- .getDefiningOp<DestinationStyleOpInterface>()) {
- rewriter.setInsertionPoint(tiledDestStyleOp);
- Value newRegionArg = newRegionIterArgs.back();
+ Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+ *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+ Location loc = originalOwner->getLoc();
+ // a. collect all init Value to be appended
+ ArrayRef<unsigned> initNumberList =
+ yieldResultNumber ? yieldResultNumber.value()
+ : llvm::to_vector(llvm::seq<unsigned>(
+ 0, originalOwner->getNumResults()));
+ SmallVector<Value> initValueList;
+ for (const auto &resultNumber : initNumberList) {
+ FailureOr<Value> initValue = tensor::getOrCreateDestination(
+ rewriter, loc, originalOwner->getResult(resultNumber));
+ if (succeeded(initValue)) {
+ initValueList.push_back(initValue.value());
+ } else {
+ return failure();
+ }
+ }
+
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+
+ // get sliceOp tile information
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+ sliceSizes = sliceOp.getMixedSizes();
+
+ // expect all strides of sliceOp being 1
+ if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 1);
+ }))
+ return failure();
+
+ unsigned sliceResultNumber =
+ fusedProducerInfo.origProducer.getResultNumber();
+
+ auto tilableOp = cast<TilingInterface>(originalOwner);
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+ // skip tensor.pack/unpack/pad, which expects single opResult
+ if (tilableOp->getNumResults() > 1 &&
+ failed(tilableOp.getIterationDomainTileFromResultTile(
+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+ iterDomainOffset, iterDomainSizes))) {
+ return failure();
+ }
+
+ // c. calculate offsets and sizes info of all OpResults respectively based
+ // on iteration Domain Tile
+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+ for (const auto &resultNumber : initNumberList) {
+ if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+ offsetList.push_back(sliceOffset);
+ sizesList.push_back(sliceSizes);
+ } else {
+ assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+ // infer result tile according to the iteration domain tile
+ SmallVector<OpFoldResult> offset, sizes;
+ if (failed(tilableOp.getResultTilePosition(
+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+ offset, sizes))) {
+ return failure();
+ }
+ offsetList.push_back(offset);
+ sizesList.push_back(sizes);
+ }
+ }
+
+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+ if (auto tiledDestStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ for (const auto &&[index, newRegionArg] :
+ llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
- unsigned resultNumber = fusableProducer.getResultNumber();
+ loc, newRegionArg, offsetList[index], sizesList[index],
+ SmallVector<OpFoldResult>(offsetList[index].size(),
+ rewriter.getIndexAttr(1)));
+ unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
- Block *block = rewriter.getInsertionPoint()->getBlock();
- rewriter.setInsertionPoint(block->getTerminator());
- tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
- tiledOffset.emplace_back(sliceOp.getMixedOffsets());
- tiledSizes.emplace_back(sliceOp.getMixedSizes());
- return success();
- };
+ }
- return addInitOperandsToLoopNest(rewriter, loops,
- SmallVector<Value>{initValue.value()},
- newYieldValuesFn);
- }
- return success();
+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
+ // caller
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+ tiledResult.push_back(tiledOwner->getResult(resultNumber));
+ tiledOffset.emplace_back(offsetList[index]);
+ tiledSizes.emplace_back(sizesList[index]);
+ }
+ return success();
+ };
+
+ return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+ newYieldValuesFn);
}
/// Implementation of tile consumer and fuse producer greedily.
@@ -1072,14 +1137,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
continue;
if (yieldReplacement) {
+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
+ // caller can specific which one to yield by designating optional argument
+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+ Operation *fusableProducerOp = fusableProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
return rewriter.notifyMatchFailure(
- fusableProducer.getOwner(), "failed to replacement value for this "
- "oepration from within the tiled loop");
+ fusableProducerOp, "failed to replacement value for this "
+ "operation from within the tiled loop");
+ }
+ for (const auto &result : fusableProducerOp->getResults()) {
+ origValToResultNumber[result] =
+ loops.front()->getNumResults() -
+ (fusableProducerOp->getNumResults() - result.getResultNumber());
}
- origValToResultNumber[fusableProducer] =
- loops.front()->getNumResults() - 1;
}
if (Operation *tiledAndFusedOp =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+ %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
+ %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
+ -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+ %out0, %out1 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j, i)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+ outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+ ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+ %4 = arith.mulf %0, %1 : f32
+ %5 = arith.addf %0, %1 : f32
+ linalg.yield %4, %5: f32, f32
+ } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+ %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+ return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_and_yield %add [16]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME: outs(%[[INIT2_TILE]] :
+// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0
>From b9450f784ec5a56e3e000bc4879321f2fed4b260 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 5 Jun 2024 07:00:01 -0700
Subject: [PATCH 2/5] change default arguments
---
.../mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 2 +-
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++----
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 807379d99b599..781da1b4ef8a2 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -198,7 +198,7 @@ LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
- std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
+ ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e8cf7298be681..33142e61750d2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -941,7 +941,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
- std::optional<ArrayRef<unsigned>> yieldResultNumber) {
+ ArrayRef<unsigned> yieldResultNumber) {
if (loops.empty())
return success();
@@ -951,9 +951,9 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
Location loc = originalOwner->getLoc();
// a. collect all init Value to be appended
ArrayRef<unsigned> initNumberList =
- yieldResultNumber ? yieldResultNumber.value()
- : llvm::to_vector(llvm::seq<unsigned>(
- 0, originalOwner->getNumResults()));
+ yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
+ 0, originalOwner->getNumResults()))
+ : yieldResultNumber;
SmallVector<Value> initValueList;
for (const auto &resultNumber : initNumberList) {
FailureOr<Value> initValue = tensor::getOrCreateDestination(
>From eda4bf35b0535cc248b8555d378176a86831df0e Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 5 Jun 2024 18:47:54 -0700
Subject: [PATCH 3/5] fix CI
---
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 33142e61750d2..a6677393c73dc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -950,10 +950,10 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
Location loc = originalOwner->getLoc();
// a. collect all init Value to be appended
- ArrayRef<unsigned> initNumberList =
+ SmallVector<unsigned> initNumberList =
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
0, originalOwner->getNumResults()))
- : yieldResultNumber;
+ : llvm::to_vector(yieldResultNumber);
SmallVector<Value> initValueList;
for (const auto &resultNumber : initNumberList) {
FailureOr<Value> initValue = tensor::getOrCreateDestination(
@@ -1000,7 +1000,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
// on iteration Domain Tile
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
for (const auto &resultNumber : initNumberList) {
- if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+ if (resultNumber == sliceResultNumber) {
offsetList.push_back(sliceOffset);
sizesList.push_back(sliceSizes);
} else {
>From 46e3751412ba00281912cb41965b109bc4521ee8 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Thu, 20 Jun 2024 00:18:51 -0700
Subject: [PATCH 4/5] fix comment
---
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a6677393c73dc..dab66ce97a6b5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1147,10 +1147,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
- for (const auto &result : fusableProducerOp->getResults()) {
- origValToResultNumber[result] =
- loops.front()->getNumResults() -
- (fusableProducerOp->getNumResults() - result.getResultNumber());
+ for (auto [index, result] :
+ llvm::enumerate(fusableProducerOp->getResults())) {
+ origValToResultNumber[result] = loops.front()->getNumResults() -
+ fusableProducerOp->getNumResults() +
+ index;
}
}
>From 1f0cbdbaf9cda930030fd2a5b9b6feba1b2840d6 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Thu, 27 Jun 2024 22:58:03 -0700
Subject: [PATCH 5/5] add a comment on why return failure
---
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index dab66ce97a6b5..2efa8149f52ba 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -993,6 +993,14 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
+ // In theory, it is unnecessary to raise an error here. Actually although
+ // it fails to reconstruct the result tensor, it should not broke current
+ // fusion anyway. The reason why we must return failure currently is that
+ // the callback function `newYieldValuesFn` will be called after new init
+ // operand(s) has already been appended. It will take more refactoring to
+ // make sure the init operands are added consistently in the future. For
+ // more details, please refer to:
+ // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
return failure();
}
More information about the Mlir-commits
mailing list