[Mlir-commits] [mlir] ce349ff - [mlir][TilingInterface] NFC: Separate out a utility method to perform one step of tile + fuse.
Mahesh Ravishankar
llvmlistbot at llvm.org
Sun Jan 15 21:21:34 PST 2023
Author: Mahesh Ravishankar
Date: 2023-01-16T05:03:41Z
New Revision: ce349ff1a483e5d08bdd394dddea63b549e646fd
URL: https://github.com/llvm/llvm-project/commit/ce349ff1a483e5d08bdd394dddea63b549e646fd
DIFF: https://github.com/llvm/llvm-project/commit/ce349ff1a483e5d08bdd394dddea63b549e646fd.diff
LOG: [mlir][TilingInterface] NFC: Separate out a utility method to perform one step of tile + fuse.
Differential Revision: https://reviews.llvm.org/D141027
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 52cd7609d55e..dd0ed448ea3a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -505,6 +505,101 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {source->get().dyn_cast<OpResult>(), destinationIterArg};
}
+static std::optional<Operation *>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+ tensor::ExtractSliceOp candidateSliceOp,
+ MutableArrayRef<scf::ForOp> loops) {
+ // 1. Get the producer of the source (potentially walking through
+ // `iter_args` of nested `scf.for`)
+ auto [fusableProducer, destinationIterArg] =
+ getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+ loops);
+ if (!fusableProducer)
+ return std::nullopt;
+
+ // 2. Generate the tiled implementation of the producer of the source
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(candidateSliceOp);
+ FailureOr<Value> fusedProducerValue =
+ tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
+ fusableProducer);
+ if (failed(fusedProducerValue))
+ return std::nullopt;
+ rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
+
+ // 3. If the slice is for a destination operand, for example,
+ //
+ // ```mlir
+ // %0 = linalg.init
+ // %1 = linalg.fill .. outs(%0 : )
+ // %2 = scf.for .. iter_args(%arg0 = %1) {
+ // %3 = scf.for .. iter_args(%arg1 = %arg0) {
+ // %4 = tensor.extract_slice %arg1 [..]
+ // .. = linalg.matmul .. outs(%4 : )
+ // }
+ // }
+ // ```
+ //
+ // the IR is currently
+ //
+ // ```
+ // %0 = linalg.init
+ // %1 = linalg.fill
+ // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
+ // %3 = scf.for .. iter_args(%arg1 = %arg0) {
+ // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
+ // %5 = linalg.fill .. outs(%4 : )
+ // .. = linalg.matmul .. outs(%5 : )
+ // }
+ // }
+ // ```
+ //
+ // The untiled `linalg.fill` is still used as the `init_value` since it
+ // was originally a destination operand of the untiled `linalg.matmul`.
+ // When fusing an operand that is a destination operand.
+ // - Update the iter_arg of the outer most loop to use the destination
+ // of the untiled producer.
+ // - Update the destination of the slice of the tiled producer generated
+ // to use the same basic block argument as the slice that was used to
+ // generate inplace the tiled implementation of the producer.
+ // With this the IR will be.
+ //
+ // ```
+ // %0 = linalg.init
+ // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
+ // %2 = scf.for .. iter_args(%arg1 = %arg0) {
+ // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
+ // %4 = linalg.fill .. outs(%3 : )
+ // .. = linalg.matmul .. outs(%4 : )
+ // }
+ // }
+ // ```
+ // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
+ // Update to use that when it does become available.
+ scf::ForOp outerMostLoop = loops.front();
+ Optional<unsigned> iterArgNumber;
+ if (destinationIterArg) {
+ iterArgNumber =
+ outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
+ }
+ if (iterArgNumber) {
+ int64_t resultNumber = fusableProducer.getResultNumber();
+ if (auto dstOp =
+ dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
+ outerMostLoop.setIterArg(iterArgNumber.value(),
+ dstOp.getTiedOpOperand(fusableProducer)->get());
+ }
+ if (auto dstOp = fusedProducerValue.value()
+ .getDefiningOp<DestinationStyleOpInterface>()) {
+ scf::ForOp innerMostLoop = loops.back();
+ updateDestinationOperandsForTiledOp(
+ rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
+ innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+ }
+ }
+ return fusedProducerValue->getDefiningOp();
+}
+
/// Implementation of tile consumer and fuse producer greedily.
FailureOr<scf::SCFTileAndFuseResult>
mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
@@ -559,105 +654,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
- // 2a. Traverse the slices in BFS fashion.
+ // Traverse the slices in BFS fashion.
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();
- // 2b. Get the producer of the source (potentially walking through
- // `iter_args` of nested `scf.for`)
- auto [fusableProducer, destinationIterArg] =
- getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
- tileAndFuseResult.loops);
- if (!fusableProducer)
+ // The operands of the fused producer might themselved be slices of
+ // values produced by operations that implement the `TilingInterface`.
+ // Add these operations to the worklist.
+ Optional<Operation *> fusedProducer = tileAndFuseProducerOfSlice(
+ rewriter, candidateSliceOp, tileAndFuseResult.loops);
+ if (!fusedProducer)
continue;
- // 2c. Generate the tiled implementation of the producer of the source
- rewriter.setInsertionPoint(candidateSliceOp);
- FailureOr<Value> fusedProducerValue =
- tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
- fusableProducer);
- if (failed(fusedProducerValue))
- continue;
- rewriter.replaceOp(candidateSliceOp, *fusedProducerValue);
-
- // 2d. The operands of the fused producer might themselved be slices of
- // values produced by operations that implement the `TilingInterface`.
- // Add these operations to the worklist.
- Operation *fusedProducer = fusedProducerValue->getDefiningOp();
- tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
- addCandidateSlices(fusedProducer, candidates);
-
- // 2e. If the slice is for a destination operand, for example,
- //
- // ```mlir
- // %0 = linalg.init
- // %1 = linalg.fill .. outs(%0 : )
- // %2 = scf.for .. iter_args(%arg0 = %1) {
- // %3 = scf.for .. iter_args(%arg1 = %arg0) {
- // %4 = tensor.extract_slice %arg1 [..]
- // .. = linalg.matmul .. outs(%4 : )
- // }
- // }
- // ```
- //
- // the IR is currently
- //
- // ```
- // %0 = linalg.init
- // %1 = linalg.fill
- // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
- // %3 = scf.for .. iter_args(%arg1 = %arg0) {
- // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
- // %5 = linalg.fill .. outs(%4 : )
- // .. = linalg.matmul .. outs(%5 : )
- // }
- // }
- // ```
- //
- // The untiled `linalg.fill` is still used as the `init_value` since it
- // was originally a destination operand of the untiled `linalg.matmul`.
- // When fusing an operand that is a destination operand.
- // - Update the iter_arg of the outer most loop to use the destination
- // of the untiled producer.
- // - Update the destination of the slice of the tiled producer generated
- // to use the same basic block argument as the slice that was used to
- // generate inplace the tiled implementation of the producer.
- // With this the IR will be.
- //
- // ```
- // %0 = linalg.init
- // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
- // %2 = scf.for .. iter_args(%arg1 = %arg0) {
- // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
- // %4 = linalg.fill .. outs(%3 : )
- // .. = linalg.matmul .. outs(%4 : )
- // }
- // }
- // ```
- // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
- // Update to use that when it does become available.
- scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
- std::optional<unsigned> iterArgNumber;
- if (destinationIterArg) {
- iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
- *destinationIterArg.value());
- }
- if (iterArgNumber) {
- int64_t resultNumber = fusableProducer.getResultNumber();
- if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(
- fusableProducer.getOwner())) {
- outerMostLoop.setIterArg(
- iterArgNumber.value(),
- dstOp.getTiedOpOperand(fusableProducer)->get());
- }
- if (auto dstOp = fusedProducerValue
- ->getDefiningOp<DestinationStyleOpInterface>()) {
- scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
- updateDestinationOperandsForTiledOp(
- rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
- innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
- }
- }
+ tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value());
+ addCandidateSlices(fusedProducer.value(), candidates);
}
return tileAndFuseResult;
}
More information about the Mlir-commits
mailing list