[Mlir-commits] [mlir] [mlir][Transform] Reuse bbArgs in FuseIntoContainingOp (PR #135066)
Pablo Antonio Martinez
llvmlistbot at llvm.org
Wed Apr 9 11:41:19 PDT 2025
https://github.com/pabloantoniom created https://github.com/llvm/llvm-project/pull/135066
When fusing two ops with the same output operand using FuseIntoContainingOp, the current implementation makes both ops write into a different value pointing to the same tensor. This, in the end, will bufferize into two different buffers, which is sub-optimal. The current patch solves this problem, adding support to reuse the tensor by both consumer and producer.
More precisely, before FuseIntoContainingOp is applied, we may have two ops that write into the same output tensor. However, the consumer would be tiled, thus the op would write into the loop iter_args (i.e., it does not write directly into the original tensor). When the producer is fused into the loop, the output tensor of the producer remains the same, so the consumer and producer writes into two different values (consumer writes into the iter_args and producer into the original tensor).
The current patch clones the consumer into the loop and checks if the consumer is writing to the same value pointed by the loop inits, in which case, it makes the output point to such tensor.
Let's consider this example:
```mlir
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins (%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>) outs (%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>} ins (%0, %arg3: tensor<?x?xf32>, tensor<?x?xf32>) outs (%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
```
Then, let's assume we do tiling and fusing with `fuse_into_containing_op`.
Before this patch:
```mlir
%0 = scf.for %arg4 = %c0 to %dim step %c1 iter_args(%arg5 = %arg2) -> tensor<?x?xf32>) {
%extracted_slice_2 = tensor.extract_slice %arg2[%arg4, 0] [1, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<1x?xf32>
...
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins (...) outs (%extracted_slice_2 : tensor<1x?xf32>) -> tensor<1x?xf32>
...
%extracted_slice_4 = tensor.extract_slice %arg5[%arg4, 0] [1, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<1x?xf32>
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>} ins (...) outs (%extracted_slice_4 : tensor<1x?xf32>) -> tensor<1x?xf32>
...
}
```
After this patch:
```mlir
%0 = scf.for %arg4 = %c0 to %dim step %c1 iter_args(%arg5 = %arg2) -> tensor<?x?xf32>) {
%extracted_slice_2 = tensor.extract_slice %arg5[%arg4, 0] [1, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<1x?xf32>
...
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins (...) outs (%extracted_slice_2 : tensor<1x?xf32>) -> tensor<1x?xf32>
...
%extracted_slice_4 = tensor.extract_slice %arg5[%arg4, 0] [1, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<1x?xf32>
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>} ins (...) outs (%extracted_slice_4 : tensor<1x?xf32>) -> tensor<1x?xf32>
...
}
```
>From 9e65799a3b4802469999328aa76f9fc75ddc677c Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Wed, 9 Apr 2025 19:20:06 +0100
Subject: [PATCH] [mlir][Transform] Reuse bbArgs in FuseIntoContainingOp
When fusing two ops with the same output operand using
FuseIntoContainingOp, the current implementation makes both ops write
into a different value pointing to the same tensor. This, in the end,
will bufferize into two different buffers, which is sub-optimal. The
current patch solves this problem, adding support to reuse the tensor
by both consumer and producer.
More precisely, before FuseIntoContainingOp is applied, we may have
two ops that write into the same output tensor. However, the consumer
would be tiled, thus the op would write into the loop iter_args (i.e.,
it does not write directly into the original tensor). When the producer
is fused into the loop, the output tensor of the producer remains the
same, so the consumer and producer writes into two different values
(consumer writes into the iter_args and producer into the original
tensor).
The current patch clones the consumer into the loop and checks if the
consumer is writing to the same value pointed by the loop inits, in
which case, it makes the output point to such tensor.
---
.../TransformOps/LinalgTransformOps.cpp | 73 +++++++++++++
.../transform-op-fuse-into-containing.mlir | 100 ++++++++++++++++++
2 files changed, 173 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c90ebe4487ca4..cd1dfb8f214eb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -718,6 +718,42 @@ static Operation *replaceForAllWithNewSignature(
return newforallOp;
}
+/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
+/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
+/// outer loop. To determine the second condition, this function iterates
+/// recursively over the enclosing loops, trying to find 'src' in any of the
+/// parent loop's iter args.
+static bool sameOrEquivalentIterArg(Value src, Value dst) {
+ // Base case.
+ if (src == dst)
+ return true;
+
+ // Recursively look for equivalent iter args in enclosing loops.
+ if (auto bbArg = dyn_cast<BlockArgument>(dst)) {
+ Block *parentBlock = bbArg.getOwner();
+ assert(parentBlock && "unlinked block argument");
+
+ // Because we stop doing recursive calls when we find a non loop-like op,
+ // this should never happen.
+ assert(parentBlock->getParentOp() &&
+ "expected block argument with parent operation");
+
+ // Check if parent is loop-like.
+ if (auto parentLoop =
+ dyn_cast<LoopLikeOpInterface>(parentBlock->getParentOp())) {
+ for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
+ OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
+ Value loopBlockArgument =
+ parentLoop->getOperand(operand->getOperandNumber());
+ if (sameOrEquivalentIterArg(src, loopBlockArgument))
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
@@ -755,6 +791,39 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOpToTile);
+ // Clone the producer inside the consumer and try to update the producer init
+ // operands using the loop bbArgs if applicable. More precisely, if the bbArg
+ // of the container loop points to a value that it is used by the consumer op,
+ // then, instead of using such value on the consumer, use the value coming
+ // from the bbArg instead. This allows to reuse the output tensor (instead of
+ // creating a new one) of the container when both producer and container write
+ // to the same output.
+ if (LoopLikeOpInterface containerLoop =
+ dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
+ Operation *clone = rewriter.clone(*producerOp);
+ rewriter.modifyOpInPlace(clone, [&]() {
+ // Iterate over the outputs of the producer and over the loop bbArgs and
+ // check if any bbArg points to the same value as the producer output. In
+ // such case, make the producer output point to the bbArg directly.
+ for (auto &initOperandPtr :
+ cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
+ Value producerOperand =
+ clone->getOperand(initOperandPtr.getOperandNumber());
+ for (auto containerIterArg : containerLoop.getRegionIterArgs()) {
+ OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
+ Value consumerOperand =
+ containerLoop->getOperand(bbArg->getOperandNumber());
+ // The producer has the same init as the loop bbArg, use it.
+ if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
+ initOperandPtr.set(containerIterArg);
+ }
+ }
+ }
+ });
+
+ tileableProducer = dyn_cast<TilingInterface>(clone);
+ }
+
// Tile the producer.
int64_t resultNumber =
cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
@@ -797,6 +866,10 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
resultNumber, offsets, sizes);
+ // Cleanup clone.
+ if (dyn_cast<LoopLikeOpInterface>(containingOp))
+ rewriter.eraseOp(tileableProducer);
+
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c..572a2ae70e0a4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -206,6 +206,106 @@ module {
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[INOUT:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_op_through_bbarg_inout(%arg0: index, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %1 = affine.apply #map0()[%d0, %arg0]
+
+ // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[INOUT]]) -> (tensor<?xf32>) {
+ %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg1) -> (tensor<?xf32>) {
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.min #map2(%arg3)[%d0, %arg0]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
+ %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %2 : tensor<?xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+ // linalg.fill is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
+ // CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
+ // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x?xf32>
+ func.func @fuse_tileable_op_through_bbarg_inout_nested(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
+ %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
+ %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
+ // CHECK: scf.for {{.*}} iter_args(%[[BBARG0:.*]] = %[[ARG1]]) -> (tensor<?x?x?xf32>) {
+ // CHECK: scf.for {{.*}} iter_args(%[[BBARG1:.*]] = %[[BBARG0]]) -> (tensor<?x?x?xf32>) {
+ // CHECK: scf.for {{.*}} iter_args(%[[BBARG2:.*]] = %[[BBARG1]]) -> (tensor<?x?x?xf32>) {
+ %1 = scf.for %arg2 = %c0 to %dim step %c1 iter_args(%arg3 = %arg1) -> (tensor<?x?x?xf32>) {
+ %2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args(%arg5 = %arg3) -> (tensor<?x?x?xf32>) {
+ %3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%arg7 = %arg5) -> (tensor<?x?x?xf32>) {
+ // CHECK: %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor<?x?x?xf32> to tensor<1x1x1xf32>
+ // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ // CHECK: %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor<?x?x?xf32> to tensor<1x1x1xf32>
+ // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ %extracted_slice = tensor.extract_slice %0[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
+ %4 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ %inserted_slice = tensor.insert_slice %4 into %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<?x?x?xf32>
+ scf.yield %inserted_slice : tensor<?x?x?xf32>
+ }
+ scf.yield %3 : tensor<?x?x?xf32>
+ }
+ scf.yield %2 : tensor<?x?x?xf32>
+ }
+ return %1 : tensor<?x?x?xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %2:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %3:3 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.structured.fuse_into_containing_op %2#0 into %3#2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
module {
// CHECK-LABEL: func.func @fuse_tileable_multi_output_op
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
More information about the Mlir-commits
mailing list