[Mlir-commits] [mlir] 0422a44 - [mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 16 09:21:56 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-16T09:21:46-07:00
New Revision: 0422a4407f6b1dcb12d37c8b59841e3ccf0c9861

URL: https://github.com/llvm/llvm-project/commit/0422a4407f6b1dcb12d37c8b59841e3ccf0c9861
DIFF: https://github.com/llvm/llvm-project/commit/0422a4407f6b1dcb12d37c8b59841e3ccf0c9861.diff

LOG: [mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion.

This revision revisits the implementation of `transform.fuse_into_containing_op` so that it iterates on
producers one use at a time.

Support is added to fuse a producer through a foreach_thread shared tensor argument, in which case we
tile and fuse the op inside the containing op and update the shared tensor argument to the unique destination operand.
If one cannot find such a unique destination operand the transform fails.

Differential Revision: https://reviews.llvm.org/D134051

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 93b1274d0e884..bc3f10c717c3f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -18,7 +18,6 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
-#include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/StringSet.h"
 
@@ -226,78 +225,167 @@ LogicalResult transform::FuseOp::verify() {
 // FuseIntoContainingOp
 //===----------------------------------------------------------------------===//
 
-static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
-                                                       Operation *containingOp,
-                                                       RewriterBase &rewriter) {
+/// Find the first "extract" user of `producerOp` and tile it right before its
+/// use. The tiled op is fused under the `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
+                                             Operation *containingOp,
+                                             RewriterBase &rewriter) {
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer)
-    return failure();
+    return nullptr;
 
   // Search the producer slices accessed within the containing operation.
-  // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
+  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
   // evolve into an interface.
-  SmallVector<tensor::ExtractSliceOp> sliceOps;
-  for (Operation *user : tileableProducer->getUsers()) {
+  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
-    if (!sliceOp)
-      continue;
-    if (!containingOp->isProperAncestor(sliceOp))
-      continue;
-    sliceOps.push_back(sliceOp);
-  }
+    return sliceOp && containingOp->isProperAncestor(sliceOp);
+  });
 
-  // Check for a non-empty list of fusion opportunities.
-  if (sliceOps.empty())
-    return failure();
+  // Find a fusion opportunity.
+  if (it == tileableProducer->getUsers().end())
+    return nullptr;
+  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
 
   // Try to fuse the producer in-place.
-  SmallVector<Operation *> fusedOps;
-  for (tensor::ExtractSliceOp sliceOp : sliceOps) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPoint(sliceOp);
-
-    // Tile the producer.
-    FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
-        rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes());
-    if (failed(tiledProducer))
-      return failure();
-    fusedOps.push_back(tiledProducer->getDefiningOp());
-  }
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(sliceOpToTile);
+
+  // Tile the producer.
+  FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+      rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+      sliceOpToTile.getMixedSizes());
+  if (failed(tiledProducer))
+    return nullptr;
+
+  // Replace the extract op.
+  Operation *fusedOp = tiledProducer->getDefiningOp();
+  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+  return fusedOp;
+}
+
+/// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure
+/// it is exactly the `containingOp`, otherwise bail.
+/// Then, find the first "extract" user of the tied block argument and tile it
+/// right before its "extract" use. The tiled op is fused under the
+/// `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+    Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
+
+  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
+  if (!tileableProducer)
+    return nullptr;
+
+  // Search the first use by a "scf::ForeachThreadOp" user.
+  scf::ForeachThreadOp foreachThreadOp;
+  auto itProducerUses =
+      llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
+        foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner());
+        return foreachThreadOp;
+      });
+  // If it's not from the containing op, return.
+  if (!foreachThreadOp || foreachThreadOp != containingOp)
+    return nullptr;
+
+  // Search the producer slices accessed within the containing
+  // operation.
+  // TODO: Generalize to more extract/insert/parallel_insert triples.
+  //   Maybe evolve into an interface.
+  OpOperand *pUse = &(*itProducerUses);
+  BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse);
+
+  // Search the producer slices accessed within the containing operation.
+  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
+  // evolve into an interface.
+  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
+    auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
+    return sliceOp && containingOp->isProperAncestor(sliceOp);
+  });
+
+  // Find a fusion opportunity.
+  if (itBBArgUsers == bbArg.getUsers().end())
+    return nullptr;
+  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
+
+  // Ensure `tileableProducer` has exactly one destination operand that we can
+  // replace the ForeachThreadOp bbArg with.
+  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+  if (destinationOperands.size() != 1)
+    return nullptr;
+
+  // Try to fuse the producer in-place.
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(sliceOpToTile);
+
+  // Replace the use in the tileableProducer before tiling: clone, replace and
+  // then tile.
+  BlockAndValueMapping bvm;
+  bvm.map(destinationOperands.front(), bbArg);
+  auto tileableProducerClone =
+      cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
+  auto scopeGuard =
+      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+  // Tile the producer.
+  FailureOr<Value> tiledProducer =
+      tileableProducerClone.generateResultTileValue(
+          rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+          sliceOpToTile.getMixedSizes());
+  if (failed(tiledProducer))
+    return nullptr;
 
   // Replace the extract op.
-  for (const auto &en : enumerate(sliceOps))
-    rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
-  return fusedOps;
+  Operation *fusedOp = tiledProducer->getDefiningOp();
+  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+
+  // Replace the use in containingOp.
+  rewriter.updateRootInPlace(containingOp, [&]() {
+    containingOp->setOperand(pUse->getOperandNumber(),
+                             destinationOperands.front());
+  });
+
+  return fusedOp;
 }
 
-static FailureOr<SmallVector<Operation *>>
-cloneAndFuse(Operation *producerOp, Operation *containingOp,
-             RewriterBase &rewriter) {
+static Operation *cloneAndFuseFirstUse(Operation *producerOp,
+                                       Operation *containingOp,
+                                       RewriterBase &rewriter) {
   // Gather all uses inside the containing op.
   SmallVector<OpOperand *> uses;
-  for (OpResult result : producerOp->getOpResults())
-    for (OpOperand &use : result.getUses())
-      if (containingOp->isProperAncestor(use.getOwner()))
+  for (OpResult result : producerOp->getOpResults()) {
+    for (OpOperand &use : result.getUses()) {
+      if (containingOp->isProperAncestor(use.getOwner())) {
         uses.push_back(&use);
+        continue;
+      }
+      // Cannot clone and fuse if the use is by the containing op itself: fail
+      // immediately.
+      if (containingOp == use.getOwner())
+        return nullptr;
+    }
+  }
 
   // Check for a non-empty list of fusion opportunities.
   if (uses.empty())
-    return failure();
+    return nullptr;
 
   // Clone and fuse inside the containing op.
-  SmallVector<Operation *> fusedOps;
-  for (OpOperand *use : uses) {
-    unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPoint(use->getOwner());
-    Operation *cloned = rewriter.clone(*producerOp);
-    rewriter.updateRootInPlace(
-        use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
-    fusedOps.push_back(cloned);
-  }
-
-  return fusedOps;
+  Operation *fusedOp = nullptr;
+  OpOperand *use = uses.front();
+  // Parallel insert slice is not a valid clone destination.
+  // TODO: Generalize to other type of ops.
+  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
+         "Parallel insert slice is not a valid clone destination");
+  unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(use->getOwner());
+  fusedOp = rewriter.clone(*producerOp);
+  rewriter.updateRootInPlace(
+      use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
+
+  return fusedOp;
 }
 
 DiagnosedSilenceableFailure
@@ -312,7 +400,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   }
   for (Operation *producerOp : producerOps) {
     if (producerOp->getNumResults() != 1) {
-      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
       diag << "op with != 1 results not supported";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
@@ -331,15 +419,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   auto getNextProducer = [&]() -> FailureOr<Operation *> {
     for (const auto &it : enumerate(remainingProducers)) {
       Operation *producerOp = it.value();
-      bool hasUseInContainingOp =
-          any_of(producerOp->getUsers(), [&](Operation *op) {
-            return containingOp->isProperAncestor(op);
+      // The containing op may be a user of producerOp: use isAncestor.
+      int64_t numUsesInContainingOp =
+          llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+            return containingOp->isAncestor(op);
           });
-      // TODO: When resolving the TODO below (no duplicate ops), take an op that
-      // has no use among the remaining producers. This is a topological
+      // TODO: When resolving the TODO below (no duplicate ops), take an op
+      // that has no use among the remaining producers. This is a topological
       // sorting.
-      if (hasUseInContainingOp) {
-        remainingProducers.erase(remainingProducers.begin() + it.index());
+      if (numUsesInContainingOp > 0) {
+        if (numUsesInContainingOp == 1)
+          remainingProducers.erase(remainingProducers.begin() + it.index());
         return producerOp;
       }
     }
@@ -350,29 +440,42 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   while (!remainingProducers.empty()) {
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
-      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
+      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
       diag << "could not fuse ops into container";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
 
     Operation *producerOp = *nextProducer;
-    // TODO: If there are multiple uses of the producer in the containing op, we
-    // currently tile/clone the op multiple times (once per use). In some cases,
-    // we can tile/clone once and reuse the value for each use. Futhermore,
-    // producers should then be traversed according to a topological sorting.
-    auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
-    if (succeeded(tiled))
-      fusedOps.append(*tiled);
-
-    auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
-    if (succeeded(cloned))
-      fusedOps.append(*cloned);
-
-    if (failed(tiled) && failed(cloned)) {
-      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
-      diag << "could not fuse into containing op";
-      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    // TODO: If there are multiple uses of the producer in the containing op,
+    // we currently tile/clone the op multiple times (once per use). In some
+    // cases, we can tile/clone once and reuse the value for each use.
+    // Futhermore, producers should then be traversed according to a
+    // topological sorting.
+    Operation *tiled =
+        tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
+    if (tiled) {
+      fusedOps.push_back(tiled);
+      continue;
+    }
+
+    Operation *tiledContainingOpOperand =
+        tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+            producerOp, containingOp, rewriter);
+    if (tiledContainingOpOperand) {
+      fusedOps.push_back(tiledContainingOpOperand);
+      continue;
     }
+
+    Operation *cloned =
+        cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
+    if (cloned) {
+      fusedOps.push_back(cloned);
+      continue;
+    }
+
+    Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+    diag << "could not fuse " << *producerOp << "into " << *containingOp;
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
   results.set(getFusedOp().cast<OpResult>(), fusedOps);
@@ -626,9 +729,9 @@ LogicalResult transform::PadOp::verify() {
       extractFromI64ArrayAttr(getPaddingDimensions());
   if (any_of(paddingDimensions,
              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
-    return emitOpError()
-           << "expects padding_dimensions to contain positive integers, found "
-           << getPaddingDimensions();
+    return emitOpError() << "expects padding_dimensions to contain positive "
+                            "integers, found "
+                         << getPaddingDimensions();
   }
 
   SmallVector<int64_t> hoistPaddings =
@@ -699,8 +802,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
                                    transform::TransformState &state) {
   LinalgTilingOptions tilingOptions;
   tilingOptions.scalarizeDynamicDims();
-  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
-  // sizes and asserts that it is not already set.
+  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
+  // tile sizes and asserts that it is not already set.
   SmallVector<int64_t> emptyTileSizes;
   LinalgTilingPattern pattern(getContext(), tilingOptions);
   SimpleRewriter rewriter(getContext());
@@ -847,8 +950,8 @@ LogicalResult SplitOp::verify() {
   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
        ShapedType::kDynamicSize) ^
       (getDynamicSplitPoint() == nullptr)) {
-    return emitOpError()
-           << "expects either a dynamic or a static split point to be provided";
+    return emitOpError() << "expects either a dynamic or a static split "
+                            "point to be provided";
   }
   return success();
 }
@@ -1225,8 +1328,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Registers new ops and declares PDL as dependent dialect since the additional
-/// ops are using PDL types for operands and results.
+/// Registers new ops and declares PDL as dependent dialect since the
+/// additional ops are using PDL types for operands and results.
 class LinalgTransformDialectExtension
     : public transform::TransformDialectExtension<
           LinalgTransformDialectExtension> {

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index b3cd3283286ce..77bd3b2da13a7 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -99,3 +99,51 @@ module {
     }
   }
 }
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_op_through_bbarg(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+    %1 = affine.apply #map0()[%d0, %arg0]
+
+    // CHECK: scf.foreach_thread {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
+    %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %0) -> (tensor<?xf32>) {
+      %3 = affine.apply #map1(%arg3)[%arg0]
+      %4 = affine.min #map2(%arg3)[%d0, %arg0]
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
+      %6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]]
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.foreach_thread.perform_concurrently {
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %2 : tensor<?xf32>
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+    %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+    // linalg.fill is tileable. The op is tiled and fused.
+    transform.structured.fuse_into_containing_op %0 into %1
+  }
+}


        


More information about the Mlir-commits mailing list