[Mlir-commits] [mlir] [mlir] Fix condition for fusability in consumer fusion API (PR #115768)

Quinn Dawkins llvmlistbot at llvm.org
Mon Nov 11 13:04:21 PST 2024


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/115768

It was previously allowing either a tilable or dps op to be fused. Both are required for consumer fusion.

>From 16fff4894b099ced1bd0e239230d0c3dad5c5ef1 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 7 Nov 2024 16:41:17 -0500
Subject: [PATCH] [mlir] Fix condition for fusability in consumer fusion API

It was previously allowing either a tilable or dps op to be fused. Both
are required for consumer fusion.
---
 .../SCF/Transforms/TileUsingInterface.cpp     |  3 +-
 .../tile-and-fuse-consumer.mlir               | 43 +++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 02e58141bdc303..98403c0a7a91b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1710,7 +1710,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
   for (OpOperand &opOperand : val.getUses()) {
     Operation *consumerOp = opOperand.getOwner();
     // Step 1. Check if the user is tilable.
-    if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp)) {
       // TODO: We have to init result of consumer before scf.for, use
       // DestinationStyleOpInterface to get result shape from init for now. Add
       // support for other op such as op has InferTypeOpInterface.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index af836d18e8c028..30166748d73929 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
 //      CHECK:   }
 //      CHECK:   return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
+
+// -----
+
+module {
+  func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+        %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+        %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+        scf.yield %insert_slice : tensor<256x256xf32>
+    }
+    %dest1 = tensor.empty() : tensor<258x258xf32>
+    %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
+    %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @no_fuse_only_dps_consumer(
+//      CHECK:   %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
+//      CHECK:     linalg.add
+//      CHECK:     linalg.mul
+//      CHECK:     scf.yield
+//      CHECK:   }
+//      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
+//      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]



More information about the Mlir-commits mailing list