[Mlir-commits] [mlir] 54ae9e7 - [mlir][SCF] Fix condition for fusability in consumer fusion API (#115768)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 11 13:44:28 PST 2024
Author: Quinn Dawkins
Date: 2024-11-11T16:44:24-05:00
New Revision: 54ae9e7bbae60a2ddc629e0a7a854492c241774d
URL: https://github.com/llvm/llvm-project/commit/54ae9e7bbae60a2ddc629e0a7a854492c241774d
DIFF: https://github.com/llvm/llvm-project/commit/54ae9e7bbae60a2ddc629e0a7a854492c241774d.diff
LOG: [mlir][SCF] Fix condition for fusability in consumer fusion API (#115768)
It was previously allowing either a tilable or dps op to be fused. Both
are required for consumer fusion.
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 02e58141bdc303..98403c0a7a91b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1710,7 +1710,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
for (OpOperand &opOperand : val.getUses()) {
Operation *consumerOp = opOperand.getOwner();
// Step 1. Check if the user is tilable.
- if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
// TODO: We have to init result of consumer before scf.for, use
// DestinationStyleOpInterface to get result shape from init for now. Add
// support for other op such as op has InferTypeOpInterface.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index af836d18e8c028..30166748d73929 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
// CHECK: }
// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
+
+// -----
+
+module {
+ func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+ %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %insert_slice : tensor<256x256xf32>
+ }
+ %dest1 = tensor.empty() : tensor<258x258xf32>
+ %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
+ %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @no_fuse_only_dps_consumer(
+// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
+// CHECK: linalg.add
+// CHECK: linalg.mul
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
+// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
More information about the Mlir-commits
mailing list