[Mlir-commits] [mlir] [mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (PR #107882)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 9 16:04:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (MaheshRavishankar)
<details>
<summary>Changes</summary>
Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF` looks at operands of tiled/tiled+fused operations to see if they are produced by `extract_slice` operations to populate the worklist used to continue fusion. This implicit assumption does not always work. Instead make the implementations of `getTiledImplementation` return the slices to use to continue fusion.
This is a breaking change
- To continue to get the same behavior of `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree implementation of `TilingInterface::getTiledImplementation` to return the slices to continue fusion on. All in-tree implementations have been adapted to this.
- This change touches parts that required a simplification to the `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a `std::optional<scf::SCFTileAndFuseOptions::ControlFnResult>` object that should be `std::nullopt` if fusion is not to be performed.
---
Patch is 39.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107882.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+6-5)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+22-11)
- (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+5-2)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+53-28)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+17-5)
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+11-9)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+55-38)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+45-26)
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+45)
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+7-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 65a1a8b42e1495..f1df49ce3eaa36 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
/// controls whether to omit the partial/boundary tile condition check in
/// cases where we statically know that it is unnecessary.
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
- ArrayRef<OpFoldResult> subShapeSizes,
- bool omitPartialTileCheck);
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck);
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..77c812cde71533 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -106,6 +106,9 @@ struct SCFTilingResult {
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
+ /// Slices generated after tiling that can be used for fusing with the tiled
+ /// producer.
+ SmallVector<Operation *> generatedSlices;
};
/// Method to tile an op that implements the `TilingInterface` using
@@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions {
/// 2) the producer value that is to be fused
/// 3) a boolean value set to `true` if the fusion is from
/// a destination operand.
- /// It retuns two booleans
- /// - returns `true` if the fusion should be done through the candidate slice
- /// - returns `true` if a replacement for the fused producer needs to be
- /// yielded from within the tiled loop. Note that it is valid to return
- /// `true` only if the slice fused is disjoint across all iterations of the
- /// tiled loop. It is up to the caller to ensure that this is true for the
- /// fused producers.
- using ControlFnTy = std::function<std::tuple<bool, bool>(
+ /// The control function returns an `std::optiona<ControlFnResult>`.
+ /// If the return value is `std::nullopt`, that implies no fusion
+ /// is to be performed along that slice.
+ struct ControlFnResult {
+ /// Set to true if the loop nest has to return a replacement value
+ /// for the fused producer.
+ bool yieldProducerReplacement = false;
+ };
+ using ControlFnTy = std::function<std::optional<ControlFnResult>(
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)>;
- ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
- return std::make_tuple(true, false);
+ /// The default control function implements greedy fusion without yielding
+ /// a replacement for any of the fused results.
+ ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
+ bool) -> std::optional<ControlFnResult> {
+ return ControlFnResult{};
};
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
fusionControlFn = controlFn;
@@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult {
OpResult origProducer; // Original untiled producer.
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
+ SmallVector<Operation *> generatedSlices;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
@@ -215,7 +223,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
///
/// The @param `yieldResultNumber` decides which result would be yield. If not
/// given, yield all `opResult` of fused producer.
-LogicalResult yieldReplacementForFusedProducer(
+///
+/// The method returns the list of new slices added during the process (which
+/// can be used to fuse along).
+FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index 2f51496d1b110a..b33aa1489c3116 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -25,12 +25,15 @@ namespace mlir {
/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
-/// are returned to the caller for further transformations.
+/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
-/// untiled operation.
+/// untiled operation.
+/// - `generatedSlices` contains the list of slices that are generated during
+/// tiling. These slices can be used for fusing producers.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
+ SmallVector<Operation *> generatedSlices;
};
/// Container for the result of merge operation of tiling.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..d423c94487d16c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -66,20 +66,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
/// `source`.
-static Value getSlice(OpBuilder &b, Location loc, Value source,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- return TypeSwitch<Type, Value>(source.getType())
- .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
+ return TypeSwitch<Type, Operation *>(source.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
strides);
})
- .Case<MemRefType>([&](MemRefType type) -> Value {
+ .Case<MemRefType>([&](MemRefType type) -> Operation * {
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
strides);
})
- .Default([&](Type t) { return nullptr; });
+ .Default([&](Type t) -> Operation * { return nullptr; });
}
//===----------------------------------------------------------------------===//
@@ -2603,10 +2603,18 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
- tiledOperands.emplace_back(
- getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
- tiledOperands.emplace_back(
- getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
+ Operation *inputSlice =
+ getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+ if (!inputSlice) {
+ return emitOpError("failed to compute input slice");
+ }
+ tiledOperands.emplace_back(inputSlice->getResult(0));
+ Operation *outputSlice =
+ getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+ if (!outputSlice) {
+ return emitOpError("failed to compute output slice");
+ }
+ tiledOperands.emplace_back(outputSlice->getResult(0));
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics())
@@ -2614,7 +2622,9 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{{tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ {inputSlice, outputSlice}};
}
LogicalResult SoftmaxOp::getResultTilePosition(
@@ -2961,8 +2971,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t filterRank = getFilterOperandRank();
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
Location loc = getLoc();
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+ auto filterSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
+ tiledOperands.emplace_back(filterSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -2971,15 +2982,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
@@ -3128,8 +3143,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+ auto inputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
+ tiledOperands.emplace_back(inputSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3138,15 +3154,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
@@ -3290,8 +3310,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
sizes[getValueFDim()]});
int64_t valueRank = getValueOperandRank();
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+ auto valueSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
+ tiledOperands.emplace_back(valueSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
@@ -3300,15 +3321,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
- tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, strides));
+ auto outputSlice = builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, strides);
+ tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp},
+ SmallVector<Value>(tiledOp->getResults()),
+ llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fbff91a94219cc..ca05467534a5b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -120,8 +120,12 @@ struct LinalgOpTilingInterface
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Value> valuesToTile = linalgOp->getOperands();
- SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+ SmallVector<Value> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+ SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+ llvm::make_filter_range(
+ tiledOperands, [](Value v) -> bool { return v.getDefiningOp(); }),
+ [](Value v) -> Operation * { return v.getDefiningOp(); });
SmallVector<Type> resultTensorTypes =
getTensorOutputTypes(linalgOp, tiledOperands);
@@ -129,7 +133,8 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+ return TilingResult{
+ {tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
}
/// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -260,7 +265,8 @@ struct LinalgOpTilingInterface
return TilingResult{
tilingResult->tiledOps,
- SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+ SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+ tilingResult->generatedSlices};
}
/// Method to generate the tiled implementation of an operation from the tile
@@ -406,8 +412,12 @@ struct LinalgOpPartialReductionInterface
}
// Step 2a: Extract a slice of the input operands.
- SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+ SmallVector<Value> tiledInputs = makeTiledShapes(
b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+ SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
+ llvm::make_filter_range(
+ tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
+ [](Value v) -> Operation * { return v.getDefiningOp(); });
// Step 2b: Extract a slice of the init operands.
SmallVector<Value, 1> tiledInits;
@@ -424,6 +434,7 @@ struct LinalgOpPartialReductionInterface
auto extractSlice = b.create<tensor::ExtractSliceOp>(
loc, valueToTile, initOffset, initSizes, initStride);
tiledInits.push_back(extractSlice);
+ generatedSlices.push_back(extractSlice);
}
// Update the indexing maps.
@@ -453,7 +464,8 @@ struct LinalgOpPartialReductionInterface
return TilingResult{
{genericOp.getOperation()},
llvm::map_to_vector(genericOp->getResults(),
- [](OpResult r) -> Value { return r; })};
+ [](OpResult r) -> Value { return r; }),
+ generatedSlices};
}
FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index fa0598dd96885c..6a3f2fc5fbc496 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -565,9 +565,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
}
-static Value materializeTiledShape(OpBuilder &builder, Location loc,
- Value valueToTile,
- const SliceParameters &sliceParams) {
+static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
+ Value valueToTile,
+ const SliceParameters &sliceParams) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
@@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc,
.Default([](ShapedType) -> Operation * {
llvm_unreachable("Unexpected shaped type");
});
- return sliceOp->getResult(0);
+ return sliceOp;
}
-Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
- ArrayRef<OpFoldResult> subShapeSizes,
- bool omitPartialTileCheck) {
+Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> subShapeSizes,
+ bool omitPartialTileCheck) {
SliceParameters sliceParams =
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
ubs, subShapeSizes, omitPartialTileCheck);
@@ -841,6 +842,7 @@ SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
tiledShapes.push_back(
sliceParams.has_value()
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
+ ->getResult(0)
: valueToTile);
}
return tiledShapes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..3729300588422e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/107882
More information about the Mlir-commits
mailing list