[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 03:58:13 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

<details>
<summary>Changes</summary>

-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>

---

Patch is 41.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88712.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+13) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+55) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+75-21) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+511) 
- (added) mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir (+110) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+53) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..75b48a2cdd8dc3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+  Operation *origConsumer;     // Original untiled consumer.
+  Value tiledAndFusedConsumer; // Tile and fused consumer value.
+  SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+                           bool useSCFFor);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
+  void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+                              AffineMap indexingMap,
+                              ArrayRef<OpFoldResult> offsets,
+                              ArrayRef<OpFoldResult> sizes,
+                              SmallVector<OpFoldResult> &mappedOffsets,
+                              SmallVector<OpFoldResult> &mappedSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[range.index()] = range.value().offset;
+        mappedSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+      mappedSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult getIterDomainTilePositionFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled get iter domain position when operand is not "
+          "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
   // Return the details of the output tile generated by the tiled
   // implementation.
   LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return op->emitOpError(
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+                           mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..45c8f8362ad581 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -798,6 +799,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 // tileConsumerAndFuseProducersUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *source,
+                                         unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+  auto iterArg = dyn_cast<BlockArgument>(source->get());
+  Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -820,6 +874,463 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    llvm::outs() << "containing op is not a scf.forall: " << containingOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result"
+            << "\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface"
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.forall as init"
+                 << "\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+  for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+    auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+      if (!targetInsertOp) {
+        targetInsertOp = parallelInsertSliceOp;
+      } else {
+        llvm::outs() << "containingOp's result inserted multi time"
+                     << "\n";
+        return failure();
+      }
+    }
+  }
+
+  if (!targetInsertOp) {
+    llvm::outs() << "containingOp's result was not inserted"
+                 << "\n";
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    llvm::outs() << "containingOp's result yield with stride"
+                 << "\n";
+    return failure();
+  }
+
+  Location loc = forallOp.getLoc();
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    llvm::outs() << "can't get iter domain position from input position"
+                 << "\n";
+    return failure();
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      llvm::outs()
+          << "can't get result domain position from iter domain position"
+          << "\n";
+      return failure();
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    llvm::outs() << "get tiled implementation failed"
+                 << "\n";
+    return failure();
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+    return failure();
+  }
+
+  // Replace tiled op's operand.
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list