[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #85528)

donald chen llvmlistbot at llvm.org
Wed Mar 27 21:29:29 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/85528

>From 996d7a9854d1d23a46f3a2ccbd25e1d8e3591776 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH] [mlir][linalg] Enable fuse consumer

This patch adds support for consumer fusion to the tiling interface, and
implements fuse consumers on FuseIntoContainingOp.

- Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling
  interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandPosition' to tiling
  interface which generate tiled implementation according to operand position.
- Implemented the above two methods and supported consumer fusion for
  FuseIntoContainingOp.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  48 ++-
 .../mlir/Interfaces/TilingInterface.td        |  55 +++
 .../TransformOps/LinalgTransformOps.cpp       | 369 +++++++++++++++---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  96 ++++-
 .../transform-op-fuse-into-containing.mlir    | 125 +++++-
 5 files changed, 590 insertions(+), 103 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4f34016066b4ce..a46096dc167f2c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -310,51 +310,61 @@ def FuseIntoContainingOp :
           ["allowsRepeatedHandleOperands"]>,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        ReportTrackingListenerFailuresOpTrait]> {
-  let summary = "Fuse a producer into a containing operation.";
+  let summary = "Fuse a target into a containing operation.";
 
   let description = [{
-    Fuses the `producer_op` into the `containing_op`.
+    Fuses the `target_op` into the `containing_op`.
     Returns a handle to the fused ops and the `new_containing_op`.
 
-    The producer is typically a slice of a tileable op (i.e., implements
-    TilingInterface). In that case, this transform computes the accessed
-    producer slice inside of the containing op ("tile and fuse") and if required,
-    creates a new containing op with outputs from the fused producer. Otherwise,
-    the entire producer is cloned inside the containing op ("clone and fuse").
+    This operation supports fusion of producer or fusion of consumer. We will
+    refer to the value connecting the containing operation and the target
+    operation as the "bridge" below.
+
+    When fuse producer, the bridge is typically a slice of a tileable op (i.e.,
+    implements TilingInterface). In that case, this transform computes the
+    accessed bridge slice inside of the containing op ("tile and fuse") and
+    if required, creates a new containing op with outputs from the fused target.
+    Otherwise, the entire target is cloned inside the containing op ("clone
+    and fuse").
+
+    When fuse consumer, the bridge is the result of containing op and a operand
+    of a tileable op (i.e., implements TilingInterface). In this case, this
+    transform computes the access bridge slice inside the containing op ("tile
+    and fuse") and creates a new containing op with consumer's output.
 
     The containing op handle must be associated with exactly one payload op. The
-    producer op handle may be associated with multiple payload ops. This
-    transform fuses producers one-by-one, always picking an unspecified producer
+    target op handle may be associated with multiple payload ops. This
+    transform fuses targets one-by-one, always picking an unspecified target
     that has at least one use inside the containing op among the
-    producers. A producer can be listed multiple times in the handle.
+    targets. A target can be listed multiple times in the handle.
 
-    Note: If a producer has multiple uses inside the containing op, it is
+    Note: If a target has multiple uses inside the containing op, it is
     currently tiled and/or cloned multiple times into the containing op.
     TODO: Reuse already fused OpResults instead of tiling/cloning a second time
-    when possible. Fuse producers according to a topological sorting to achieve
+    when possible. Fuse targets according to a topological sorting to achieve
     the largest amount of reuse.
 
     #### Return modes
 
-    If at least one producer could not be fused, this operation produces a
+    If at least one target could not be fused, this operation produces a
     silenceable failure.  This is the case when tiling fails or when no
-    producer op could be found among the remaining producers that has at least
-    one use within the containing op. I.e., "producers" that are not consumed
+    target op could be found among the remaining targets that has at least
+    one use within the containing op. I.e., "targets" that are not consumed
     within the containing op are rejected by this operation.
 
-    This operation consumes the producer handle.
+    This operation consumes the target handle.
     This operation only reads the containing op handle.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$producer_op,
+  let arguments = (ins TransformHandleTypeInterface:$target_op,
                        TransformHandleTypeInterface:$containing_op);
   let results = (outs TransformHandleTypeInterface:$fused_op,
                       TransformHandleTypeInterface:$new_containing_op);
-  let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
+  let assemblyFormat = "$target_op `into` $containing_op attr-dict "
                        " `:` functional-type(operands, results)";
 
   let builders = [
-    OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+    OpBuilder<(ins "Value":$targetOp, "Value":$containingOp)>
   ];
 }
 
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ecf9983124821a..d067888c17a915 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -546,9 +546,9 @@ LogicalResult transform::FuseOp::verify() {
 
 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
                                             OperationState &result,
-                                            Value producerOp,
+                                            Value targetOp,
                                             Value containingOp) {
-  result.addOperands({producerOp, containingOp});
+  result.addOperands({targetOp, containingOp});
   auto resultType = transform::AnyOpType::get(builder.getContext());
   result.addTypes({resultType, resultType});
 }
@@ -631,6 +631,223 @@ static Operation *replaceForAllWithNewSignature(
   return newforallOp;
 }
 
+static std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseParallelInsertSlice(RewriterBase &rewriter, Diagnostic &diag,
+                               Operation *consumerOp, Operation *containingOp) {
+  // Check consumer has tiling interface.
+  LLVM_DEBUG(DBGS() << "Try to fuse a consumer\n");
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer is not a TileableInterface: " << *consumerOp;
+    return {};
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    diag.attachNote(containingOp->getLoc())
+        << "containing op is not a scf.forall: " << containingOp;
+    return {};
+  }
+
+  // Check dominance.
+  DominanceInfo domInfo(
+      containingOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>());
+  if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+        return v.getDefiningOp() != containingOp &&
+               !domInfo.properlyDominates(v, containingOp);
+      })) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer's operand can't dominate containing op";
+    return {};
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        diag.attachNote(consumerOp->getLoc())
+            << "consumer's operand use more than one containingOp's result";
+        return {};
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer op should have destination style op interface";
+    return {};
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer op take result of scf.forall as init";
+    return {};
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+  for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+    auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+      if (!targetInsertOp) {
+        targetInsertOp = parallelInsertSliceOp;
+      } else {
+        diag.attachNote(containingOp->getLoc())
+            << "containingOp's result inserted multi time";
+        return {};
+      }
+    }
+  }
+
+  if (!targetInsertOp) {
+    diag.attachNote(containingOp->getLoc())
+        << "containingOp's result was not inserted";
+    return {};
+  }
+
+  SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    diag.attachNote(containingOp->getLoc())
+        << "containingOp's result yield with stride";
+    return {};
+  }
+
+  Location loc = forallOp.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    diag.attachNote(consumerOp->getLoc())
+        << "can't get iter domain position from input position";
+    return {};
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      diag.attachNote(consumerOp->getLoc())
+          << "can't get result domain position from iter domain position";
+      return {};
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    diag.attachNote(consumerOp->getLoc()) << "get tiled implementation failed";
+    return {};
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    diag.attachNote(tileableConsumer->getLoc())
+        << "failed to tile consumer op: " << *tileableConsumer;
+    return {};
+  }
+
+  // Replace tiled op's operand.
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(forallOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) {
+    newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+    auto bbArgs = newforallOp.getBody()->getArguments();
+    rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+      Operation *op = use.getOwner();
+      return newforallOp->isProperAncestor(op);
+    });
+  }
+
+  // Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOp->getLoc(), v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultPositions[idx].first, resultPositions[idx].second, strides);
+  }
+
+  // Replace the result of forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  return std::make_tuple(tileAndFuseResult->tiledOps, newforallOp);
+}
+
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
@@ -880,7 +1097,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
                                        transform::TransformResults &results,
                                        transform::TransformState &state) {
   SmallVector<Operation *> fusedOps;
-  auto producerOps = state.getPayloadOps(getProducerOp());
+  auto targetOps = state.getPayloadOps(getTargetOp());
   auto containingOps = state.getPayloadOps(getContainingOp());
   if (!llvm::hasSingleElement(containingOps)) {
     return emitDefiniteFailure()
@@ -890,69 +1107,115 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
   Operation *containingOp = *containingOps.begin();
 
   // If nothing to fuse, propagate success.
-  if (std::empty(producerOps)) {
+  if (std::empty(targetOps)) {
     results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
     results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
     return DiagnosedSilenceableFailure::success();
   }
 
-  // Helper function to find the next producer that should be fused. Take any
-  // producer that has a use inside the containing op.
-  SetVector<Operation *> remainingProducers(producerOps.begin(),
-                                            producerOps.end());
-  auto getNextProducer = [&]() -> FailureOr<Operation *> {
-    for (const auto &it : enumerate(remainingProducers)) {
-      Operation *producerOp = it.value();
-      // The containing op may be a user of producerOp: use isAncestor.
+  // Helper function to find the next target that should be fused. Take any
+  // target that has a use inside the containing op. Return target operation
+  // and a bool variable indicate if this target op is a producer.
+  SetVector<Operation *> remainingTargets(targetOps.begin(), targetOps.end());
+  auto getNextTarget = [&]() -> FailureOr<std::pair<Operation *, bool>> {
+    for (const auto &it : enumerate(remainingTargets)) {
+      Operation *targetOp = it.value();
+      // The containing op may be a user of targetOp: use isAncestor.
       int64_t numUsesInContainingOp =
-          llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+          llvm::count_if(targetOp->getUsers(), [&](Operation *op) {
             return containingOp->isAncestor(op);
           });
       // TODO: When resolving the TODO below (no duplicate ops), take an op
-      // that has no use among the remaining producers. This is a topological
+      // that has no use among the remaining targets. This is a topological
       // sorting.
       if (numUsesInContainingOp > 0) {
         if (numUsesInContainingOp == 1)
-          remainingProducers.erase(remainingProducers.begin() + it.index());
-        return producerOp;
+          remainingTargets.erase(remainingTargets.begin() + it.index());
+        return std::make_pair(targetOp, true);
+      }
+
+      // The containing op may be a producer of targetOp: use getDefinitionOp.
+      if (llvm::any_of(
+              targetOp->getOperands(),
+              [&](Value v) { return v.getDefiningOp() == containingOp; }) &&
+          !targetOp->getUses().empty()) {
+        remainingTargets.erase(remainingTargets.begin() + it.index());
+        return std::make_pair(targetOp, false);
       }
     }
     return failure();
   };
 
-  while (!remainingProducers.empty()) {
-    auto nextProducer = getNextProducer();
-    if (failed(nextProducer)) {
+  while (!remainingTargets.empty()) {
+    auto nextTarget = getNextTarget();
+    if (failed(nextTarget)) {
       auto diag = mlir::emitSilenceableFailure(getLoc())
-                  << "could not find next producer to fuse into container";
+                  << "could not find next target to fuse into container";
       diag.attachNote(containingOp->getLoc()) << "containing op";
       return diag;
     }
 
-    Operation *producerOp = *nextProducer;
+    Operation *targetOp = nextTarget->first;
 
     // Default diagnostic, to be complemented with more failure information.
-    Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
-    diag << "could not fuse " << *producerOp << " into " << *containingOp;
-
-    // TODO: If there are multiple uses of the producer in the containing op,
-    // we currently tile/clone the op multiple times (once per use). In some
-    // cases, we can tile/clone once and reuse the value for each use.
-    // Futhermore, producers should then be traversed according to a
-    // topological sorting.
-    auto [tiledOps, newContainingOp] =
-        tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
-    if (!tiledOps.empty()) {
-      LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
-      fusedOps.append(tiledOps);
-      if (newContainingOp) {
-        // Update handles associated with the containing op so we don't need to
-        // invalidate them. This is a hack to support better composability
-        // between tiling and fusion while a proper mechanism is being
-        // investigated.
-        //
-        // DO NOT replicate this elsewhere unless you understand what you are
-        // doing.
+    Diagnostic diag(targetOp->getLoc(), DiagnosticSeverity::Remark);
+    diag << "could not fuse " << *targetOp << " into " << *containingOp;
+
+    if (nextTarget->second) {
+      // Fuse producer.
+      // TODO: If there are multiple uses of the target in the containing op,
+      // we currently tile/clone the op multiple times (once per use). In some
+      // cases, we can tile/clone once and reuse the value for each use.
+      // Futhermore, targets should then be traversed according to a
+      // topological sorting.
+      auto [tiledOps, newContainingOp] =
+          tileAndFuseFirstExtractUse(rewriter, diag, targetOp, containingOp);
+      if (!tiledOps.empty()) {
+        LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
+        fusedOps.append(tiledOps);
+        if (newContainingOp) {
+          // Update handles associated with the containing op so we don't need
+          // to invalidate them. This is a hack to support better composability
+          // between tiling and fusion while a proper mechanism is being
+          // investigated.
+          //
+          // DO NOT replicate this elsewhere unless you understand what you are
+          // doing.
+          LogicalResult replacementStatus =
+              rewriter.notifyPayloadOperationReplaced(containingOp,
+                                                      newContainingOp);
+          (void)replacementStatus;
+          assert(succeeded(replacementStatus) &&
+                 "unable to update transform state mapping");
+          rewriter.eraseOp(containingOp);
+          containingOp = newContainingOp;
+        }
+        continue;
+      }
+
+      SmallVector<Operation *> tiledContainingOpOperand =
+          tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+              rewriter, diag, targetOp, containingOp);
+      if (!tiledContainingOpOperand.empty()) {
+        LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
+                          << *containingOp);
+        fusedOps.append(tiledContainingOpOperand);
+        continue;
+      }
+
+      Operation *cloned =
+          cloneAndFuseFirstUse(rewriter, diag, targetOp, containingOp);
+      if (cloned) {
+        LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
+        fusedOps.push_back(cloned);
+        continue;
+      }
+    } else {
+      // Fuse consumer.
+      auto [tiledOps, newContainingOp] = tileAndFuseParallelInsertSlice(
+          rewriter, diag, targetOp, containingOp);
+      if (!tiledOps.empty()) {
+        fusedOps.append(tiledOps);
         LogicalResult replacementStatus =
             rewriter.notifyPayloadOperationReplaced(containingOp,
                                                     newContainingOp);
@@ -961,26 +1224,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
                "unable to update transform state mapping");
         rewriter.eraseOp(containingOp);
         containingOp = newContainingOp;
+        continue;
       }
-      continue;
-    }
-
-    SmallVector<Operation *> tiledContainingOpOperand =
-        tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
-            rewriter, diag, producerOp, containingOp);
-    if (!tiledContainingOpOperand.empty()) {
-      LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
-                        << *containingOp);
-      fusedOps.append(tiledContainingOpOperand);
-      continue;
-    }
-
-    Operation *cloned =
-        cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
-    if (cloned) {
-      LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
-      fusedOps.push_back(cloned);
-      continue;
     }
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
@@ -992,7 +1237,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
 
 void transform::FuseIntoContainingOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  consumesHandle(getProducerOp(), effects);
+  consumesHandle(getTargetOp(), effects);
   onlyReadsHandle(getContainingOp(), effects);
   producesHandle(getResults(), effects);
   modifiesPayload(effects);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
+  void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+                              AffineMap indexingMap,
+                              ArrayRef<OpFoldResult> offsets,
+                              ArrayRef<OpFoldResult> sizes,
+                              SmallVector<OpFoldResult> &mappedOffsets,
+                              SmallVector<OpFoldResult> &mappedSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[range.index()] = range.value().offset;
+        mappedSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+      mappedSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult getIterDomainTilePositionFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled get iter domain position when operand is not "
+          "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
   // Return the details of the output tile generated by the tiled
   // implementation.
   LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return op->emitOpError(
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+                           mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c6..5a1f1b01483750 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -585,6 +585,129 @@ module {
 
 // -----
 
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_consumer
+  // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+  // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+  func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+    %0 = tensor.empty(%arg0) : tensor<?xf32>
+    %1 = affine.apply #map()[%arg0]
+    // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+    %2 = tensor.empty() : tensor<64xf32>
+    // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+    %3 = tensor.empty() : tensor<64xf32>
+    // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+    %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+      %6 = affine.apply #map1(%arg3)[%arg0]
+      %7 = affine.min #map2(%arg3)[%arg0]
+      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+      // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+      %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+      // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+      scf.forall.in_parallel {
+        // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+        // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+        tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+      }
+    }
+    // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+    %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+    // CHECK: return %[[RES]]#1
+    return %5 : tensor<64xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.forall">
+      %1 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg0 : (!transform.any_op) -> !transform.op<"linalg.elemwise_binary">
+      %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1 into %0 : (!transform.op<"linalg.elemwise_binary">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+module {
+  // CHECK-LABEL: func.func @fuse_consumer_multi_output
+  // CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+  // CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+  // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+  func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+    %cst = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[INIT:.*]] = linalg.fill
+    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+    // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+    %1 = tensor.empty() : tensor<123x789xf32>
+    // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+    %2 = tensor.empty() : tensor<789x123xf32>
+    // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+    %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+      %5 = affine.min #map(%arg3)
+      %6 = affine.min #map1(%arg4)
+      %7 = affine.apply #map2(%arg3)
+      %8 = affine.apply #map3(%arg4)
+      %9 = affine.apply #map2(%arg3)
+      %10 = affine.apply #map3(%arg4)
+      // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+      %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+      // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+      %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+      // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+      %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+      // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+      %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+      // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+      // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+      // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+      %12 = affine.apply #map2(%arg3)
+      %13 = affine.apply #map3(%arg4)
+      scf.forall.in_parallel {
+        // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[[LOOP_ARG1]]
+        // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#1 into %[[LOOP_ARG2]]
+        // CHECK: tensor.parallel_insert_slice %[[MATMUL_RES]] into %[[LOOP_ARG0]]
+        tensor.parallel_insert_slice %11 into %arg5[%12, %13] [%5, %6] [1, 1] : tensor<?x?xf32> into tensor<123x789xf32>
+      }
+    }
+    // CHECK: %[[ORI_OUTPUT:.*]]:2 = linalg.generic
+    %4:2 = linalg.generic {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<123x789xf32>) outs(%1, %2 : tensor<123x789xf32>, tensor<789x123xf32>) {
+    ^bb0(%in: f32, %out: f32, %out_0: f32):
+      %5 = arith.addf %in, %out : f32
+      %6 = arith.addf %5, %out_0 : f32
+      linalg.yield %5, %6 : f32, f32
+    } -> (tensor<123x789xf32>, tensor<789x123xf32>)
+    // CHECK: return %[[RES]]#1, %[[RES]]#2
+    return %4#0, %4#1 : tensor<123x789xf32>, tensor<789x123xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.forall">
+      %1 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.op<"linalg.generic">
+      %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1 into %0 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+
 // This is a regression test. Make sure that the transform succeeds and valid
 // IR is generated.
 
@@ -677,7 +800,7 @@ module attributes {transform.with_named_sequence} {
       num_threads [] tile_sizes [50, 16]
       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     // Note that we pass in %tiled_op, which isn't a container op.
-    // expected-error @+2 {{could not find next producer to fuse into container}}
+    // expected-error @+2 {{could not find next target to fuse into container}}
     %fused_op, %new_containing_op =
       transform.structured.fuse_into_containing_op %0 into %tiled_op
         : (!transform.any_op, !transform.any_op)



More information about the Mlir-commits mailing list