[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 21 18:56:36 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

<details>
<summary>Changes</summary>

Similar to `scf::tileUsingSCFForOp` that is a method that tiles
operations that implement the `TilingInterface`, using `scf.for`
operations, this method introduces tiling of operations using
`scf.forall`. Most of this implementation is derived from
`linalg::tileToForallOp` method. Eventually that method will either be
deprecated or moved to use the method introduced here.

---

Patch is 38.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67083.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+19-2) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+5-9) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+203-42) 
- (added) mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir (+37) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp (+117-51) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index ca641c596c7b7bb..06cce19894e9f5a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
     interchangeVector = llvm::to_vector(interchange);
     return *this;
   }
+
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
+    mappingVector = llvm::to_vector(
+        llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+    return *this;
+  }
 };
 
 /// Transformation information returned after tiling.
@@ -60,7 +71,7 @@ struct SCFTilingResult {
   /// of the last op.
   SmallVector<Operation *> tiledOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  SmallVector<Operation *> loops;
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Method to tile and op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                     const SCFTilingOptions &options);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
@@ -160,7 +177,7 @@ struct SCFTileAndFuseResult {
   /// generated operation.
   llvm::SetVector<Operation *> tiledAndFusedOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  SmallVector<Operation *> loops;
   /// The replacement values to use for the tiled and fused operations.
   llvm::DenseMap<Value, Value> replacements;
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1819ca614a060fd..ca3db7401e38caa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
     SmallVector<Operation *> opsToReplace{target};
     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
     for (Operation *toReplace : opsToReplace) {
-      SmallVector<Value> replacements;
-      replacements.reserve(toReplace->getNumResults());
-      for (OpResult res : toReplace->getResults()) {
-        auto it = tiledResults->replacements.find(res);
-        if (it == tiledResults->replacements.end())
-          replacements.push_back(res);
-        else
-          replacements.push_back(it->getSecond());
+      for (OpResult res : toReplace->getResults())
+        if (auto replacement = tiledResults->replacements.lookup(res))
+          rewriter.replaceAllUsesWith(res, replacement);
+      if (toReplace->use_empty()) {
+        rewriter.eraseOp(toReplace);
       }
-      rewriter.replaceOp(toReplace, replacements);
     }
 
     // Report back the relevant handles to the transform op.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 6cfba3fef15ebda..9054f7bcdde7e15 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
   return filledVector;
 }
 
+/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
+template <typename SrcOpTy>
+static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
+  return llvm::to_vector(
+      llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
+}
+template <typename SrcOpTy>
+static SmallVector<Operation *>
+getAsOperations(const SmallVector<SrcOpTy> &ops) {
+  return getAsOperations(ArrayRef<SrcOpTy>(ops));
+}
+
+/// Convert a list of `Operation *` to a list of `DstOpTy`
+template <typename DstOpTy>
+static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
+  return llvm::to_vector(
+      llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
+}
+template <typename DstOpTy>
+static SmallVector<DstOpTy>
+castToTypedOperations(const SmallVector<Operation *> &ops) {
+  return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
+}
+
 //===----------------------------------------------------------------------===//
 // tileUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
                                        Range loopRange, Value iv,
-                                       Value tileSize) {
-  std::optional<int64_t> ts = getConstantIntValue(tileSize);
-  if (ts && ts.value() == 1)
-    return getAsOpFoldResult(tileSize);
+                                       OpFoldResult tileSize) {
+  if (isConstantIntValue(tileSize, 1))
+    return tileSize;
 
   if (tileDividesIterationDomain(
           Range{loopRange.offset, loopRange.size, tileSize}))
@@ -98,6 +121,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
 }
 
+/// Clones the operation and updates the destination if the operation
+/// implements the `DestinationStyleOpInterface`.
+static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
+                                                  Operation *op,
+                                                  ValueRange newDestArgs) {
+  Operation *clonedOp = rewriter.clone(*op);
+  if (auto destinationStyleOp =
+          dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+    // Note that this is assuming that
+    auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
+    assert((end - start == newDestArgs.size()) &&
+           "expected as many new destination args as number of inits of the "
+           "operation");
+    clonedOp->setOperands(start, end - start, newDestArgs);
+  }
+  return clonedOp;
+}
+
 /// Generate an empty loop nest that represents the tiled loop nest shell.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -295,8 +336,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
   }
 
-  scf::SCFTilingResult tilingResult;
   SmallVector<OpFoldResult> offsets, sizes;
+  SmallVector<scf::ForOp> forLoops;
   {
     // If there is an interchange specified, permute the iteration domain and
     // the tile sizes.
@@ -319,8 +360,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     // 3. Materialize an empty loop nest that iterates over the tiles. These
     // loops for now do not return any values even if the original operation has
     // results.
-    tilingResult.loops = generateTileLoopNest(
-        rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+    forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
+                                    tileSizeVector, offsets, sizes);
 
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -330,30 +371,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   }
 
   LLVM_DEBUG({
-    if (!tilingResult.loops.empty()) {
+    if (!forLoops.empty()) {
       llvm::dbgs() << "LoopNest shell :\n";
-      tilingResult.loops.front().dump();
+      forLoops.front().dump();
       llvm::dbgs() << "\n";
     }
   });
 
   // 4. Generate the tiled implementation within the inner most loop.
-  if (!tilingResult.loops.empty())
-    rewriter.setInsertionPoint(
-        tilingResult.loops.back().getBody()->getTerminator());
+  if (!forLoops.empty())
+    rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
   FailureOr<TilingResult> tiledImplementation =
       op.getTiledImplementation(rewriter, offsets, sizes);
-  tilingResult.tiledOps.append(tiledImplementation->tiledOps);
+
   if (op->getNumResults() == 0) {
-    // nothing more to do.
-    return tilingResult;
+    return scf::SCFTilingResult{
+        tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
   }
 
   // If loops are empty, the tiled op is used as the replacement for the untiled
   // op.
-  if (tilingResult.loops.empty()) {
-    tilingResult.replacements = tiledImplementation->tiledValues;
-    return tilingResult;
+  if (forLoops.empty()) {
+    return scf::SCFTilingResult{tiledImplementation->tiledOps,
+                                getAsOperations(forLoops),
+                                tiledImplementation->tiledValues};
   }
 
   // 5. Yield all the results of the tiled operation. The surrounding loop
@@ -377,18 +418,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
                                              destinationTensors)))
     return rewriter.notifyMatchFailure(op, "failed to get destinations");
 
-  tilingResult.replacements = yieldTiledValues(
+  SmallVector<Value> replacements = yieldTiledValues(
       rewriter, destinationTensors, tiledImplementation.value(),
-      resultOffsetsList, resultSizesList, tilingResult.loops);
-
+      resultOffsetsList, resultSizesList, forLoops);
   LLVM_DEBUG({
-    if (!tilingResult.loops.empty()) {
+    if (!forLoops.empty()) {
       llvm::dbgs() << "After tiled implementation :\n";
-      tilingResult.loops.front().dump();
+      forLoops.front().dump();
       llvm::dbgs() << "\n";
     }
   });
-  return tilingResult;
+  return scf::SCFTilingResult{tiledImplementation->tiledOps,
+                              getAsOperations(forLoops), replacements};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -466,6 +507,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   results.mergeOp = mergeOp;
   return results;
 }
+
 //===----------------------------------------------------------------------===//
 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
@@ -636,7 +678,9 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   }
 
   // 1. First tile the consumer.
-  scf::SCFTileAndFuseResult tileAndFuseResult;
+  SmallVector<scf::ForOp> forLoops;
+  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
+  DenseMap<Value, Value> replacements;
   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
   {
     FailureOr<scf::SCFTilingResult> tilingResult =
@@ -644,20 +688,21 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     if (failed(tilingResult))
       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
     for (auto *tiledOp : tilingResult->tiledOps)
-      tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
-    tileAndFuseResult.loops = std::move(tilingResult->loops);
-    for (const auto &result : llvm::enumerate(
-             llvm::zip(consumer->getResults(), tilingResult->replacements))) {
-      tileAndFuseResult.replacements[std::get<0>(result.value())] =
-          std::get<1>(result.value());
+      tiledAndFusedOps.insert(tiledOp);
+    forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
+    for (auto [index, origValue, replacement] :
+         llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
+      replacements[origValue] = replacement;
       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
-          result.index())] = result.index();
+          index)] = index;
     }
   }
 
   // If there are no loops generated, fusion is immaterial.
-  if (tileAndFuseResult.loops.empty())
-    return tileAndFuseResult;
+  if (forLoops.empty()) {
+    return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+                                     getAsOperations(forLoops), replacements};
+  }
 
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
@@ -674,7 +719,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   };
 
   std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tiledAndFusedOps.back(), candidates);
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
     // Traverse the slices in BFS fashion.
@@ -684,19 +729,135 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
-    std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
-                                   tileAndFuseResult.loops);
-    if (!fusedProducer)
+    std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
+        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
+    if (!fusedResult)
       continue;
 
     if (Operation *tiledAndFusedOp =
-            fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
-      tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
+            fusedResult->tiledAndFusedProducer.getDefiningOp()) {
+      fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
+      tiledAndFusedOps.insert(tiledAndFusedOp);
       addCandidateSlices(tiledAndFusedOp, candidates);
     }
   }
-  return tileAndFuseResult;
+  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+                                   getAsOperations(forLoops), replacements};
+}
+
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                                const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 1. Get the range of loops that are represented by the operation.
+  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+
+  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+  SmallVector<OpFoldResult> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [index, tileSize, loopRange] :
+       llvm::enumerate(tileSizeVector, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+
+  // 4. Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+    return op->emitOpError("failed to get destination tensors");
+
+  // 5. Build the device mapping attribute;
+  std::optional<ArrayAttr> mappingAttr;
+  if (!options.mappingVector.empty()) {
+    mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+  }
+
+  // 6. Create the ForallOp. We don't use the lambda body-builder
+  // version because we require the use of RewriterBase in the body, so we
+  // manually move the insertion point to the body below.
+  auto forallOp =
+      rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+  // 7. Get the tile offset and sizes.
+  rewriter.setInsertionPoint(forallOp.getTerminator());
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  tiledOffsets.reserve(loopRanges.size());
+  tiledSizes.reserve(loopRanges.size());
+  ValueRange ivs = forallOp.getInductionVars();
+  {
+    int materializedLoopNum = 0;
+    for (auto [index, tileSize, loopRange] :
+         llvm::enumerate(tileSizeVector, loopRanges)) {
+      if (isConstantIntValue(tileSize, 0)) {
+        tiledOffsets.push_back(loopRange.offset);
+        tiledSizes.push_back(loopRange.size);
+        continue;
+      }
+      Value iv = ivs[materializedLoopNum++];
+      tiledOffsets.push_back(iv);
+      tiledSizes.push_back(
+          getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+    }
+  }
+
+  // 8. Tile the operation. Clone the operation to allow fix up of destination
+  // operands
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  Operation *clonedOp =
+      cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+  FailureOr<TilingResult> tilingResult =
+      cast<TilingInterface>(clonedOp).getTiledImplementation(
+          rewriter, tiledOffsets, tiledSizes);
+  if (failed(tilingResult))
+    return clonedOp->emitError("Failed to tile op: ");
+  rewriter.eraseOp(clonedOp);
+
+  // 9. Parallel insert back into the result tensor.
+  for (auto [index, tiledValue, destBBArg] :
+       llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+    // 9.a. Partial subset information is inserted just before the terminator.
+    rewriter.setInsertionPoint(forallOp.getTerminator());
+
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+                                        tiledSizes, resultOffsets,
+                                        resultSizes)))
+      return op->emitOpError("output offsets couldn't be calculated");
+    SmallVector<OpFoldResult> strides(resultSizes.size(),
+                                      rewriter.getIndexAttr(1));
+
+    // 5.b. Parallel insertions are inserted at the end of the combining
+    // terminator.
+    rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+  }
+
+  // 10. Return the tiling result;
+  return scf::SCFTilingResult{
+      tilingResult->tiledOps,
+      {forallOp.getOperation()},
+      llvm::to_vector(llvm::map_range(forallOp.getResults(),
+                                      [](auto val) -> Value { return val; }))};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
new file mode 100644
index 000000000000000..f40374b7b5485da
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67083


More information about the Mlir-commits mailing list