[Mlir-commits] [mlir] [mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (PR #120115)
Kunwar Grover
llvmlistbot at llvm.org
Tue Dec 17 06:40:02 PST 2024
================
@@ -570,6 +570,146 @@ static LogicalResult generateLoopNest(
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
+static LogicalResult
+createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<OpFoldResult> tileSizes,
+ SmallVector<Value> &initTensors,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return tensor::getOrCreateDestinations(rewriter, loc, op, initTensors);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ FailureOr<SmallVector<Value>> maybeInitTensors =
+ redOp.generateInitialTensorForPartialReduction(rewriter, loc, tileSizes,
+ reductionDims);
+ if (failed(maybeInitTensors)) {
+ return failure();
+ }
+ initTensors = maybeInitTensors.value();
+ return success();
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static FailureOr<TilingResult>
+getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
+ ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ const scf::SCFTilingOptions &options) {
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getTiledImplementation(rewriter, offsets, sizes);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
+ PartialReductionOuterReduction: {
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only "
+ "supported for operations "
+ "implementing PartialReductionOpInterface");
+ }
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
+ offsets, sizes, reductionDims);
+ }
+ default:
+ return rewriter.notifyMatchFailure(op,
+ "unhandled reduction tiling strategy");
+ }
+}
+
+static LogicalResult
+getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
+ TilingInterface op, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffset,
+ SmallVector<OpFoldResult> &resultSize,
+ const scf::SCFTilingOptions &options) {
+
+ switch (options.reductionStrategy) {
+ case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
+ return op.getResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffset, resultSize);
+ case scf::SCFTilingOptions::ReductionTilingStrategy::
----------------
Groverkss wrote:
No getResultTilePosition will not work here (I tried it). In the original operation, this dimension was a reduction dimension. So if I use getResultTilePosition, it will ignore this dimension and will not produce the correct result tile.
The current implementation will only work if we are only tiling reduction dimensions, because it ignores tiling of parallel dimensions.
But this is really just a copy paste of the old implementation, I'm not adding anything new. This needs to be fixed in PartialReductionOpInterface. I have some ideas, but they are beyond the scope of this patch, since this is just code motion.
The correct way to do this would be add a function to PartialReductionOpInterface "getResultPartialTilePosition" to get the correct position. But until then, we can't do much here. I was thinking of sending a followup patch to fix that.
https://github.com/llvm/llvm-project/pull/120115
More information about the Mlir-commits
mailing list