[Mlir-commits] [mlir] 54a5f60 - [mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Sep 14 08:51:05 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-14T08:50:32-07:00
New Revision: 54a5f606281d05203dca1d81d135e691b10bc513

URL: https://github.com/llvm/llvm-project/commit/54a5f606281d05203dca1d81d135e691b10bc513
DIFF: https://github.com/llvm/llvm-project/commit/54a5f606281d05203dca1d81d135e691b10bc513.diff

LOG: [mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion.

This revision revisits the implementation of `transform.fuse_into_containing_op` so that it iterates on
producers one use at a time.

Support is added to fuse a producer through a foreach_thread shared tensor argument, in which case we
tile and fuse the op inside the containing op and update the shared tensor argument to the unique destination operand.
If one cannot find such a unique destination operand the transform fails.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 29b13e27de7e..49328a6cb708 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -17,9 +17,12 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringSet.h"
 
 using namespace mlir;
@@ -226,78 +229,168 @@ LogicalResult transform::FuseOp::verify() {
 // FuseIntoContainingOp
 //===----------------------------------------------------------------------===//
 
-static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
-                                                       Operation *containingOp,
-                                                       RewriterBase &rewriter) {
+/// Find the first "extract" user of `producerOp` and tile it right before its
+/// use. The tiled op is now fused under the `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
+                                             Operation *containingOp,
+                                             RewriterBase &rewriter) {
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer)
-    return failure();
+    return nullptr;
 
   // Search the producer slices accessed within the containing operation.
-  // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
-  // evolve into an interface.
-  SmallVector<tensor::ExtractSliceOp> sliceOps;
-  for (Operation *user : tileableProducer->getUsers()) {
+  // TODO: Generalize to more extract/insert/parallel_insert triples.
+  //   Maybe evolve into an interface.
+  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
-    if (!sliceOp)
-      continue;
-    if (!containingOp->isProperAncestor(sliceOp))
+    return sliceOp && containingOp->isProperAncestor(sliceOp);
+  });
+
+  // Check for a non-empty fusion opportunity.
+  if (it == tileableProducer->getUsers().end())
+    return nullptr;
+  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
+
+  // Try to fuse the producer in-place.
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(sliceOpToTile);
+
+  // Tile the producer.
+  FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+      rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+      sliceOpToTile.getMixedSizes());
+  if (failed(tiledProducer))
+    return nullptr;
+
+  // Replace the extract op.
+  Operation *fusedOp = tiledProducer->getDefiningOp();
+  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+  return fusedOp;
+}
+
+/// Find the first "extract" user of `producerOp` and tile it right before its
+/// use. The tiled op is now fused under the `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+    Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
+
+  auto foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(containingOp);
+  if (!foreachThreadOp)
+    return nullptr;
+
+  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
+  if (!tileableProducer)
+    return nullptr;
+
+  // Search the producer slices accessed within the containing
+  // operation.
+  // TODO: Generalize to more extract/insert/parallel_insert triples.
+  //   Maybe evolve into an interface.
+  OpOperand *pUse;
+  BlockArgument bbArg;
+  tensor::ExtractSliceOp sliceOpToTile;
+  // Only consider slices that may come from the containingOp args.
+  for (OpOperand &use : tileableProducer->getUses()) {
+    if (use.getOwner() != containingOp)
       continue;
-    sliceOps.push_back(sliceOp);
+    pUse = &use;
+    bbArg = foreachThreadOp.getTiedBlockArgument(&use);
+    for (Operation *user : bbArg.getUsers()) {
+      auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
+      if (!sliceOp)
+        continue;
+      if (!containingOp->isAncestor(sliceOp))
+        continue;
+      sliceOpToTile = sliceOp;
+      break;
+    }
+    if (sliceOpToTile)
+      break;
   }
 
   // Check for a non-empty list of fusion opportunities.
-  if (sliceOps.empty())
-    return failure();
+  if (!sliceOpToTile || !pUse)
+    return nullptr;
 
-  // Try to fuse the producer in-place.
-  SmallVector<Operation *> fusedOps;
-  for (tensor::ExtractSliceOp sliceOp : sliceOps) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPoint(sliceOp);
+  // Ensure there is exactly one destination operand that we can replace the
+  // ForeachThreadOp bbArg with.
+  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+  if (destinationOperands.size() != 1)
+    return nullptr;
 
-    // Tile the producer.
-    FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
-        rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes());
-    if (failed(tiledProducer))
-      return failure();
-    fusedOps.push_back(tiledProducer->getDefiningOp());
-  }
+  // Try to fuse the producer in-place.
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(sliceOpToTile);
+
+  // Replace the use in the tileableProducer before tiling, replace and then
+  // tile.
+  BlockAndValueMapping bvm;
+  bvm.map(destinationOperands.front(), bbArg);
+  auto tileableProducerClone =
+      cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
+  auto scopeGuard =
+      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+  // Tile the producer.
+  FailureOr<Value> tiledProducer =
+      tileableProducerClone.generateResultTileValue(
+          rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+          sliceOpToTile.getMixedSizes());
+  if (failed(tiledProducer))
+    return nullptr;
 
   // Replace the extract op.
-  for (const auto &en : enumerate(sliceOps))
-    rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
-  return fusedOps;
+  Operation *fusedOp = tiledProducer->getDefiningOp();
+  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+
+  // Replace the use in containingOp.
+  rewriter.startRootUpdate(fusedOp);
+  containingOp->setOperand(pUse->getOperandNumber(),
+                           destinationOperands.front());
+  rewriter.finalizeRootUpdate(fusedOp);
+
+  return fusedOp;
 }
 
-static FailureOr<SmallVector<Operation *>>
-cloneAndFuse(Operation *producerOp, Operation *containingOp,
-             RewriterBase &rewriter) {
+static Operation *cloneAndFuseFirstUse(Operation *producerOp,
+                                       Operation *containingOp,
+                                       RewriterBase &rewriter) {
   // Gather all uses inside the containing op.
   SmallVector<OpOperand *> uses;
-  for (OpResult result : producerOp->getOpResults())
-    for (OpOperand &use : result.getUses())
-      if (containingOp->isProperAncestor(use.getOwner()))
+  for (OpResult result : producerOp->getOpResults()) {
+    for (OpOperand &use : result.getUses()) {
+      if (containingOp->isProperAncestor(use.getOwner())) {
         uses.push_back(&use);
+        continue;
+      }
+      // Cannot clone and fuse if the use is fom the containing op itself: fail.
+      if (containingOp == use.getOwner())
+        return nullptr;
+    }
+  }
 
   // Check for a non-empty list of fusion opportunities.
   if (uses.empty())
-    return failure();
+    return nullptr;
 
   // Clone and fuse inside the containing op.
-  SmallVector<Operation *> fusedOps;
+  Operation *fusedOp = nullptr;
   for (OpOperand *use : uses) {
+    // Parallel insert slice is not a valid clone destination.
+    // TODO: Generalize to other type of ops.
+    assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
+           "Parallel insert slice is not a valid clone destination");
     unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPoint(use->getOwner());
-    Operation *cloned = rewriter.clone(*producerOp);
+    fusedOp = rewriter.clone(*producerOp);
     rewriter.updateRootInPlace(
-        use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
-    fusedOps.push_back(cloned);
+        use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
+    break;
   }
 
-  return fusedOps;
+  return fusedOp;
 }
 
 DiagnosedSilenceableFailure
@@ -312,7 +405,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   }
   for (Operation *producerOp : producerOps) {
     if (producerOp->getNumResults() != 1) {
-      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
       diag << "op with != 1 results not supported";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
@@ -331,15 +424,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   auto getNextProducer = [&]() -> FailureOr<Operation *> {
     for (const auto &it : enumerate(remainingProducers)) {
       Operation *producerOp = it.value();
-      bool hasUseInContainingOp =
-          any_of(producerOp->getUsers(), [&](Operation *op) {
-            return containingOp->isProperAncestor(op);
+      // The containing op may be a user of producerOp: use isAncestor.
+      int64_t numUsesInContainingOp =
+          llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+            return containingOp->isAncestor(op);
           });
-      // TODO: When resolving the TODO below (no duplicate ops), take an op that
-      // has no use among the remaining producers. This is a topological
+      // TODO: When resolving the TODO below (no duplicate ops), take an op
+      // that has no use among the remaining producers. This is a topological
       // sorting.
-      if (hasUseInContainingOp) {
-        remainingProducers.erase(remainingProducers.begin() + it.index());
+      if (numUsesInContainingOp > 0) {
+        if (numUsesInContainingOp == 1)
+          remainingProducers.erase(remainingProducers.begin() + it.index());
         return producerOp;
       }
     }
@@ -350,29 +445,42 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   while (!remainingProducers.empty()) {
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
-      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
+      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
       diag << "could not fuse ops into container";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
 
     Operation *producerOp = *nextProducer;
-    // TODO: If there are multiple uses of the producer in the containing op, we
-    // currently tile/clone the op multiple times (once per use). In some cases,
-    // we can tile/clone once and reuse the value for each use. Futhermore,
-    // producers should then be traversed according to a topological sorting.
-    auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
-    if (succeeded(tiled))
-      fusedOps.append(*tiled);
-
-    auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
-    if (succeeded(cloned))
-      fusedOps.append(*cloned);
-
-    if (failed(tiled) && failed(cloned)) {
-      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
-      diag << "could not fuse into containing op";
-      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    // TODO: If there are multiple uses of the producer in the containing op,
+    // we currently tile/clone the op multiple times (once per use). In some
+    // cases, we can tile/clone once and reuse the value for each use.
+    // Futhermore, producers should then be traversed according to a
+    // topological sorting.
+    Operation *tiled =
+        tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
+    if (tiled) {
+      fusedOps.push_back(tiled);
+      continue;
+    }
+
+    Operation *tiledContainingOpOperand =
+        tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+            producerOp, containingOp, rewriter);
+    if (tiledContainingOpOperand) {
+      fusedOps.push_back(tiledContainingOpOperand);
+      continue;
+    }
+
+    Operation *cloned =
+        cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
+    if (cloned) {
+      fusedOps.push_back(cloned);
+      continue;
     }
+
+    Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+    diag << "could not fuse " << *producerOp << "into " << *containingOp;
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
   results.set(getFusedOp().cast<OpResult>(), fusedOps);
@@ -626,9 +734,9 @@ LogicalResult transform::PadOp::verify() {
       extractFromI64ArrayAttr(getPaddingDimensions());
   if (any_of(paddingDimensions,
              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
-    return emitOpError()
-           << "expects padding_dimensions to contain positive integers, found "
-           << getPaddingDimensions();
+    return emitOpError() << "expects padding_dimensions to contain positive "
+                            "integers, found "
+                         << getPaddingDimensions();
   }
 
   SmallVector<int64_t> hoistPaddings =
@@ -699,8 +807,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
                                    transform::TransformState &state) {
   LinalgTilingOptions tilingOptions;
   tilingOptions.scalarizeDynamicDims();
-  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
-  // sizes and asserts that it is not already set.
+  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
+  // tile sizes and asserts that it is not already set.
   SmallVector<int64_t> emptyTileSizes;
   LinalgTilingPattern pattern(getContext(), tilingOptions);
   SimpleRewriter rewriter(getContext());
@@ -847,8 +955,8 @@ LogicalResult SplitOp::verify() {
   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
        ShapedType::kDynamicSize) ^
       (getDynamicSplitPoint() == nullptr)) {
-    return emitOpError()
-           << "expects either a dynamic or a static split point to be provided";
+    return emitOpError() << "expects either a dynamic or a static split "
+                            "point to be provided";
   }
   return success();
 }
@@ -1202,8 +1310,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Registers new ops and declares PDL as dependent dialect since the additional
-/// ops are using PDL types for operands and results.
+/// Registers new ops and declares PDL as dependent dialect since the
+/// additional ops are using PDL types for operands and results.
 class LinalgTransformDialectExtension
     : public transform::TransformDialectExtension<
           LinalgTransformDialectExtension> {


        


More information about the Mlir-commits mailing list