[Mlir-commits] [mlir] [mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (PR #107882)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 11 09:47:27 PDT 2024
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/107882
>From 1b1f2caf9f94b239eaf24a345316c98eaf86babb Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sat, 7 Sep 2024 23:31:58 -0700
Subject: [PATCH] [mlir][TilingInterface] Avoid looking at operands for getting
slices to continue tile + fuse.
Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF`
looks at operands of tiled/tiled+fused operations to see if they are
produced by `extract_slice` operations to populate the worklist used
to continue fusion. This implicit assumption does not always
work. Instead make the implementations of `getTiledImplementation`
return the slices to use to continue fusion.
This is a breaking change
- To continue to get the same behavior of
`scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree
implementation of `TilingInterface::getTiledImplementation` to
return the slices to continue fusion on. All in-tree implementations
have been adapted to this.
- This change touches parts that required a simplification to the
`ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a
`std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object
that should be `std::nullopt` if fusion is not to be performed.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 11 ++-
.../SCF/Transforms/TileUsingInterface.h | 33 ++++---
.../include/mlir/Interfaces/TilingInterface.h | 7 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 82 ++++++++++------
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 26 +++++-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 20 ++--
.../SCF/Transforms/TileUsingInterface.cpp | 93 +++++++++++--------
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 71 ++++++++------
.../tile-and-fuse-using-interface.mlir | 45 +++++++++
.../TestTilingInterfaceTransformOps.cpp | 12 ++-
10 files changed, 271 insertions(+), 129 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 65a1a8b42e1495..f1df49ce3eaa36 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
/// controls whether to omit the partial/boundary tile condition check in
/// cases where we statically know that it is unnecessary.
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
- ArrayRef<OpFoldResult> subShapeSizes,
- bool omitPartialTileCheck);
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck);
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..77c812cde71533 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -106,6 +106,9 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
+ /// Slices generated after tiling that can be used for fusing with the tiled
+ /// producer.
+ SmallVector<Operation *> generatedSlices;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
/// 2) the producer value that is to be fused
/// 3) a boolean value set to `true` if the fusion is from
/// a destination operand.
- /// It retuns two booleans
- /// - returns `true` if the fusion should be done through the candidate slice
- /// - returns `true` if a replacement for the fused producer needs to be
- /// yielded from within the tiled loop. Note that it is valid to return
- /// `true` only if the slice fused is disjoint across all iterations of the
- /// tiled loop. It is up to the caller to ensure that this is true for the
- /// fused producers.
- using ControlFnTy = std::function<std::tuple<bool, bool>(
+ /// The control function returns an `std::optiona<ControlFnResult>`.
+ /// If the return value is `std::nullopt`, that implies no fusion
+ /// is to be performed along that slice.
+ struct ControlFnResult {
+ /// Set to true if the loop nest has to return a replacement value
+ /// for the fused producer.
+ bool yieldProducerReplacement = false;
+ };
+ using ControlFnTy = std::function<std::optional<ControlFnResult>(
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)>;
- ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
- return std::make_tuple(true, false);
+ /// The default control function implements greedy fusion without yielding
+ /// a replacement for any of the fused results.
+ ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
+ bool) -> std::optional<ControlFnResult> {
+ return ControlFnResult{};
};
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
fusionControlFn = controlFn;
@@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
+ SmallVector<Operation *> generatedSlices;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
@@ -215,7 +223,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
///
/// The @param `yieldResultNumber` decides which result would be yield. If not
/// given, yield all `opResult` of fused producer.
-LogicalResult yieldReplacementForFusedProducer(
+///
+/// The method returns the list of new slices added during the process (which
+/// can be used to fuse along).
+FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index 2f51496d1b110a..b33aa1489c3116 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -25,12 +25,15 @@ namespace mlir {
/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
-/// are returned to the caller for further transformations.
+/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
-/// untiled operation.
+/// untiled operation.
+/// - `generatedSlices` contains the list of slices that are generated during
+/// tiling. These slices can be used for fusing producers.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
+ SmallVector<Operation *> generatedSlices;
};
/// Container for the result of merge operation of tiling.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 630985d76a0ebf..b888005625eda7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -67,20 +67,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
/// `source`.
-static Value getSlice(OpBuilder &b, Location loc, Value source,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- return TypeSwitch<Type, Value>(source.getType())
- .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
+ return TypeSwitch<Type, Operation *>(source.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
strides);
})
- .Case<MemRefType>([&](MemRefType type) -> Value {
+ .Case<MemRefType>([&](MemRefType type) -> Operation * {
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
strides);
})
- .Default([&](Type t) { return nullptr; });
+ .Default([&](Type t) -> Operation * { return nullptr; });
}
//===----------------------------------------------------------------------===//
@@ -2634,10 +2634,18 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.emplace_back(
- getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
- tiledOperands.emplace_back(
- getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
+ Operation *inputSlice =
+ getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+ if (!inputSlice) {
+ return emitOpError("failed to compute input slice");
+ }
+ tiledOperands.emplace_back(inputSlice->getResult(0));
+ Operation *outputSlice =
+ getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+ if (!outputSlice) {
+ return emitOpError("failed to compute output slice");
+ }
+ tiledOperands.emplace_back(outputSlice->getResult(0));
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics())
@@ -2645,7 +2653,10 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
}
LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2992,8 +3003,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t filterRank = getFilterOperandRank();
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
Location loc = getLoc();
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+ auto filterSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
+ tiledOperands.emplace_back(filterSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3002,15 +3014,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
@@ -3159,8 +3175,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+ auto inputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
+ tiledOperands.emplace_back(inputSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3169,15 +3186,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
@@ -3321,8 +3342,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
sizes[getValueFDim()]});
int64_t valueRank = getValueOperandRank();
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+ auto valueSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
+ tiledOperands.emplace_back(valueSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3331,15 +3353,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, strides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, strides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fbff91a94219cc..cf5ca9aa2b0e04 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -120,8 +120,16 @@ struct LinalgOpTilingInterface
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Value> valuesToTile = linalgOp->getOperands();
- SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+ SmallVector<Value> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+ SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+ llvm::make_filter_range(
+ tiledOperands,
+ [](Value v) -> bool {
+ return isa<tensor::ExtractSliceOp, memref::SubViewOp>(
+ v.getDefiningOp());
+ }),
+ [](Value v) -> Operation * { return v.getDefiningOp(); });
SmallVector<Type> resultTensorTypes =
getTensorOutputTypes(linalgOp, tiledOperands);
@@ -129,7 +137,8 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
}
/// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -260,7 +269,8 @@ struct LinalgOpTilingInterface
return TilingResult{
tilingResult->tiledOps,
- SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+ SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+ tilingResult->generatedSlices};
}
/// Method to generate the tiled implementation of an operation from the tile
@@ -406,8 +416,12 @@ struct LinalgOpPartialReductionInterface
}
// Step 2a: Extract a slice of the input operands.
- SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+ SmallVector<Value> tiledInputs = makeTiledShapes(
b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+ SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+ llvm::make_filter_range(
+ tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
+ [](Value v) -> Operation * { return v.getDefiningOp(); });
// Step 2b: Extract a slice of the init operands.
SmallVector<Value, 1> tiledInits;
@@ -424,6 +438,7 @@ struct LinalgOpPartialReductionInterface
auto extractSlice = b.create<tensor::ExtractSliceOp>(
loc, valueToTile, initOffset, initSizes, initStride);
tiledInits.push_back(extractSlice);
+ generatedSlices.push_back(extractSlice);
}
// Update the indexing maps.
@@ -453,7 +468,8 @@ struct LinalgOpPartialReductionInterface
return TilingResult{
{genericOp.getOperation()},
llvm::map_to_vector(genericOp->getResults(),
- [](OpResult r) -> Value { return r; })};
+ [](OpResult r) -> Value { return r; }),
+ generatedSlices};
}
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index fa0598dd96885c..6a3f2fc5fbc496 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
}
-static Value materializeTiledShape(OpBuilder &builder, Location loc,
- Value valueToTile,
- const SliceParameters &sliceParams) {
+static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
+ Value valueToTile,
+ const SliceParameters &sliceParams) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
@@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
.Default([](ShapedType) -> Operation * {
llvm_unreachable("Unexpected shaped type");
});
- return sliceOp->getResult(0);
+ return sliceOp;
}
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
- ArrayRef<OpFoldResult> subShapeSizes,
- bool omitPartialTileCheck) {
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck) {
SliceParameters sliceParams =
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
ubs, subShapeSizes, omitPartialTileCheck);
@@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
tiledShapes.push_back(
sliceParams.has_value()
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
+ ->getResult(0)
: valueToTile);
}
return tiledShapes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..3729300588422e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -854,7 +854,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
- TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+ TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+ /*generatedSlices=*/{}};
return success();
}
@@ -910,12 +911,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// op.
if (loops.empty()) {
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
- tilingResult->tiledValues};
+ tilingResult->tiledValues,
+ tilingResult->generatedSlices};
}
SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
- return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+ return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+ tilingResult->generatedSlices};
}
FailureOr<scf::SCFReductionTilingResult>
@@ -1180,13 +1183,13 @@ mlir::scf::tileAndFuseProducerOfSlice(
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
- return scf::SCFFuseProducerOfSliceResult{fusableProducer,
- tileAndFuseResult->tiledValues[0],
- tileAndFuseResult->tiledOps};
+ return scf::SCFFuseProducerOfSliceResult{
+ fusableProducer, tileAndFuseResult->tiledValues[0],
+ tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
-LogicalResult mlir::scf::yieldReplacementForFusedProducer(
+FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
@@ -1214,6 +1217,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
}
}
+ SmallVector<Operation *> generatedSlices;
YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
@@ -1284,6 +1288,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
loc, newRegionArg, offsetList[index], sizesList[index],
SmallVector<OpFoldResult>(offsetList[index].size(),
rewriter.getIndexAttr(1)));
+ generatedSlices.push_back(destSlice);
unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
@@ -1303,8 +1308,11 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
return success();
};
- return addInitOperandsToLoopNest(rewriter, loops, initValueList,
- newYieldValuesFn);
+ if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
+ newYieldValuesFn))) {
+ return failure();
+ }
+ return generatedSlices;
}
/// Implementation of tile consumer and fuse producer greedily.
@@ -1358,52 +1366,62 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// operations. If the producers of the source of the `tensor.extract_slice`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
- auto addCandidateSlices = [](Operation *fusedOp,
- std::deque<tensor::ExtractSliceOp> &candidates) {
- for (Value operand : fusedOp->getOperands())
- if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
- candidates.push_back(sliceOp);
+ struct WorklistItem {
+ tensor::ExtractSliceOp candidateSlice;
+ SCFTileAndFuseOptions::ControlFnResult controlFnResult;
+ };
+ std::deque<WorklistItem> worklist;
+ auto addCandidateSlices = [&worklist, &options,
+ &loops](ArrayRef<Operation *> candidates) {
+ for (auto candidate : candidates) {
+ auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
+ if (!sliceOp || sliceOp.use_empty())
+ continue;
+
+ auto [fusableProducer, destinationInitArg] =
+ getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
+ if (!fusableProducer)
+ continue;
+ std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
+ options.fusionControlFn(sliceOp, fusableProducer,
+ destinationInitArg.has_value());
+ if (!controlFnResult)
+ continue;
+ worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
+ }
};
- std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tiledAndFusedOps.back(), candidates);
+ addCandidateSlices(tilingResult->generatedSlices);
OpBuilder::InsertionGuard g(rewriter);
- while (!candidates.empty()) {
+ while (!worklist.empty()) {
// Traverse the slices in BFS fashion.
- tensor::ExtractSliceOp candidateSliceOp = candidates.front();
- candidates.pop_front();
-
- // Find the original producer of the slice.
- auto [fusableProducer, destinationInitArg] =
- getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
- loops);
- if (!fusableProducer)
- continue;
-
- auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
- candidateSliceOp, fusableProducer, destinationInitArg.has_value());
- if (!fuseSlice)
- continue;
+ WorklistItem worklistItem = worklist.front();
+ worklist.pop_front();
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
- tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
+ tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
+ loops);
if (!fusedResult)
continue;
- if (yieldReplacement) {
+ if (worklistItem.controlFnResult.yieldProducerReplacement) {
// Reconstruct and yield all opResult of fusableProducerOp by default. The
// caller can specific which one to yield by designating optional argument
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
- Operation *fusableProducerOp = fusableProducer.getOwner();
- if (failed(yieldReplacementForFusedProducer(
- rewriter, candidateSliceOp, fusedResult.value(), loops))) {
+ Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
+ FailureOr<SmallVector<Operation *>> newSlices =
+ yieldReplacementForFusedProducer(rewriter,
+ worklistItem.candidateSlice,
+ fusedResult.value(), loops);
+ if (failed(newSlices)) {
return rewriter.notifyMatchFailure(
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
+ addCandidateSlices(newSlices.value());
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
origValToResultNumber[result] = loops.front()->getNumResults() -
@@ -1411,12 +1429,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
index;
}
}
-
+ addCandidateSlices(fusedResult->generatedSlices);
if (Operation *tiledAndFusedOp =
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
- addCandidateSlices(tiledAndFusedOp, candidates);
}
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 9e17184ebed794..104d6ae1f9f6b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -187,8 +187,9 @@ struct PackOpTiling
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.push_back(b.create<ExtractSliceOp>(
- loc, packOp.getSource(), inputIndices, inputSizes, strides));
+ auto sourceSlice = b.create<ExtractSliceOp>(
+ loc, packOp.getSource(), inputIndices, inputSizes, strides);
+ tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
@@ -196,9 +197,9 @@ struct PackOpTiling
return {};
strides.append(packOp.getDestRank() - inputRank, oneAttr);
- auto extractSlice = b.create<ExtractSliceOp>(
+ auto outSlice = b.create<ExtractSliceOp>(
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
- tiledOperands.push_back(extractSlice);
+ tiledOperands.push_back(outSlice);
if (auto val = packOp.getPaddingValue())
tiledOperands.push_back(val);
@@ -206,10 +207,12 @@ struct PackOpTiling
tiledOperands.push_back(tile);
Operation *tiledPackOp = b.create<PackOp>(
- loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+ loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
- return TilingResult{{tiledPackOp},
- SmallVector<Value>(tiledPackOp->getResults())};
+ return TilingResult{
+ {tiledPackOp},
+ SmallVector<Value>(tiledPackOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
}
LogicalResult
@@ -348,8 +351,9 @@ struct PackOpTiling
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
- offsets, sizes, strides));
+ auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
+ offsets, sizes, strides);
+ tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
if (failed(getIterationDomainTileFromOperandTile(
@@ -363,19 +367,21 @@ struct PackOpTiling
return failure();
strides.append(packOp.getDestRank() - inputRank, oneAttr);
- auto extractSlice = b.create<ExtractSliceOp>(
+ auto outSlice = b.create<ExtractSliceOp>(
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
- tiledOperands.push_back(extractSlice);
+ tiledOperands.push_back(outSlice);
assert(!packOp.getPaddingValue() && "Expect no padding semantic");
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);
Operation *tiledPackOp = b.create<PackOp>(
- loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+ loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
- return TilingResult{{tiledPackOp},
- SmallVector<Value>(tiledPackOp->getResults())};
+ return TilingResult{
+ {tiledPackOp},
+ SmallVector<Value>(tiledPackOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
}
};
@@ -554,9 +560,12 @@ struct UnPackOpTiling
SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
Value sliceDest;
+ SmallVector<Operation *> generatedSlices;
if (isPerfectTilingCase) {
- sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
- sizes, destStrides);
+ auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
+ offsets, sizes, destStrides);
+ sliceDest = destSliceOp;
+ generatedSlices.push_back(destSliceOp);
} else {
sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
unpackOp.getDestType().getElementType());
@@ -571,12 +580,15 @@ struct UnPackOpTiling
if (isPerfectTilingCase)
return TilingResult{{tiledUnpackOp},
- SmallVector<Value>(tiledUnpackOp->getResults())};
+ SmallVector<Value>(tiledUnpackOp->getResults()),
+ generatedSlices};
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
- return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
+ generatedSlices.push_back(extractSlice);
+ return TilingResult{
+ {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
}
LogicalResult
@@ -697,7 +709,9 @@ struct UnPackOpTiling
tiledOperands, op->getAttrs());
return TilingResult{{tiledUnPackOp},
- SmallVector<Value>(tiledUnPackOp->getResults())};
+ SmallVector<Value>(tiledUnPackOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{
+ extractSourceSlice, extractDestSlice})};
}
};
@@ -867,7 +881,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// the result shape of the new SliceOp has a zero dimension.
auto createPadOfExtractSlice = [&]() {
// Create pad(extract_slice(x)).
- Value newSliceOp = b.create<tensor::ExtractSliceOp>(
+ auto newSliceOp = b.create<tensor::ExtractSliceOp>(
loc, padOp.getSource(), newOffsets, newLengths, newStrides);
auto newPadOp = b.create<PadOp>(
loc, Type(), newSliceOp, newLows, newHighs,
@@ -879,14 +893,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
// Cast result and return.
- return newPadOp;
+ return std::make_tuple(newPadOp, newSliceOp);
};
// Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
// the original data source x is not used.
if (hasZeroLen) {
Operation *generateOp = createGenerateOp();
- return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
+ return TilingResult{{generateOp},
+ {castResult(generateOp->getResult(0))},
+ /*generatedSlices=*/{}};
}
// If there are dynamic dimensions: Generate an scf.if check to avoid
@@ -894,6 +910,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
if (generateZeroSliceGuard && dynHasZeroLenCond) {
Operation *thenOp;
Operation *elseOp;
+ Operation *sliceOp;
auto result = b.create<scf::IfOp>(
loc, dynHasZeroLenCond,
/*thenBuilder=*/
@@ -903,14 +920,16 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
},
/*elseBuilder=*/
[&](OpBuilder &b, Location loc) {
- elseOp = createPadOfExtractSlice();
+ std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
});
- return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
+ return TilingResult{
+ {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
}
- Operation *newPadOp = createPadOfExtractSlice();
- return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
+ auto [newPadOp, sliceOp] = createPadOfExtractSlice();
+ return TilingResult{
+ {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
}
void mlir::tensor::registerTilingInterfaceExternalModels(
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index d1aed593f45451..3ea1929e4ed785 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -542,3 +542,48 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
// CHECK: scf.yield %[[INSERTSLICE]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @pad_producer_fusion(%arg0 : tensor<10xf32>) -> tensor<16xf32> {
+ %0 = tensor.empty() : tensor<10xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<10xf32>) outs(%0 : tensor<10xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %2 = arith.addf %b0, %b0: f32
+ linalg.yield %2 : f32
+ } -> tensor<10xf32>
+ %cst = arith.constant 0.0 : f32
+ %2 = tensor.pad %1 low[4] high[2] {
+ ^bb0(%arg1 : index):
+ tensor.yield %cst : f32
+ } : tensor<10xf32> to tensor<16xf32>
+ return %2 : tensor<16xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %pad = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.structured.fuse %pad [8]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @pad_producer_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<10xf32>
+// CHECK: %[[FOR_RESULT:.+]] = scf.for
+// CHECK: %[[IF_RESULT:.+]] = scf.if
+// CHECK: else
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[SLICE]] :
+// CHECK: %[[PAD:.+]] = tensor.pad %[[GENERIC]]
+// CHECK: %[[CAST:.+]] = tensor.cast %[[PAD]]
+// CHECK: scf.yield %[[CAST]]
+// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
+// CHECK: scf.yield %[[INSERT_SLICE]]
+// CHECK: return %[[FOR_RESULT]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7aa7b58433f36c..b6da47977cb4cf 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -91,11 +91,13 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
- bool isDestinationOperand) {
- Operation *owner = originalProducer.getOwner();
- bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
- return std::make_tuple(true, yieldProducerReplacement);
- };
+ bool isDestinationOperand)
+ -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
+ Operation *owner = originalProducer.getOwner();
+ bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+ return scf::SCFTileAndFuseOptions::ControlFnResult{
+ yieldProducerReplacement};
+ };
tileAndFuseOptions.setFusionControlFn(controlFn);
rewriter.setInsertionPoint(target);
More information about the Mlir-commits
mailing list