[Mlir-commits] [mlir] [mlir][SCF] Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` (PR #91878)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri May 24 02:47:15 PDT 2024
================
@@ -223,32 +451,40 @@ static LogicalResult generateLoopNestUsingForOp(
/// populated.
static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
- ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
- SmallVector<LoopLikeOpInterface> &loops) {
- SmallVector<OpFoldResult> lbs, ubs, steps;
+ ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
+ YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> offsets(loopRanges.size()),
sizes(loopRanges.size());
- for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
- if (isConstantIntValue(tileSize, 0))
- continue;
- lbs.push_back(loopRange.offset);
- ubs.push_back(loopRange.size);
- steps.push_back(tileSize);
- }
- assert(!lbs.empty() && "Expected at least one loop range");
-
std::optional<ArrayAttr> mappingAttr;
if (!mappingVector.empty())
mappingAttr = rewriter.getArrayAttr(mappingVector);
- auto forallOp = rewriter.create<scf::ForallOp>(
- loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+ scf::ForallOp forallOp;
+ bool useNumThreads = !numThreads.empty();
+
+ if (useNumThreads) {
+ // Prune the zero numthreads.
+ SmallVector<OpFoldResult> nonZeroNumThreads;
----------------
nicolasvasilache wrote:
Something with `filter_range` should let us have a 1-liner here.
Still feels like common indexing utils should exist / be added for this: this is similar to `lbs = 0, steps =1, ubs = numThreads`.
https://github.com/llvm/llvm-project/pull/91878
More information about the Mlir-commits
mailing list