[Mlir-commits] [mlir] e479aec - Revert "[mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion."

Nicolas Vasilache llvmlistbot at llvm.org
Wed Sep 14 08:52:08 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-14T08:51:30-07:00
New Revision: e479aecd56d20bea409c507dd237c3f37a766702

URL: https://github.com/llvm/llvm-project/commit/e479aecd56d20bea409c507dd237c3f37a766702
DIFF: https://github.com/llvm/llvm-project/commit/e479aecd56d20bea409c507dd237c3f37a766702.diff

LOG: Revert "[mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion."

This reverts commit 54a5f606281d05203dca1d81d135e691b10bc513 which is a WIP that was pushed by mistake.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 49328a6cb708..29b13e27de7e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -17,12 +17,9 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringSet.h"
 
 using namespace mlir;
@@ -229,168 +226,78 @@ LogicalResult transform::FuseOp::verify() {
 // FuseIntoContainingOp
 //===----------------------------------------------------------------------===//
 
-/// Find the first "extract" user of `producerOp` and tile it right before its
-/// use. The tiled op is now fused under the `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
-                                             Operation *containingOp,
-                                             RewriterBase &rewriter) {
+static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
+                                                       Operation *containingOp,
+                                                       RewriterBase &rewriter) {
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer)
-    return nullptr;
+    return failure();
 
   // Search the producer slices accessed within the containing operation.
-  // TODO: Generalize to more extract/insert/parallel_insert triples.
-  //   Maybe evolve into an interface.
-  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
+  // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
+  // evolve into an interface.
+  SmallVector<tensor::ExtractSliceOp> sliceOps;
+  for (Operation *user : tileableProducer->getUsers()) {
     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
-    return sliceOp && containingOp->isProperAncestor(sliceOp);
-  });
-
-  // Check for a non-empty fusion opportunity.
-  if (it == tileableProducer->getUsers().end())
-    return nullptr;
-  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
-
-  // Try to fuse the producer in-place.
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(sliceOpToTile);
-
-  // Tile the producer.
-  FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
-      rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
-      sliceOpToTile.getMixedSizes());
-  if (failed(tiledProducer))
-    return nullptr;
-
-  // Replace the extract op.
-  Operation *fusedOp = tiledProducer->getDefiningOp();
-  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
-  return fusedOp;
-}
-
-/// Find the first "extract" user of `producerOp` and tile it right before its
-/// use. The tiled op is now fused under the `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
-    Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
-
-  auto foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(containingOp);
-  if (!foreachThreadOp)
-    return nullptr;
-
-  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
-  if (!tileableProducer)
-    return nullptr;
-
-  // Search the producer slices accessed within the containing
-  // operation.
-  // TODO: Generalize to more extract/insert/parallel_insert triples.
-  //   Maybe evolve into an interface.
-  OpOperand *pUse;
-  BlockArgument bbArg;
-  tensor::ExtractSliceOp sliceOpToTile;
-  // Only consider slices that may come from the containingOp args.
-  for (OpOperand &use : tileableProducer->getUses()) {
-    if (use.getOwner() != containingOp)
+    if (!sliceOp)
       continue;
-    pUse = &use;
-    bbArg = foreachThreadOp.getTiedBlockArgument(&use);
-    for (Operation *user : bbArg.getUsers()) {
-      auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
-      if (!sliceOp)
-        continue;
-      if (!containingOp->isAncestor(sliceOp))
-        continue;
-      sliceOpToTile = sliceOp;
-      break;
-    }
-    if (sliceOpToTile)
-      break;
+    if (!containingOp->isProperAncestor(sliceOp))
+      continue;
+    sliceOps.push_back(sliceOp);
   }
 
   // Check for a non-empty list of fusion opportunities.
-  if (!sliceOpToTile || !pUse)
-    return nullptr;
-
-  // Ensure there is exactly one destination operand that we can replace the
-  // ForeachThreadOp bbArg with.
-  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
-  if (destinationOperands.size() != 1)
-    return nullptr;
+  if (sliceOps.empty())
+    return failure();
 
   // Try to fuse the producer in-place.
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(sliceOpToTile);
-
-  // Replace the use in the tileableProducer before tiling, replace and then
-  // tile.
-  BlockAndValueMapping bvm;
-  bvm.map(destinationOperands.front(), bbArg);
-  auto tileableProducerClone =
-      cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
-  auto scopeGuard =
-      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
-
-  // Tile the producer.
-  FailureOr<Value> tiledProducer =
-      tileableProducerClone.generateResultTileValue(
-          rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
-          sliceOpToTile.getMixedSizes());
-  if (failed(tiledProducer))
-    return nullptr;
-
-  // Replace the extract op.
-  Operation *fusedOp = tiledProducer->getDefiningOp();
-  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+  SmallVector<Operation *> fusedOps;
+  for (tensor::ExtractSliceOp sliceOp : sliceOps) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(sliceOp);
 
-  // Replace the use in containingOp.
-  rewriter.startRootUpdate(fusedOp);
-  containingOp->setOperand(pUse->getOperandNumber(),
-                           destinationOperands.front());
-  rewriter.finalizeRootUpdate(fusedOp);
+    // Tile the producer.
+    FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+        rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
+        sliceOp.getMixedSizes());
+    if (failed(tiledProducer))
+      return failure();
+    fusedOps.push_back(tiledProducer->getDefiningOp());
+  }
 
-  return fusedOp;
+  // Replace the extract op.
+  for (const auto &en : enumerate(sliceOps))
+    rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
+  return fusedOps;
 }
 
-static Operation *cloneAndFuseFirstUse(Operation *producerOp,
-                                       Operation *containingOp,
-                                       RewriterBase &rewriter) {
+static FailureOr<SmallVector<Operation *>>
+cloneAndFuse(Operation *producerOp, Operation *containingOp,
+             RewriterBase &rewriter) {
   // Gather all uses inside the containing op.
   SmallVector<OpOperand *> uses;
-  for (OpResult result : producerOp->getOpResults()) {
-    for (OpOperand &use : result.getUses()) {
-      if (containingOp->isProperAncestor(use.getOwner())) {
+  for (OpResult result : producerOp->getOpResults())
+    for (OpOperand &use : result.getUses())
+      if (containingOp->isProperAncestor(use.getOwner()))
         uses.push_back(&use);
-        continue;
-      }
-      // Cannot clone and fuse if the use is fom the containing op itself: fail.
-      if (containingOp == use.getOwner())
-        return nullptr;
-    }
-  }
 
   // Check for a non-empty list of fusion opportunities.
   if (uses.empty())
-    return nullptr;
+    return failure();
 
   // Clone and fuse inside the containing op.
-  Operation *fusedOp = nullptr;
+  SmallVector<Operation *> fusedOps;
   for (OpOperand *use : uses) {
-    // Parallel insert slice is not a valid clone destination.
-    // TODO: Generalize to other type of ops.
-    assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
-           "Parallel insert slice is not a valid clone destination");
     unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPoint(use->getOwner());
-    fusedOp = rewriter.clone(*producerOp);
+    Operation *cloned = rewriter.clone(*producerOp);
     rewriter.updateRootInPlace(
-        use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
-    break;
+        use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
+    fusedOps.push_back(cloned);
   }
 
-  return fusedOp;
+  return fusedOps;
 }
 
 DiagnosedSilenceableFailure
@@ -405,7 +312,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   }
   for (Operation *producerOp : producerOps) {
     if (producerOp->getNumResults() != 1) {
-      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
       diag << "op with != 1 results not supported";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
@@ -424,17 +331,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   auto getNextProducer = [&]() -> FailureOr<Operation *> {
     for (const auto &it : enumerate(remainingProducers)) {
       Operation *producerOp = it.value();
-      // The containing op may be a user of producerOp: use isAncestor.
-      int64_t numUsesInContainingOp =
-          llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
-            return containingOp->isAncestor(op);
+      bool hasUseInContainingOp =
+          any_of(producerOp->getUsers(), [&](Operation *op) {
+            return containingOp->isProperAncestor(op);
           });
-      // TODO: When resolving the TODO below (no duplicate ops), take an op
-      // that has no use among the remaining producers. This is a topological
+      // TODO: When resolving the TODO below (no duplicate ops), take an op that
+      // has no use among the remaining producers. This is a topological
       // sorting.
-      if (numUsesInContainingOp > 0) {
-        if (numUsesInContainingOp == 1)
-          remainingProducers.erase(remainingProducers.begin() + it.index());
+      if (hasUseInContainingOp) {
+        remainingProducers.erase(remainingProducers.begin() + it.index());
         return producerOp;
       }
     }
@@ -445,42 +350,29 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   while (!remainingProducers.empty()) {
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
-      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
+      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
       diag << "could not fuse ops into container";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
 
     Operation *producerOp = *nextProducer;
-    // TODO: If there are multiple uses of the producer in the containing op,
-    // we currently tile/clone the op multiple times (once per use). In some
-    // cases, we can tile/clone once and reuse the value for each use.
-    // Futhermore, producers should then be traversed according to a
-    // topological sorting.
-    Operation *tiled =
-        tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
-    if (tiled) {
-      fusedOps.push_back(tiled);
-      continue;
-    }
-
-    Operation *tiledContainingOpOperand =
-        tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
-            producerOp, containingOp, rewriter);
-    if (tiledContainingOpOperand) {
-      fusedOps.push_back(tiledContainingOpOperand);
-      continue;
-    }
-
-    Operation *cloned =
-        cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
-    if (cloned) {
-      fusedOps.push_back(cloned);
-      continue;
+    // TODO: If there are multiple uses of the producer in the containing op, we
+    // currently tile/clone the op multiple times (once per use). In some cases,
+    // we can tile/clone once and reuse the value for each use. Futhermore,
+    // producers should then be traversed according to a topological sorting.
+    auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
+    if (succeeded(tiled))
+      fusedOps.append(*tiled);
+
+    auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
+    if (succeeded(cloned))
+      fusedOps.append(*cloned);
+
+    if (failed(tiled) && failed(cloned)) {
+      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+      diag << "could not fuse into containing op";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
-
-    Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
-    diag << "could not fuse " << *producerOp << "into " << *containingOp;
-    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
   results.set(getFusedOp().cast<OpResult>(), fusedOps);
@@ -734,9 +626,9 @@ LogicalResult transform::PadOp::verify() {
       extractFromI64ArrayAttr(getPaddingDimensions());
   if (any_of(paddingDimensions,
              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
-    return emitOpError() << "expects padding_dimensions to contain positive "
-                            "integers, found "
-                         << getPaddingDimensions();
+    return emitOpError()
+           << "expects padding_dimensions to contain positive integers, found "
+           << getPaddingDimensions();
   }
 
   SmallVector<int64_t> hoistPaddings =
@@ -807,8 +699,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
                                    transform::TransformState &state) {
   LinalgTilingOptions tilingOptions;
   tilingOptions.scalarizeDynamicDims();
-  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
-  // tile sizes and asserts that it is not already set.
+  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
+  // sizes and asserts that it is not already set.
   SmallVector<int64_t> emptyTileSizes;
   LinalgTilingPattern pattern(getContext(), tilingOptions);
   SimpleRewriter rewriter(getContext());
@@ -955,8 +847,8 @@ LogicalResult SplitOp::verify() {
   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
        ShapedType::kDynamicSize) ^
       (getDynamicSplitPoint() == nullptr)) {
-    return emitOpError() << "expects either a dynamic or a static split "
-                            "point to be provided";
+    return emitOpError()
+           << "expects either a dynamic or a static split point to be provided";
   }
   return success();
 }
@@ -1310,8 +1202,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Registers new ops and declares PDL as dependent dialect since the
-/// additional ops are using PDL types for operands and results.
+/// Registers new ops and declares PDL as dependent dialect since the additional
+/// ops are using PDL types for operands and results.
 class LinalgTransformDialectExtension
     : public transform::TransformDialectExtension<
           LinalgTransformDialectExtension> {


        


More information about the Mlir-commits mailing list