[Mlir-commits] [mlir] [mlir][TilingInterface] Extend option to yield replacement for multiple results case (PR #93144)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 5 06:47:32 PDT 2024
https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/93144
>From 4377bf0ce4e2108e1de01ca2b1c0a7d455068264 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 22 May 2024 23:15:36 -0700
Subject: [PATCH] yield replacement for multiple results
---
.../SCF/Transforms/TileUsingInterface.h | 6 +-
.../mlir/Interfaces/TilingInterface.td | 19 +++
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 35 +++--
.../SCF/Transforms/TileUsingInterface.cpp | 148 +++++++++++++-----
.../tile-fuse-and-yield-using-interface.mlir | 62 ++++++++
5 files changed, 220 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..807379d99b599 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
/// where `%0` had other uses as well. If not reconstructed from within the loop
/// body, uses of `%0` could not be replaced, making it still live and the
/// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops);
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index bc83c81c0086c..95e204cf5e1d2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -115,6 +115,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the tile of the iteration domain based
+ on the given tile of the certain result.
+ }],
+ /*retType=*/"::mlir::LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromResultTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$resultNumber,
+ "ArrayRef<OpFoldResult> ":$resultOffsets,
+ "ArrayRef<OpFoldResult> ":$resultSizes,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..0dbfe93fc5650 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -215,26 +215,39 @@ struct LinalgOpTilingInterface
return success();
}
- FailureOr<TilingResult>
- generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) const {
+ LogicalResult getIterationDomainTileFromResultTile(
+ Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
- // Check that the indexing map used for the output is a projected
+ // Check that the indexing map used for the operand is a projected
// permutation. This could be relaxed with a more general approach that can
- // map the offsets and sizes from the result to iteration space tiles
+ // map the offsets and sizes from the operand to iteration space tiles
// (filling in full extent for dimensions not used to access the result).
AffineMap indexingMap =
linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
if (!indexingMap.isProjectedPermutation()) {
- return op->emitOpError(
- "unhandled tiled implementation generation when result is not "
- "accessed using a permuted projection");
+ return op->emitError()
+ << "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection";
}
- SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- mappedOffsets, mappedSizes);
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
+ FailureOr<TilingResult>
+ generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ if (failed(getIterationDomainTileFromResultTile(
+ op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return failure();
+ }
auto tilingInterfaceOp = cast<TilingInterface>(op);
FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..e8cf7298be681 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
- MutableArrayRef<LoopLikeOpInterface> loops) {
+ MutableArrayRef<LoopLikeOpInterface> loops,
+ std::optional<ArrayRef<unsigned>> yieldResultNumber) {
if (loops.empty())
return success();
- OpResult fusableProducer = fusedProducerInfo.origProducer;
- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
- FailureOr<Value> initValue = tensor::getOrCreateDestination(
- rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
- if (succeeded(initValue)) {
-
- YieldTiledValuesFn newYieldValuesFn =
- [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
- -> LogicalResult {
- OpBuilder::InsertionGuard g(innerRewriter);
- if (auto tiledDestStyleOp =
- tiledAndFusedProducer
- .getDefiningOp<DestinationStyleOpInterface>()) {
- rewriter.setInsertionPoint(tiledDestStyleOp);
- Value newRegionArg = newRegionIterArgs.back();
+ Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+ *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+ Location loc = originalOwner->getLoc();
+ // a. collect all init Value to be appended
+ ArrayRef<unsigned> initNumberList =
+ yieldResultNumber ? yieldResultNumber.value()
+ : llvm::to_vector(llvm::seq<unsigned>(
+ 0, originalOwner->getNumResults()));
+ SmallVector<Value> initValueList;
+ for (const auto &resultNumber : initNumberList) {
+ FailureOr<Value> initValue = tensor::getOrCreateDestination(
+ rewriter, loc, originalOwner->getResult(resultNumber));
+ if (succeeded(initValue)) {
+ initValueList.push_back(initValue.value());
+ } else {
+ return failure();
+ }
+ }
+
+ YieldTiledValuesFn newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+ OpBuilder::InsertionGuard g(innerRewriter);
+
+ // get sliceOp tile information
+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+ sliceSizes = sliceOp.getMixedSizes();
+
+ // expect all strides of sliceOp being 1
+ if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 1);
+ }))
+ return failure();
+
+ unsigned sliceResultNumber =
+ fusedProducerInfo.origProducer.getResultNumber();
+
+ auto tilableOp = cast<TilingInterface>(originalOwner);
+ // b. get iterDomain Offset and Sizes based on sliceOp tile
+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+ // skip tensor.pack/unpack/pad, which expects single opResult
+ if (tilableOp->getNumResults() > 1 &&
+ failed(tilableOp.getIterationDomainTileFromResultTile(
+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+ iterDomainOffset, iterDomainSizes))) {
+ return failure();
+ }
+
+ // c. calculate offsets and sizes info of all OpResults respectively based
+ // on iteration Domain Tile
+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+ for (const auto &resultNumber : initNumberList) {
+ if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+ offsetList.push_back(sliceOffset);
+ sizesList.push_back(sliceSizes);
+ } else {
+ assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+ // infer result tile according to the iteration domain tile
+ SmallVector<OpFoldResult> offset, sizes;
+ if (failed(tilableOp.getResultTilePosition(
+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+ offset, sizes))) {
+ return failure();
+ }
+ offsetList.push_back(offset);
+ sizesList.push_back(sizes);
+ }
+ }
+
+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+ if (auto tiledDestStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ for (const auto &&[index, newRegionArg] :
+ llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
- unsigned resultNumber = fusableProducer.getResultNumber();
+ loc, newRegionArg, offsetList[index], sizesList[index],
+ SmallVector<OpFoldResult>(offsetList[index].size(),
+ rewriter.getIndexAttr(1)));
+ unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
- Block *block = rewriter.getInsertionPoint()->getBlock();
- rewriter.setInsertionPoint(block->getTerminator());
- tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
- tiledOffset.emplace_back(sliceOp.getMixedOffsets());
- tiledSizes.emplace_back(sliceOp.getMixedSizes());
- return success();
- };
+ }
- return addInitOperandsToLoopNest(rewriter, loops,
- SmallVector<Value>{initValue.value()},
- newYieldValuesFn);
- }
- return success();
+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
+ // caller
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+ tiledResult.push_back(tiledOwner->getResult(resultNumber));
+ tiledOffset.emplace_back(offsetList[index]);
+ tiledSizes.emplace_back(sizesList[index]);
+ }
+ return success();
+ };
+
+ return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+ newYieldValuesFn);
}
/// Implementation of tile consumer and fuse producer greedily.
@@ -1072,14 +1137,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
continue;
if (yieldReplacement) {
+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
+ // caller can specific which one to yield by designating optional argument
+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+ Operation *fusableProducerOp = fusableProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
return rewriter.notifyMatchFailure(
- fusableProducer.getOwner(), "failed to replacement value for this "
- "oepration from within the tiled loop");
+ fusableProducerOp, "failed to replacement value for this "
+ "operation from within the tiled loop");
+ }
+ for (const auto &result : fusableProducerOp->getResults()) {
+ origValToResultNumber[result] =
+ loops.front()->getNumResults() -
+ (fusableProducerOp->getNumResults() - result.getResultNumber());
}
- origValToResultNumber[fusableProducer] =
- loops.front()->getNumResults() - 1;
}
if (Operation *tiledAndFusedOp =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+ %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
+ %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
+ -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+ %out0, %out1 = linalg.generic {
+ indexing_maps = [affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j, i)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+ outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+ ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+ %4 = arith.mulf %0, %1 : f32
+ %5 = arith.addf %0, %1 : f32
+ linalg.yield %4, %5: f32, f32
+ } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+ %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+ return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %add = transform.structured.match ops{["linalg.add"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_and_yield %add [16]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME: outs(%[[INIT2_TILE]] :
+// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0
More information about the Mlir-commits
mailing list