[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #85528)
donald chen
llvmlistbot at llvm.org
Sat Mar 16 07:32:55 PDT 2024
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/85528
>From 7ca5ae58b93a182b1bc4795c5e851c77eb7119fb Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH] [mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface, and
implements fuse consumers on FuseIntoContainingOp.
- Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling
interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandPosition' to tiling
interface which generate tiled implementation according to operand position.
- Implemented the above two methods and supported consumer fusion for
FuseIntoContainingOp.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 48 ++-
.../mlir/Interfaces/TilingInterface.td | 55 +++
.../TransformOps/LinalgTransformOps.cpp | 369 +++++++++++++++---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 96 ++++-
.../transform-op-fuse-into-containing.mlir | 125 +++++-
5 files changed, 590 insertions(+), 103 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bdeab55091b9f3..08c41dbee54dbc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -310,51 +310,61 @@ def FuseIntoContainingOp :
["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
- let summary = "Fuse a producer into a containing operation.";
+ let summary = "Fuse a target into a containing operation.";
let description = [{
- Fuses the `producer_op` into the `containing_op`.
+ Fuses the `target_op` into the `containing_op`.
Returns a handle to the fused ops and the `new_containing_op`.
- The producer is typically a slice of a tileable op (i.e., implements
- TilingInterface). In that case, this transform computes the accessed
- producer slice inside of the containing op ("tile and fuse") and if required,
- creates a new containing op with outputs from the fused producer. Otherwise,
- the entire producer is cloned inside the containing op ("clone and fuse").
+ This operation supports fusion of producer or fusion of consumer. We will
+ refer to the value connecting the containing operation and the target
+ operation as the "bridge" below.
+
+ When fuse producer, the bridge is typically a slice of a tileable op (i.e.,
+ implements TilingInterface). In that case, this transform computes the
+ accessed bridge slice inside of the containing op ("tile and fuse") and
+ if required, creates a new containing op with outputs from the fused target.
+ Otherwise, the entire target is cloned inside the containing op ("clone
+ and fuse").
+
+ When fuse consumer, the bridge is the result of containing op and a operand
+ of a tileable op (i.e., implements TilingInterface). In this case, this
+ transform computes the access bridge slice inside the containing op ("tile
+ and fuse") and creates a new containing op with consumer's output.
The containing op handle must be associated with exactly one payload op. The
- producer op handle may be associated with multiple payload ops. This
- transform fuses producers one-by-one, always picking an unspecified producer
+ target op handle may be associated with multiple payload ops. This
+ transform fuses targets one-by-one, always picking an unspecified target
that has at least one use inside the containing op among the
- producers. A producer can be listed multiple times in the handle.
+ targets. A target can be listed multiple times in the handle.
- Note: If a producer has multiple uses inside the containing op, it is
+ Note: If a target has multiple uses inside the containing op, it is
currently tiled and/or cloned multiple times into the containing op.
TODO: Reuse already fused OpResults instead of tiling/cloning a second time
- when possible. Fuse producers according to a topological sorting to achieve
+ when possible. Fuse targets according to a topological sorting to achieve
the largest amount of reuse.
#### Return modes
- If at least one producer could not be fused, this operation produces a
+ If at least one target could not be fused, this operation produces a
silenceable failure. This is the case when tiling fails or when no
- producer op could be found among the remaining producers that has at least
- one use within the containing op. I.e., "producers" that are not consumed
+ target op could be found among the remaining targets that has at least
+ one use within the containing op. I.e., "targets" that are not consumed
within the containing op are rejected by this operation.
- This operation consumes the producer handle.
+ This operation consumes the target handle.
This operation only reads the containing op handle.
}];
- let arguments = (ins TransformHandleTypeInterface:$producer_op,
+ let arguments = (ins TransformHandleTypeInterface:$target_op,
TransformHandleTypeInterface:$containing_op);
let results = (outs TransformHandleTypeInterface:$fused_op,
TransformHandleTypeInterface:$new_containing_op);
- let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
+ let assemblyFormat = "$target_op `into` $containing_op attr-dict "
" `:` functional-type(operands, results)";
let builders = [
- OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+ OpBuilder<(ins "Value":$targetOp, "Value":$containingOp)>
];
}
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return iterator domain position computed by the
+ input operand position.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand position.
+
+ Generates the IR that generate the tiled implementation of an
+ operation from operand position. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the position of the
+ operand required. This method enables fusion consumer by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..ecffb910b236e8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -546,9 +546,9 @@ LogicalResult transform::FuseOp::verify() {
void transform::FuseIntoContainingOp::build(OpBuilder &builder,
OperationState &result,
- Value producerOp,
+ Value targetOp,
Value containingOp) {
- result.addOperands({producerOp, containingOp});
+ result.addOperands({targetOp, containingOp});
auto resultType = transform::AnyOpType::get(builder.getContext());
result.addTypes({resultType, resultType});
}
@@ -631,6 +631,223 @@ static Operation *replaceForAllWithNewSignature(
return newforallOp;
}
+static std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseParallelInsertSlice(RewriterBase &rewriter, Diagnostic &diag,
+ Operation *consumerOp, Operation *containingOp) {
+ // Check consumer has tiling interface.
+ LLVM_DEBUG(DBGS() << "Try to fuse a consumer\n");
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer is not a TileableInterface: " << *consumerOp;
+ return {};
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ diag.attachNote(containingOp->getLoc())
+ << "containing op is not a scf.forall: " << containingOp;
+ return {};
+ }
+
+ // Check dominance.
+ DominanceInfo domInfo(
+ containingOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>());
+ if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+ return v.getDefiningOp() != containingOp &&
+ !domInfo.properlyDominates(v, containingOp);
+ })) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer's operand can't dominate containing op";
+ return {};
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer's operand use more than one containingOp's result";
+ return {};
+ }
+ }
+ }
+
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer op should have destination style op interface";
+ return {};
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, bridge)) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer op take result of scf.forall as init";
+ return {};
+ }
+
+ // Check result was inserted only once.
+ int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+ auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+ tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+ for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+ auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+ if (!targetInsertOp) {
+ targetInsertOp = parallelInsertSliceOp;
+ } else {
+ diag.attachNote(containingOp->getLoc())
+ << "containingOp's result inserted multi time";
+ return {};
+ }
+ }
+ }
+
+ if (!targetInsertOp) {
+ diag.attachNote(containingOp->getLoc())
+ << "containingOp's result was not inserted";
+ return {};
+ }
+
+ SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+ if (auto attr = foldRes.dyn_cast<Attribute>()) {
+ return cast<IntegerAttr>(attr).getInt() != 1;
+ }
+ return true;
+ })) {
+ diag.attachNote(containingOp->getLoc())
+ << "containingOp's result yield with stride";
+ return {};
+ }
+
+ Location loc = forallOp.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(terminatorOp);
+
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+ // Try to get iter domain position from input position.
+ if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ diag.attachNote(consumerOp->getLoc())
+ << "can't get iter domain position from input position";
+ return {};
+ }
+
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(consumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ diag.attachNote(consumerOp->getLoc())
+ << "can't get result domain position from iter domain position";
+ return {};
+ }
+ }
+
+ // All check passed, try to fuse consumer.
+ // Create tiled implementation of containing op.
+ FailureOr<TilingResult> tileAndFuseResult =
+ tileableConsumer.getTiledImplementationFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes);
+ if (failed(tileAndFuseResult)) {
+ diag.attachNote(consumerOp->getLoc()) << "get tiled implementation failed";
+ return {};
+ }
+
+ auto tiledOps = tileAndFuseResult->tiledOps;
+ if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+ diag.attachNote(tileableConsumer->getLoc())
+ << "failed to tile consumer op: " << *tileableConsumer;
+ return {};
+ }
+
+ // Replace tiled op's operand .
+ for (auto operandNum : operandNums) {
+ tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+ }
+ rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+ [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return forallOp->isProperAncestor(op);
+ });
+
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+
+ // Create new scf.forall op.
+ rewriter.setInsertionPoint(forallOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ for (auto v : dpsInits) {
+ newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return newforallOp->isProperAncestor(op);
+ });
+ }
+
+ // Fix terminator.
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+ newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+ Operation *firstYieldOp = yieldingOps.front();
+ rewriter.setInsertionPoint(firstYieldOp);
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOp->getLoc(), v,
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+ resultPositions[idx].first, resultPositions[idx].second, strides);
+ }
+
+ // Replace the result of forall and consumer op.
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforallOp->getResult(forallOp.getOutputs().size() +
+ consumerResult.index()));
+ }
+
+ return std::make_tuple(tileAndFuseResult->tiledOps, newforallOp);
+}
+
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
@@ -880,7 +1097,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> fusedOps;
- auto producerOps = state.getPayloadOps(getProducerOp());
+ auto targetOps = state.getPayloadOps(getTargetOp());
auto containingOps = state.getPayloadOps(getContainingOp());
if (!llvm::hasSingleElement(containingOps)) {
return emitDefiniteFailure()
@@ -890,69 +1107,115 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
Operation *containingOp = *containingOps.begin();
// If nothing to fuse, propagate success.
- if (std::empty(producerOps)) {
+ if (std::empty(targetOps)) {
results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
return DiagnosedSilenceableFailure::success();
}
- // Helper function to find the next producer that should be fused. Take any
- // producer that has a use inside the containing op.
- SetVector<Operation *> remainingProducers(producerOps.begin(),
- producerOps.end());
- auto getNextProducer = [&]() -> FailureOr<Operation *> {
- for (const auto &it : enumerate(remainingProducers)) {
- Operation *producerOp = it.value();
- // The containing op may be a user of producerOp: use isAncestor.
+ // Helper function to find the next target that should be fused. Take any
+ // target that has a use inside the containing op. Return target operation
+ // and a bool variable indicate if this target op is a producer.
+ SetVector<Operation *> remainingTargets(targetOps.begin(), targetOps.end());
+ auto getNextTarget = [&]() -> FailureOr<std::pair<Operation *, bool>> {
+ for (const auto &it : enumerate(remainingTargets)) {
+ Operation *targetOp = it.value();
+ // The containing op may be a user of targetOp: use isAncestor.
int64_t numUsesInContainingOp =
- llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+ llvm::count_if(targetOp->getUsers(), [&](Operation *op) {
return containingOp->isAncestor(op);
});
// TODO: When resolving the TODO below (no duplicate ops), take an op
- // that has no use among the remaining producers. This is a topological
+ // that has no use among the remaining targets. This is a topological
// sorting.
if (numUsesInContainingOp > 0) {
if (numUsesInContainingOp == 1)
- remainingProducers.erase(remainingProducers.begin() + it.index());
- return producerOp;
+ remainingTargets.erase(remainingTargets.begin() + it.index());
+ return std::make_pair(targetOp, true);
+ }
+
+ // The containing op may be a producer of targetOp: use getDefinitionOp.
+ if (llvm::any_of(
+ targetOp->getOperands(),
+ [&](Value v) { return v.getDefiningOp() == containingOp; }) &&
+ !targetOp->getUses().empty()) {
+ remainingTargets.erase(remainingTargets.begin() + it.index());
+ return std::make_pair(targetOp, false);
}
}
return failure();
};
- while (!remainingProducers.empty()) {
- auto nextProducer = getNextProducer();
- if (failed(nextProducer)) {
+ while (!remainingTargets.empty()) {
+ auto nextTarget = getNextTarget();
+ if (failed(nextTarget)) {
auto diag = mlir::emitSilenceableFailure(getLoc())
- << "could not find next producer to fuse into container";
+ << "could not find next target to fuse into container";
diag.attachNote(containingOp->getLoc()) << "containing op";
return diag;
}
- Operation *producerOp = *nextProducer;
+ Operation *targetOp = nextTarget->first;
// Default diagnostic, to be complemented with more failure information.
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
- diag << "could not fuse " << *producerOp << " into " << *containingOp;
-
- // TODO: If there are multiple uses of the producer in the containing op,
- // we currently tile/clone the op multiple times (once per use). In some
- // cases, we can tile/clone once and reuse the value for each use.
- // Futhermore, producers should then be traversed according to a
- // topological sorting.
- auto [tiledOps, newContainingOp] =
- tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
- if (!tiledOps.empty()) {
- LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
- fusedOps.append(tiledOps);
- if (newContainingOp) {
- // Update handles associated with the containing op so we don't need to
- // invalidate them. This is a hack to support better composability
- // between tiling and fusion while a proper mechanism is being
- // investigated.
- //
- // DO NOT replicate this elsewhere unless you understand what you are
- // doing.
+ Diagnostic diag(targetOp->getLoc(), DiagnosticSeverity::Remark);
+ diag << "could not fuse " << *targetOp << " into " << *containingOp;
+
+ if (nextTarget->second) {
+ // Fuse producer.
+ // TODO: If there are multiple uses of the target in the containing op,
+ // we currently tile/clone the op multiple times (once per use). In some
+ // cases, we can tile/clone once and reuse the value for each use.
+ // Futhermore, targets should then be traversed according to a
+ // topological sorting.
+ auto [tiledOps, newContainingOp] =
+ tileAndFuseFirstExtractUse(rewriter, diag, targetOp, containingOp);
+ if (!tiledOps.empty()) {
+ LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
+ fusedOps.append(tiledOps);
+ if (newContainingOp) {
+ // Update handles associated with the containing op so we don't need
+ // to invalidate them. This is a hack to support better composability
+ // between tiling and fusion while a proper mechanism is being
+ // investigated.
+ //
+ // DO NOT replicate this elsewhere unless you understand what you are
+ // doing.
+ LogicalResult replacementStatus =
+ rewriter.notifyPayloadOperationReplaced(containingOp,
+ newContainingOp);
+ (void)replacementStatus;
+ assert(succeeded(replacementStatus) &&
+ "unable to update transform state mapping");
+ rewriter.eraseOp(containingOp);
+ containingOp = newContainingOp;
+ }
+ continue;
+ }
+
+ SmallVector<Operation *> tiledContainingOpOperand =
+ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+ rewriter, diag, targetOp, containingOp);
+ if (!tiledContainingOpOperand.empty()) {
+ LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
+ << *containingOp);
+ fusedOps.append(tiledContainingOpOperand);
+ continue;
+ }
+
+ Operation *cloned =
+ cloneAndFuseFirstUse(rewriter, diag, targetOp, containingOp);
+ if (cloned) {
+ LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
+ fusedOps.push_back(cloned);
+ continue;
+ }
+ } else {
+ // Fuse consumer.
+ auto [tiledOps, newContainingOp] = tileAndFuseParallelInsertSlice(
+ rewriter, diag, targetOp, containingOp);
+ if (!tiledOps.empty()) {
+ fusedOps.append(tiledOps);
LogicalResult replacementStatus =
rewriter.notifyPayloadOperationReplaced(containingOp,
newContainingOp);
@@ -961,26 +1224,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
"unable to update transform state mapping");
rewriter.eraseOp(containingOp);
containingOp = newContainingOp;
+ continue;
}
- continue;
- }
-
- SmallVector<Operation *> tiledContainingOpOperand =
- tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
- rewriter, diag, producerOp, containingOp);
- if (!tiledContainingOpOperand.empty()) {
- LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
- << *containingOp);
- fusedOps.append(tiledContainingOpOperand);
- continue;
- }
-
- Operation *cloned =
- cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
- if (cloned) {
- LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
- fusedOps.push_back(cloned);
- continue;
}
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
@@ -992,7 +1237,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
void transform::FuseIntoContainingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- consumesHandle(getProducerOp(), effects);
+ consumesHandle(getTargetOp(), effects);
onlyReadsHandle(getContainingOp(), effects);
producesHandle(getResults(), effects);
modifiesPayload(effects);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
+ void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &mappedOffsets,
+ SmallVector<OpFoldResult> &mappedSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+ auto numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &range : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[range.index()] = range.value().offset;
+ mappedSizes[range.index()] = range.value().size;
+ }
+ }
+ for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition =
+ cast<AffineDimExpr>(resultExpr.value()).getPosition();
+ mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+ mappedSizes[dimPosition] = sizes[resultExpr.index()];
+ }
+ }
+
+ // Return the details of the output tile generated by the tiled
+ // implementation.
+ LogicalResult getIterDomainTilePositionFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &iterDomainOffsets,
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return op->emitOpError(
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
// Return the details of the output tile generated by the tiled
// implementation.
LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return op->emitOpError(
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+ mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c6..5a1f1b01483750 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -585,6 +585,129 @@ module {
// -----
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_consumer
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+ func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+ %0 = tensor.empty(%arg0) : tensor<?xf32>
+ %1 = affine.apply #map()[%arg0]
+ // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+ %2 = tensor.empty() : tensor<64xf32>
+ // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+ %3 = tensor.empty() : tensor<64xf32>
+ // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+ %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+ %6 = affine.apply #map1(%arg3)[%arg0]
+ %7 = affine.min #map2(%arg3)[%arg0]
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+ // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+ %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+ tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+ }
+ }
+ // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+ %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+ // CHECK: return %[[RES]]#1
+ return %5 : tensor<64xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.forall">
+ %1 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg0 : (!transform.any_op) -> !transform.op<"linalg.elemwise_binary">
+ %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1 into %0 : (!transform.op<"linalg.elemwise_binary">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+module {
+ // CHECK-LABEL: func.func @fuse_consumer_multi_output
+ // CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+ // CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+ func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[INIT:.*]] = linalg.fill
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+ // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+ %1 = tensor.empty() : tensor<123x789xf32>
+ // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+ %2 = tensor.empty() : tensor<789x123xf32>
+ // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+ %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+ %5 = affine.min #map(%arg3)
+ %6 = affine.min #map1(%arg4)
+ %7 = affine.apply #map2(%arg3)
+ %8 = affine.apply #map3(%arg4)
+ %9 = affine.apply #map2(%arg3)
+ %10 = affine.apply #map3(%arg4)
+ // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+ %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+ // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+ %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+ // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+ %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+ // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+ %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+ // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+ // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+ %12 = affine.apply #map2(%arg3)
+ %13 = affine.apply #map3(%arg4)
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[[LOOP_ARG1]]
+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#1 into %[[LOOP_ARG2]]
+ // CHECK: tensor.parallel_insert_slice %[[MATMUL_RES]] into %[[LOOP_ARG0]]
+ tensor.parallel_insert_slice %11 into %arg5[%12, %13] [%5, %6] [1, 1] : tensor<?x?xf32> into tensor<123x789xf32>
+ }
+ }
+ // CHECK: %[[ORI_OUTPUT:.*]]:2 = linalg.generic
+ %4:2 = linalg.generic {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<123x789xf32>) outs(%1, %2 : tensor<123x789xf32>, tensor<789x123xf32>) {
+ ^bb0(%in: f32, %out: f32, %out_0: f32):
+ %5 = arith.addf %in, %out : f32
+ %6 = arith.addf %5, %out_0 : f32
+ linalg.yield %5, %6 : f32, f32
+ } -> (tensor<123x789xf32>, tensor<789x123xf32>)
+ // CHECK: return %[[RES]]#1, %[[RES]]#2
+ return %4#0, %4#1 : tensor<123x789xf32>, tensor<789x123xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.forall">
+ %1 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1 into %0 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
// This is a regression test. Make sure that the transform succeeds and valid
// IR is generated.
@@ -677,7 +800,7 @@ module attributes {transform.with_named_sequence} {
num_threads [] tile_sizes [50, 16]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Note that we pass in %tiled_op, which isn't a container op.
- // expected-error @+2 {{could not find next producer to fuse into container}}
+ // expected-error @+2 {{could not find next target to fuse into container}}
%fused_op, %new_containing_op =
transform.structured.fuse_into_containing_op %0 into %tiled_op
: (!transform.any_op, !transform.any_op)
More information about the Mlir-commits
mailing list