[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #85528)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 16 07:22:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.

- Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position.
- Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.

---

Patch is 40.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85528.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+29-19) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+55) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+307-62) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+75-21) 
- (modified) mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (+124-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bdeab55091b9f3..2c501a3ecb14f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -310,51 +310,61 @@ def FuseIntoContainingOp :
           ["allowsRepeatedHandleOperands"]>,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        ReportTrackingListenerFailuresOpTrait]> {
-  let summary = "Fuse a producer into a containing operation.";
+  let summary = "Fuse a target into a containing operation.";
 
   let description = [{
-    Fuses the `producer_op` into the `containing_op`.
+    Fuses the `target_op` into the `containing_op`.
     Returns a handle to the fused ops and the `new_containing_op`.
 
-    The producer is typically a slice of a tileable op (i.e., implements
-    TilingInterface). In that case, this transform computes the accessed
-    producer slice inside of the containing op ("tile and fuse") and if required,
-    creates a new containing op with outputs from the fused producer. Otherwise,
-    the entire producer is cloned inside the containing op ("clone and fuse").
+    This operation supports fusion of producer or fusion of consumer. We will
+    refer to the value connecting the containing operation and the target
+    operation as the "bridge" below.
+
+    When fuse producer, the bridge is typically a slice of a tileable op (i.e.,
+    implements TilingInterface). In that case, this transform computes the
+    accessed bridge slice inside of the containing op ("tile and fuse") and
+    if required, creates a new containing op with outputs from the fused target.
+    Otherwise, the entire target is cloned inside the containing op ("clone
+    and fuse").
+
+    When fuse consumer, the bridge is the result of containing op and a operand
+    of a tileable op (i.2., implements TilingInterface). In this case, this
+    transform computes the access bridge slice inside the containing op ("tile
+    and fuse") and creates a new containing op with consumer's output.
 
     The containing op handle must be associated with exactly one payload op. The
-    producer op handle may be associated with multiple payload ops. This
-    transform fuses producers one-by-one, always picking an unspecified producer
+    target op handle may be associated with multiple payload ops. This
+    transform fuses targets one-by-one, always picking an unspecified target
     that has at least one use inside the containing op among the
-    producers. A producer can be listed multiple times in the handle.
+    targets. A target can be listed multiple times in the handle.
 
-    Note: If a producer has multiple uses inside the containing op, it is
+    Note: If a target has multiple uses inside the containing op, it is
     currently tiled and/or cloned multiple times into the containing op.
     TODO: Reuse already fused OpResults instead of tiling/cloning a second time
-    when possible. Fuse producers according to a topological sorting to achieve
+    when possible. Fuse targets according to a topological sorting to achieve
     the largest amount of reuse.
 
     #### Return modes
 
-    If at least one producer could not be fused, this operation produces a
+    If at least one target could not be fused, this operation produces a
     silenceable failure.  This is the case when tiling fails or when no
-    producer op could be found among the remaining producers that has at least
-    one use within the containing op. I.e., "producers" that are not consumed
+    target op could be found among the remaining targets that has at least
+    one use within the containing op. I.e., "targets" that are not consumed
     within the containing op are rejected by this operation.
 
-    This operation consumes the producer handle.
+    This operation consumes the target handle.
     This operation only reads the containing op handle.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$producer_op,
+  let arguments = (ins TransformHandleTypeInterface:$target_op,
                        TransformHandleTypeInterface:$containing_op);
   let results = (outs TransformHandleTypeInterface:$fused_op,
                       TransformHandleTypeInterface:$new_containing_op);
-  let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
+  let assemblyFormat = "$target_op `into` $containing_op attr-dict "
                        " `:` functional-type(operands, results)";
 
   let builders = [
-    OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+    OpBuilder<(ins "Value":$targetOp, "Value":$containingOp)>
   ];
 }
 
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..ecffb910b236e8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -546,9 +546,9 @@ LogicalResult transform::FuseOp::verify() {
 
 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
                                             OperationState &result,
-                                            Value producerOp,
+                                            Value targetOp,
                                             Value containingOp) {
-  result.addOperands({producerOp, containingOp});
+  result.addOperands({targetOp, containingOp});
   auto resultType = transform::AnyOpType::get(builder.getContext());
   result.addTypes({resultType, resultType});
 }
@@ -631,6 +631,223 @@ static Operation *replaceForAllWithNewSignature(
   return newforallOp;
 }
 
+static std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseParallelInsertSlice(RewriterBase &rewriter, Diagnostic &diag,
+                               Operation *consumerOp, Operation *containingOp) {
+  // Check consumer has tiling interface.
+  LLVM_DEBUG(DBGS() << "Try to fuse a consumer\n");
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer is not a TileableInterface: " << *consumerOp;
+    return {};
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    diag.attachNote(containingOp->getLoc())
+        << "containing op is not a scf.forall: " << containingOp;
+    return {};
+  }
+
+  // Check dominance.
+  DominanceInfo domInfo(
+      containingOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>());
+  if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+        return v.getDefiningOp() != containingOp &&
+               !domInfo.properlyDominates(v, containingOp);
+      })) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer's operand can't dominate containing op";
+    return {};
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        diag.attachNote(consumerOp->getLoc())
+            << "consumer's operand use more than one containingOp's result";
+        return {};
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer op should have destination style op interface";
+    return {};
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer op take result of scf.forall as init";
+    return {};
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+  for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+    auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+      if (!targetInsertOp) {
+        targetInsertOp = parallelInsertSliceOp;
+      } else {
+        diag.attachNote(containingOp->getLoc())
+            << "containingOp's result inserted multi time";
+        return {};
+      }
+    }
+  }
+
+  if (!targetInsertOp) {
+    diag.attachNote(containingOp->getLoc())
+        << "containingOp's result was not inserted";
+    return {};
+  }
+
+  SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    diag.attachNote(containingOp->getLoc())
+        << "containingOp's result yield with stride";
+    return {};
+  }
+
+  Location loc = forallOp.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    diag.attachNote(consumerOp->getLoc())
+        << "can't get iter domain position from input position";
+    return {};
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      diag.attachNote(consumerOp->getLoc())
+          << "can't get result domain position from iter domain position";
+      return {};
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    diag.attachNote(consumerOp->getLoc()) << "get tiled implementation failed";
+    return {};
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    diag.attachNote(tileableConsumer->getLoc())
+        << "failed to tile consumer op: " << *tileableConsumer;
+    return {};
+  }
+
+  // Replace tiled op's operand .
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(forallOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) {
+    newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+    auto bbArgs = newforallOp.getBody()->getArguments();
+    rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+      Operation *op = use.getOwner();
+      return newforallOp->isProperAncestor(op);
+    });
+  }
+
+  // Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOp->getLoc(), v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultPositions[idx].first, resultPositions[idx].second, strides);
+  }
+
+  // Replace the result of forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  return std::make_tuple(tileAndFuseResult->tiledOps, newforallOp);
+}
+
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
@@ -880,7 +1097,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
                                        transform::TransformResults &results,
                                        transform::TransformState &state) {
   SmallVector<Operation *> fusedOps;
-  auto producerOps = state.getPayloadOps(getProducerOp());
+  auto targetOps = state.getPayloadOps(getTargetOp());
   auto containingOps = state.getPayloadOps(getContainingOp());
   if (!llvm::hasSingleElement(containingOps)) {
     return emitDefiniteFailure()
@@ -890,69 +1107,115 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
   Operation *containingOp = *containingOps.begin();
 
   // If nothing to fuse, propagate success.
-  if (std::empty(producerOps)) {
+  if (std::empty(targetOps)) {
     results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
     results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
     return DiagnosedSilenceableFailure::success();
   }
 
-  // Helper function to find the next producer that should be fused. Take any
-  // producer that has a use inside the containing op.
-  SetVector<Operation *> remainingProducers(producerOps.begin(),
-                                            producerOps.end());
-  auto getNextProducer = [&]() -> FailureOr<Operation *> {
-    for (const auto &it : enumerate(remainingProducers)) {
-      Operation *producerOp = it.value();
-      // The containing op may be a user of producerOp: use isAncestor.
+  // Helper function to find the next target that should be fused. Take any
+  // target that has a use inside the containing op. Return target operation
+  // and a bool variable indicate if this target op is a producer.
+  SetVector<Operation *> remainingTargets(targetOps.begin(), targetOps.end());
+  auto getNextTarget = [&]() -> FailureOr<std::pair<Operation *, bool>> {
+    for (const auto &it : enumerate(remainingTargets)) {
+      Operation *targetOp = it.value();
+      // The containing op may be a user of targetOp: use isAncestor.
       int64_t numUsesInContainingOp =
-          llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+          llvm::count_if(targetOp->getUsers(), [&](Operation *op) {
             return containingOp->isAncestor(op);
           });
       // TODO: When resolving the TODO below (no duplicate ops), take an op
-      // that has no use among the remaining producers. This is a topological
+      // that has no use among the remaining targets. Th...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/85528


More information about the Mlir-commits mailing list