[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)
lorenzo chelini
llvmlistbot at llvm.org
Thu Oct 19 03:04:30 PDT 2023
================
@@ -728,6 +741,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
getAsOperations(forLoops), replacements};
}
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 1. Get the range of loops that are represented by the operation.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+
+ // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+ // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+ SmallVector<OpFoldResult> tileSizeVector =
+ options.tileSizeComputationFunction(rewriter, op);
+ tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+ // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+
+ // 4. Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+ return op->emitOpError("failed to get destination tensors");
+
+ // 5. Build the device mapping attribute;
+ std::optional<ArrayAttr> mappingAttr;
+ if (!options.mappingVector.empty()) {
----------------
chelini wrote:
[Optional] I would add a test to make sure we preserve the mapping.
https://github.com/llvm/llvm-project/pull/67083
More information about the Mlir-commits
mailing list