[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)

lorenzo chelini llvmlistbot at llvm.org
Thu Oct 19 03:04:30 PDT 2023


================
@@ -728,6 +741,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
                                    getAsOperations(forLoops), replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                                const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 1. Get the range of loops that are represented by the operation.
+  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+
+  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+  SmallVector<OpFoldResult> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [index, tileSize, loopRange] :
+       llvm::enumerate(tileSizeVector, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+
+  // 4. Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+    return op->emitOpError("failed to get destination tensors");
+
+  // 5. Build the device mapping attribute;
+  std::optional<ArrayAttr> mappingAttr;
+  if (!options.mappingVector.empty()) {
----------------
chelini wrote:

[Optional] I would add a test to make sure we preserve the mapping.

https://github.com/llvm/llvm-project/pull/67083


More information about the Mlir-commits mailing list