[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 17 09:05:10 PST 2024
================
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<OpFoldResult> tileSizes,
+ SmallVector<Value> &initTensors,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ FailureOr<SmallVector<Value>> maybeInitTensors =
+ redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+ reductionDims);
+ if (failed(maybeInitTensors)) {
+ return failure();
+ }
+ initTensors = maybeInitTensors.value();
+ return success();
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+ ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ const scf::SCFTilingOptions &options) {
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getTiledImplementation(rewriter, offsets, sizes);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+ offsets, sizes, reductionDims);
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+ TilingInterface op, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffset,
+ SmallVector<OpFoldResult> &resultSize,
+ const scf::SCFTilingOptions &options) {
+
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffset, resultSize);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
----------------
MaheshRavishankar wrote:
I think this wont work for an operation like this
```
%result = linalg.generic {
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = [...., affine_map<(d0, d1, d2) -> (d1, d0)>]}
```
but like you said that is the problem with current implementation anyway.
https://github.com/llvm/llvm-project/pull/120115
More information about the Mlir-commits
mailing list