[Mlir-commits] [mlir] [mlir] Don't require extract_slice in fusion with transform op (PR #112755)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 17 11:04:21 PDT 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/112755

>From 432341c07ac5a1fc90aecd66e8ebeae3a4507c08 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 10:04:01 -0500
Subject: [PATCH] [mlir] Don't require extract_slice in fusion with transform
 op

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 .../TransformOps/LinalgTransformOps.cpp       | 54 +++++++++++++------
 .../transform-op-fuse-into-containing.mlir    | 40 ++++++++++++++
 2 files changed, 77 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ad72b5d7beccde..2bc1d5dde6b5d9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   // Search the producer slices accessed within the containing operation.
   // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
   // evolve into an interface.
+  if (bbArg.getUsers().empty()) {
+    diag.attachNote(containingOp->getLoc())
+        << "could not find fusion opportunity for bbArg: " << bbArg;
+    return {};
+  }
   auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
     return sliceOp && containingOp->isProperAncestor(sliceOp);
   });
-
-  // Find a fusion opportunity.
+  OpBuilder::InsertionGuard guard(rewriter);
+  tensor::ExtractSliceOp sliceOpToTile;
   if (itBBArgUsers == bbArg.getUsers().end()) {
-    diag.attachNote(containingOp->getLoc())
-        << "could not find fusion opportunity for bbArg: " << bbArg;
-    return {};
+    rewriter.setInsertionPoint(&bbArg.getOwner()->front());
+  } else {
+    sliceOpToTile = llvm::cast<tensor::ExtractSliceOp>(*itBBArgUsers);
+    rewriter.setInsertionPoint(sliceOpToTile);
   }
-  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
-
-  // Try to fuse the producer in-place.
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(sliceOpToTile);
-
-  // Replace the use in the tileableProducer before tiling: clone, replace and
-  // then tile.
-  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
-  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
   // Gather destination tensors.
   SmallVector<Value> destinationTensors;
@@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     return {};
   }
 
+  // Replace the use in the tileableProducer before tiling: clone, replace and
+  // then tile.
+  SmallVector<Operation *> oldBbArgUsers(bbArg.getUsers());
+  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
+  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
   IRMapping bvm;
   bvm.map(destinationTensors[resultNumber], bbArg);
   auto tileableProducerClone =
       cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
-  auto scopeGuard =
-      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+  // If there was no extract_slice user, then no need to tile.
+  if (!sliceOpToTile) {
+    LLVM_DEBUG(DBGS() << "No extract_slice user. No need to tile cloned op.\n");
+    // Replace the old uses of bbArg with the cloned op, except for any parallel
+    // insert ops.
+    rewriter.replaceUsesWithIf(
+        bbArg, tileableProducerClone->getResult(resultNumber),
+        [&](OpOperand &operand) {
+          return !isa<tensor::ParallelInsertSliceOp>(operand.getOwner()) &&
+                 operand.getOwner() != tileableProducerClone.getOperation();
+        });
+    // Replace the use in containingOp.
+    rewriter.modifyOpInPlace(containingOp, [&]() {
+      containingOp->setOperand(pUse->getOperandNumber(),
+                               destinationTensors.front());
+    });
+    return {tileableProducerClone};
+  }
 
   // Tile the producer.
+  auto scopeGuard =
+      llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
   FailureOr<TilingResult> tileAndFuseResult =
       tileableProducerClone.generateResultTileValue(
           rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c6..c0c7b8ec9598bc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -202,6 +202,46 @@ module {
 
 // -----
 
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_no_slice
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_op_through_bbarg_no_slice(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+
+    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+    // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
+    %1 = scf.forall (%arg3) in (%arg0) shared_outs(%o = %0) -> (tensor<?xf32>) {
+      // CHECK: %[[T0:.*]] = linalg.fill {{.*}} outs(%[[BBARGOUT]]
+
+      // CHECK: %[[T1:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T0]]
+      %2 = linalg.elemwise_unary ins(%arg1 : tensor<?xf32>) outs(%o : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %2 into %o[0] [%d0] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %1 : tensor<?xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+      // linalg.fill is tileable. The op is tiled and fused.
+      transform.structured.fuse_into_containing_op %0 into %1
+        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+        transform.yield
+    }
+  }
+}
+
+// -----
+
 #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 #map1 = affine_map<(d0)[s0] -> (d0 * s0)>
 #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>



More information about the Mlir-commits mailing list