[Mlir-commits] [mlir] [mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (PR #107882)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 9 09:06:45 PDT 2024
https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/107882
Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF` looks at operands of tiled/tiled+fused operations to see if they are produced by `extract_slice` operations to populate the worklist used to continue fusion. This implicit assumption does not always work. Instead make the implementations of `getTiledImplementation` return the slices to use to continue fusion.
This is a breaking change
- To continue to get the same behavior of `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree implementation of `TilingInterface::getTiledImplementation` to return the slices to continue fusion on. All in-tree implementations have been adapted to this.
- This change touches parts that required a simplification to the `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a `std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object that should be `std::nullopt` if fusion is not to be performed.
>From 850784b85d6da2e26764b648f62537815b2f37c5 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sat, 7 Sep 2024 23:31:58 -0700
Subject: [PATCH] [mlir][TilingInterface] Avoid looking at operands for getting
slices to continue tile + fuse.
Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF`
looks at operands of tiled/tiled+fused operations to see if they are
produced by `extract_slice` operations to populate the worklist used
to continue fusion. This implicit assumption does not always
work. Instead make the implementations of `getTiledImplementation`
return the slices to use to continue fusion.
This is a breaking change
- To continue to get the same behavior of
`scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree
implementation of `TilingInterface::getTiledImplementation` to
return the slices to continue fusion on. All in-tree implementations
have been adapted to this.
- This change touches parts that required a simplification to the
`ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a
`std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object
that should be `std::nullopt` if fusion is not to be performed.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 11 +--
.../SCF/Transforms/TileUsingInterface.h | 28 ++++---
.../include/mlir/Interfaces/TilingInterface.h | 7 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 81 ++++++++++++-------
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 22 +++--
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 20 ++---
.../SCF/Transforms/TileUsingInterface.cpp | 81 ++++++++++---------
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 71 ++++++++++------
.../TestTilingInterfaceTransformOps.cpp | 12 +--
9 files changed, 205 insertions(+), 128 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 65a1a8b42e1495..f1df49ce3eaa36 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
/// controls whether to omit the partial/boundary tile condition check in
/// cases where we statically know that it is unnecessary.
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
- ArrayRef<OpFoldResult> subShapeSizes,
- bool omitPartialTileCheck);
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck);
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..ffd36878686751 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -106,6 +106,9 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
+ /// Slices generated after tiling that can be used for fusing with the tiled
+ /// producer.
+ SmallVector<Operation *> generatedSlices;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
/// 2) the producer value that is to be fused
/// 3) a boolean value set to `true` if the fusion is from
/// a destination operand.
- /// It retuns two booleans
- /// - returns `true` if the fusion should be done through the candidate slice
- /// - returns `true` if a replacement for the fused producer needs to be
- /// yielded from within the tiled loop. Note that it is valid to return
- /// `true` only if the slice fused is disjoint across all iterations of the
- /// tiled loop. It is up to the caller to ensure that this is true for the
- /// fused producers.
- using ControlFnTy = std::function<std::tuple<bool, bool>(
+ /// The control function returns an `std::optiona<ControlFnResult>`.
+ /// If the return value is `std::nullopt`, that implies no fusion
+ /// is to be performed along that slice.
+ struct ControlFnResult {
+ /// Set to true if the loop nest has to return a replacement value
+ /// for the fused producer.
+ bool yieldProducerReplacement = false;
+ };
+ using ControlFnTy = std::function<std::optional<ControlFnResult>(
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)>;
- ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
- return std::make_tuple(true, false);
+ /// The default control function implements a greedy fusion with yielding
+ /// a replacement for any of the fused results.
+ ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
+ bool) -> std::optional<ControlFnResult> {
+ return ControlFnResult{};
};
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
fusionControlFn = controlFn;
@@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
+ SmallVector<Operation *> generatedSlices;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index 2f51496d1b110a..b33aa1489c3116 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -25,12 +25,15 @@ namespace mlir {
/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
-/// are returned to the caller for further transformations.
+/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
-/// untiled operation.
+/// untiled operation.
+/// - `generatedSlices` contains the list of slices that are generated during
+/// tiling. These slices can be used for fusing producers.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
+ SmallVector<Operation *> generatedSlices;
};
/// Container for the result of merge operation of tiling.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..d423c94487d16c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -66,20 +66,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
/// `source`.
-static Value getSlice(OpBuilder &b, Location loc, Value source,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- return TypeSwitch<Type, Value>(source.getType())
- .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
+ return TypeSwitch<Type, Operation *>(source.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
strides);
})
- .Case<MemRefType>([&](MemRefType type) -> Value {
+ .Case<MemRefType>([&](MemRefType type) -> Operation * {
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
strides);
})
- .Default([&](Type t) { return nullptr; });
+ .Default([&](Type t) -> Operation * { return nullptr; });
}
//===----------------------------------------------------------------------===//
@@ -2603,10 +2603,18 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.emplace_back(
- getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
- tiledOperands.emplace_back(
- getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
+ Operation *inputSlice =
+ getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+ if (!inputSlice) {
+ return emitOpError("failed to compute input slice");
+ }
+ tiledOperands.emplace_back(inputSlice->getResult(0));
+ Operation *outputSlice =
+ getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+ if (!outputSlice) {
+ return emitOpError("failed to compute output slice");
+ }
+ tiledOperands.emplace_back(outputSlice->getResult(0));
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics())
@@ -2614,7 +2622,9 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{{tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ {inputSlice, outputSlice}};
}
LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2961,8 +2971,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t filterRank = getFilterOperandRank();
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
Location loc = getLoc();
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+ auto filterSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
+ tiledOperands.emplace_back(filterSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -2971,15 +2982,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
@@ -3128,8 +3143,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+ auto inputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
+ tiledOperands.emplace_back(inputSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3138,15 +3154,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
@@ -3290,8 +3310,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
sizes[getValueFDim()]});
int64_t valueRank = getValueOperandRank();
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+ auto valueSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
+ tiledOperands.emplace_back(valueSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3300,15 +3321,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, strides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, strides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fbff91a94219cc..ca05467534a5b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -120,8 +120,12 @@ struct LinalgOpTilingInterface
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Value> valuesToTile = linalgOp->getOperands();
- SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+ SmallVector<Value> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+ SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+ llvm::make_filter_range(
+ tiledOperands, [](Value v) -> bool { return v.getDefiningOp(); }),
+ [](Value v) -> Operation * { return v.getDefiningOp(); });
SmallVector<Type> resultTensorTypes =
getTensorOutputTypes(linalgOp, tiledOperands);
@@ -129,7 +133,8 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
}
/// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -260,7 +265,8 @@ struct LinalgOpTilingInterface
return TilingResult{
tilingResult->tiledOps,
- SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+ SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+ tilingResult->generatedSlices};
}
/// Method to generate the tiled implementation of an operation from the tile
@@ -406,8 +412,12 @@ struct LinalgOpPartialReductionInterface
}
// Step 2a: Extract a slice of the input operands.
- SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+ SmallVector<Value> tiledInputs = makeTiledShapes(
b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+ SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+ llvm::make_filter_range(
+ tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
+ [](Value v) -> Operation * { return v.getDefiningOp(); });
// Step 2b: Extract a slice of the init operands.
SmallVector<Value, 1> tiledInits;
@@ -424,6 +434,7 @@ struct LinalgOpPartialReductionInterface
auto extractSlice = b.create<tensor::ExtractSliceOp>(
loc, valueToTile, initOffset, initSizes, initStride);
tiledInits.push_back(extractSlice);
+ generatedSlices.push_back(extractSlice);
}
// Update the indexing maps.
@@ -453,7 +464,8 @@ struct LinalgOpPartialReductionInterface
return TilingResult{
{genericOp.getOperation()},
llvm::map_to_vector(genericOp->getResults(),
- [](OpResult r) -> Value { return r; })};
+ [](OpResult r) -> Value { return r; }),
+ generatedSlices};
}
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index fa0598dd96885c..6a3f2fc5fbc496 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
}
-static Value materializeTiledShape(OpBuilder &builder, Location loc,
- Value valueToTile,
- const SliceParameters &sliceParams) {
+static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
+ Value valueToTile,
+ const SliceParameters &sliceParams) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
@@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
.Default([](ShapedType) -> Operation * {
llvm_unreachable("Unexpected shaped type");
});
- return sliceOp->getResult(0);
+ return sliceOp;
}
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
- ArrayRef<OpFoldResult> subShapeSizes,
- bool omitPartialTileCheck) {
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck) {
SliceParameters sliceParams =
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
ubs, subShapeSizes, omitPartialTileCheck);
@@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
tiledShapes.push_back(
sliceParams.has_value()
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
+ ->getResult(0)
: valueToTile);
}
return tiledShapes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..fc27f96500abc7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -854,7 +854,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
- TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+ TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+ /*generatedSlices=*/{}};
return success();
}
@@ -910,12 +911,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// op.
if (loops.empty()) {
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
- tilingResult->tiledValues};
+ tilingResult->tiledValues,
+ tilingResult->generatedSlices};
}
SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
- return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+ return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+ tilingResult->generatedSlices};
}
FailureOr<scf::SCFReductionTilingResult>
@@ -1180,9 +1183,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
- return scf::SCFFuseProducerOfSliceResult{fusableProducer,
- tileAndFuseResult->tiledValues[0],
- tileAndFuseResult->tiledOps};
+ return scf::SCFFuseProducerOfSliceResult{
+ fusableProducer, tileAndFuseResult->tiledValues[0],
+ tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -1358,48 +1361,55 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// operations. If the producers of the source of the `tensor.extract_slice`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
- auto addCandidateSlices = [](Operation *fusedOp,
- std::deque<tensor::ExtractSliceOp> &candidates) {
- for (Value operand : fusedOp->getOperands())
- if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
- candidates.push_back(sliceOp);
+ struct WorklistItem {
+ tensor::ExtractSliceOp candidateSlice;
+ SCFTileAndFuseOptions::ControlFnResult controlFnResult;
+ };
+ std::deque<WorklistItem> worklist;
+ auto addCandidateSlices = [&worklist, &options,
+ &loops](ArrayRef<Operation *> candidates) {
+ for (auto candidate : candidates) {
+ auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
+ if (!sliceOp)
+ continue;
+
+ auto [fusableProducer, destinationInitArg] =
+ getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
+ if (!fusableProducer)
+ continue;
+ std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
+ options.fusionControlFn(sliceOp, fusableProducer,
+ destinationInitArg.has_value());
+ if (!controlFnResult)
+ continue;
+ worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
+ }
};
- std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tiledAndFusedOps.back(), candidates);
+ addCandidateSlices(tilingResult->generatedSlices);
OpBuilder::InsertionGuard g(rewriter);
- while (!candidates.empty()) {
+ while (!worklist.empty()) {
// Traverse the slices in BFS fashion.
- tensor::ExtractSliceOp candidateSliceOp = candidates.front();
- candidates.pop_front();
-
- // Find the original producer of the slice.
- auto [fusableProducer, destinationInitArg] =
- getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
- loops);
- if (!fusableProducer)
- continue;
-
- auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
- candidateSliceOp, fusableProducer, destinationInitArg.has_value());
- if (!fuseSlice)
- continue;
+ WorklistItem worklistItem = worklist.front();
+ worklist.pop_front();
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
- tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
+ tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
+ loops);
if (!fusedResult)
continue;
- if (yieldReplacement) {
+ if (worklistItem.controlFnResult.yieldProducerReplacement) {
// Reconstruct and yield all opResult of fusableProducerOp by default. The
// caller can specific which one to yield by designating optional argument
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
- Operation *fusableProducerOp = fusableProducer.getOwner();
+ Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
- rewriter, candidateSliceOp, fusedResult.value(), loops))) {
+ rewriter, worklistItem.candidateSlice, fusedResult.value(),
+ loops))) {
return rewriter.notifyMatchFailure(
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
@@ -1412,12 +1422,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
}
}
- if (Operation *tiledAndFusedOp =
- fusedResult->tiledAndFusedProducer.getDefiningOp()) {
- fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
- tiledAndFusedOps.insert(tiledAndFusedOp);
- addCandidateSlices(tiledAndFusedOp, candidates);
- }
+ addCandidateSlices(fusedResult->generatedSlices);
}
DenseMap<Value, Value> replacements;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 9e17184ebed794..104d6ae1f9f6b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -187,8 +187,9 @@ struct PackOpTiling
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.push_back(b.create<ExtractSliceOp>(
- loc, packOp.getSource(), inputIndices, inputSizes, strides));
+ auto sourceSlice = b.create<ExtractSliceOp>(
+ loc, packOp.getSource(), inputIndices, inputSizes, strides);
+ tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
@@ -196,9 +197,9 @@ struct PackOpTiling
return {};
strides.append(packOp.getDestRank() - inputRank, oneAttr);
- auto extractSlice = b.create<ExtractSliceOp>(
+ auto outSlice = b.create<ExtractSliceOp>(
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
- tiledOperands.push_back(extractSlice);
+ tiledOperands.push_back(outSlice);
if (auto val = packOp.getPaddingValue())
tiledOperands.push_back(val);
@@ -206,10 +207,12 @@ struct PackOpTiling
tiledOperands.push_back(tile);
Operation *tiledPackOp = b.create<PackOp>(
- loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+ loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
- return TilingResult{{tiledPackOp},
- SmallVector<Value>(tiledPackOp->getResults())};
+ return TilingResult{
+ {tiledPackOp},
+ SmallVector<Value>(tiledPackOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
}
LogicalResult
@@ -348,8 +351,9 @@ struct PackOpTiling
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
- offsets, sizes, strides));
+ auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
+ offsets, sizes, strides);
+ tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
if (failed(getIterationDomainTileFromOperandTile(
@@ -363,19 +367,21 @@ struct PackOpTiling
return failure();
strides.append(packOp.getDestRank() - inputRank, oneAttr);
- auto extractSlice = b.create<ExtractSliceOp>(
+ auto outSlice = b.create<ExtractSliceOp>(
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
- tiledOperands.push_back(extractSlice);
+ tiledOperands.push_back(outSlice);
assert(!packOp.getPaddingValue() && "Expect no padding semantic");
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);
Operation *tiledPackOp = b.create<PackOp>(
- loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+ loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
- return TilingResult{{tiledPackOp},
- SmallVector<Value>(tiledPackOp->getResults())};
+ return TilingResult{
+ {tiledPackOp},
+ SmallVector<Value>(tiledPackOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
}
};
@@ -554,9 +560,12 @@ struct UnPackOpTiling
SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
Value sliceDest;
+ SmallVector<Operation *> generatedSlices;
if (isPerfectTilingCase) {
- sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
- sizes, destStrides);
+ auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
+ offsets, sizes, destStrides);
+ sliceDest = destSliceOp;
+ generatedSlices.push_back(destSliceOp);
} else {
sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
unpackOp.getDestType().getElementType());
@@ -571,12 +580,15 @@ struct UnPackOpTiling
if (isPerfectTilingCase)
return TilingResult{{tiledUnpackOp},
- SmallVector<Value>(tiledUnpackOp->getResults())};
+ SmallVector<Value>(tiledUnpackOp->getResults()),
+ generatedSlices};
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
- return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
+ generatedSlices.push_back(extractSlice);
+ return TilingResult{
+ {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
}
LogicalResult
@@ -697,7 +709,9 @@ struct UnPackOpTiling
tiledOperands, op->getAttrs());
return TilingResult{{tiledUnPackOp},
- SmallVector<Value>(tiledUnPackOp->getResults())};
+ SmallVector<Value>(tiledUnPackOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{
+ extractSourceSlice, extractDestSlice})};
}
};
@@ -867,7 +881,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// the result shape of the new SliceOp has a zero dimension.
auto createPadOfExtractSlice = [&]() {
// Create pad(extract_slice(x)).
- Value newSliceOp = b.create<tensor::ExtractSliceOp>(
+ auto newSliceOp = b.create<tensor::ExtractSliceOp>(
loc, padOp.getSource(), newOffsets, newLengths, newStrides);
auto newPadOp = b.create<PadOp>(
loc, Type(), newSliceOp, newLows, newHighs,
@@ -879,14 +893,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
// Cast result and return.
- return newPadOp;
+ return std::make_tuple(newPadOp, newSliceOp);
};
// Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
// the original data source x is not used.
if (hasZeroLen) {
Operation *generateOp = createGenerateOp();
- return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+ return TilingResult{{generateOp},
+ {castResult(generateOp->getResult(0))},
+ /*generatedSlices=*/{}};
}
// If there are dynamic dimensions: Generate an scf.if check to avoid
@@ -894,6 +910,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
if (generateZeroSliceGuard && dynHasZeroLenCond) {
Operation *thenOp;
Operation *elseOp;
+ Operation *sliceOp;
auto result = b.create<scf::IfOp>(
loc, dynHasZeroLenCond,
/*thenBuilder=*/
@@ -903,14 +920,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
},
/*elseBuilder=*/
[&](OpBuilder &b, Location loc) {
- elseOp = createPadOfExtractSlice();
+ std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
});
- return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
+ return TilingResult{
+ {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
}
- Operation *newPadOp = createPadOfExtractSlice();
- return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
+ auto [newPadOp, sliceOp] = createPadOfExtractSlice();
+ return TilingResult{
+ {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
}
void mlir::tensor::registerTilingInterfaceExternalModels(
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7aa7b58433f36c..b6da47977cb4cf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -91,11 +91,13 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
- bool isDestinationOperand) {
- Operation *owner = originalProducer.getOwner();
- bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
- return std::make_tuple(true, yieldProducerReplacement);
- };
+ bool isDestinationOperand)
+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
+ Operation *owner = originalProducer.getOwner();
+ bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+ return scf::SCFTileAndFuseOptions::ControlFnResult{
+ yieldProducerReplacement};
+ };
tileAndFuseOptions.setFusionControlFn(controlFn);
rewriter.setInsertionPoint(target);
More information about the Mlir-commits
mailing list