[Mlir-commits] [mlir] ba5591e - [mlir][Transform] Reuse bbArgs in FuseIntoContainingOp (#135066)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 15 01:29:38 PDT 2025


Author: Pablo Antonio Martinez
Date: 2025-05-15T09:29:33+01:00
New Revision: ba5591e39d50973bf60fe2716d085465b62768e8

URL: https://github.com/llvm/llvm-project/commit/ba5591e39d50973bf60fe2716d085465b62768e8
DIFF: https://github.com/llvm/llvm-project/commit/ba5591e39d50973bf60fe2716d085465b62768e8.diff

LOG: [mlir][Transform] Reuse bbArgs in FuseIntoContainingOp (#135066)

When fusing two ops with the same output operand using
FuseIntoContainingOp, the current implementation makes both ops write
into a different value pointing to the same tensor. This, in the end,
will bufferize into two different buffers, which is sub-optimal. The
current patch solves this problem, adding support to reuse the tensor by
both consumer and producer.

More precisely, before FuseIntoContainingOp is applied, we may have two
ops that write into the same output tensor. However, the consumer would
be tiled, thus the op would write into the loop iter_args (i.e., it does
not write directly into the original tensor). When the producer is fused
into the loop, the output tensor of the producer remains the same, so
the consumer and producer writes into two different values (consumer
writes into the iter_args and producer into the original tensor).

The current patch clones the consumer into the loop and checks if the
consumer is writing to the same value pointed by the loop inits, in
which case, it makes the output point to such tensor.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fbe7593420102..a9370dc003830 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -718,6 +718,54 @@ static Operation *replaceForAllWithNewSignature(
   return newforallOp;
 }
 
+/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
+/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
+/// outer loop. To determine the second condition, this function iterates
+/// using a worklist over the enclosing loops, trying to find 'src' in any of
+/// the parent loop's iter args.
+static bool sameOrEquivalentIterArg(Value src, Value dst) {
+  // Stack like vector containing possible iterArgs candidates. The first one
+  // is dst, and we will transverse the IR from there.
+  SmallVector<Value> destWorklist;
+  destWorklist.push_back(dst);
+
+  while (!destWorklist.empty()) {
+    Value currentDst = destWorklist.pop_back_val();
+
+    // We have found the same operand in some iter arg in the loop structure,
+    // so src and dst are equivalent.
+    if (src == currentDst)
+      return true;
+
+    // The operands are not equivalent, look for enclosing loops over
+    // currentDst.
+    auto bbArg = dyn_cast<BlockArgument>(currentDst);
+    if (!bbArg)
+      continue;
+
+    Block *parentBlock = bbArg.getOwner();
+    assert(parentBlock && "unlinked block argument");
+
+    Operation *parentOp = parentBlock->getParentOp();
+    assert(parentOp && "expected block argument with parent operation");
+
+    // Check if parent is loop-like. If it's not, do not add it to the worklist.
+    auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
+    if (!parentLoop)
+      continue;
+
+    for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
+      // No need to check for null as innerIterArg is tied to parentLoop.
+      OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
+      Value loopBlockArgument =
+          parentLoop->getOperand(operand->getOperandNumber());
+      destWorklist.push_back(loopBlockArgument);
+    }
+  }
+
+  return false;
+}
+
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
@@ -755,6 +803,40 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(sliceOpToTile);
 
+  // Clone the producer inside the consumer and try to update the producer init
+  // operands using the loop bbArgs if applicable. More precisely, if the bbArg
+  // of the container loop points to a value that it is used by the consumer op,
+  // then, instead of using such value on the consumer, use the value coming
+  // from the bbArg instead. This allows to reuse the output tensor (instead of
+  // creating a new one) of the container when both producer and container write
+  // to the same output.
+  if (LoopLikeOpInterface containerLoop =
+          dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
+    Operation *clone = rewriter.clone(*producerOp);
+    rewriter.modifyOpInPlace(clone, [&]() {
+      // Iterate over the outputs of the producer and over the loop bbArgs and
+      // check if any bbArg points to the same value as the producer output. In
+      // such case, make the producer output point to the bbArg directly.
+      for (OpOperand &initOperandPtr :
+           cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
+        Value producerOperand =
+            clone->getOperand(initOperandPtr.getOperandNumber());
+        for (BlockArgument containerIterArg :
+             containerLoop.getRegionIterArgs()) {
+          OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
+          Value consumerOperand =
+              containerLoop->getOperand(bbArg->getOperandNumber());
+          // The producer has the same init as the loop bbArg, use it.
+          if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
+            initOperandPtr.set(containerIterArg);
+          }
+        }
+      }
+    });
+
+    tileableProducer = dyn_cast<TilingInterface>(clone);
+  }
+
   // Tile the producer.
   int64_t resultNumber =
       cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
@@ -797,6 +879,10 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
       rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
       resultNumber, offsets, sizes);
 
+  // Cleanup clone.
+  if (dyn_cast<LoopLikeOpInterface>(containingOp))
+    rewriter.eraseOp(tileableProducer);
+
   return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c..572a2ae70e0a4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -206,6 +206,106 @@ module {
 #map1 = affine_map<(d0)[s0] -> (d0 * s0)>
 #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
 
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[INOUT:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_op_through_bbarg_inout(%arg0: index, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+    %1 = affine.apply #map0()[%d0, %arg0]
+
+    // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[INOUT]]) -> (tensor<?xf32>) {
+    %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg1) -> (tensor<?xf32>) {
+      %3 = affine.apply #map1(%arg3)[%arg0]
+      %4 = affine.min #map2(%arg3)[%d0, %arg0]
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T1:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
+      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %2 : tensor<?xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+      // linalg.fill is tileable. The op is tiled and fused.
+      transform.structured.fuse_into_containing_op %0 into %1
+        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+        transform.yield
+    }
+  }
+}
+
+// -----
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
+  //  CHECK-SAME:   %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
+  //  CHECK-SAME:   %[[ARG1:[0-9a-z]+]]: tensor<?x?x?xf32>
+  func.func @fuse_tileable_op_through_bbarg_inout_nested(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+    %c2 = arith.constant 2 : index
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
+    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
+    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
+    // CHECK:   scf.for {{.*}} iter_args(%[[BBARG0:.*]] = %[[ARG1]]) -> (tensor<?x?x?xf32>) {
+    // CHECK:     scf.for {{.*}} iter_args(%[[BBARG1:.*]] = %[[BBARG0]]) -> (tensor<?x?x?xf32>) {
+    // CHECK:       scf.for {{.*}} iter_args(%[[BBARG2:.*]] = %[[BBARG1]]) -> (tensor<?x?x?xf32>) {
+    %1 = scf.for %arg2 = %c0 to %dim step %c1 iter_args(%arg3 = %arg1) -> (tensor<?x?x?xf32>) {
+      %2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args(%arg5 = %arg3) -> (tensor<?x?x?xf32>) {
+        %3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%arg7 = %arg5) -> (tensor<?x?x?xf32>) {
+          // CHECK:  %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor<?x?x?xf32> to tensor<1x1x1xf32>
+          // CHECK:  linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+          // CHECK:  %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor<?x?x?xf32> to tensor<1x1x1xf32>
+          // CHECK:  linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+          %extracted_slice = tensor.extract_slice %0[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
+          %extracted_slice_2 = tensor.extract_slice %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
+          %4 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+          %inserted_slice = tensor.insert_slice %4 into %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<?x?x?xf32>
+          scf.yield %inserted_slice : tensor<?x?x?xf32>
+        }
+        scf.yield %3 : tensor<?x?x?xf32>
+      }
+      scf.yield %2 : tensor<?x?x?xf32>
+    }
+    return %1 : tensor<?x?x?xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %2:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      %3:3 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+      transform.structured.fuse_into_containing_op %2#0 into %3#2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
 module {
   // CHECK-LABEL: func.func @fuse_tileable_multi_output_op
   //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index


        


More information about the Mlir-commits mailing list