[Mlir-commits] [mlir] [mlir] Fix condition for fusability in consumer fusion API (PR #115768)
Quinn Dawkins
llvmlistbot at llvm.org
Mon Nov 11 13:04:21 PST 2024
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/115768
It was previously allowing either a tilable or dps op to be fused. Both are required for consumer fusion.
>From 16fff4894b099ced1bd0e239230d0c3dad5c5ef1 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 7 Nov 2024 16:41:17 -0500
Subject: [PATCH] [mlir] Fix condition for fusability in consumer fusion API
It was previously allowing either a tilable or dps op to be fused. Both
are required for consumer fusion.
---
.../SCF/Transforms/TileUsingInterface.cpp | 3 +-
.../tile-and-fuse-consumer.mlir | 43 +++++++++++++++++++
2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 02e58141bdc303..98403c0a7a91b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1710,7 +1710,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
for (OpOperand &opOperand : val.getUses()) {
Operation *consumerOp = opOperand.getOwner();
// Step 1. Check if the user is tilable.
- if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp)) {
// TODO: We have to init result of consumer before scf.for, use
// DestinationStyleOpInterface to get result shape from init for now. Add
// support for other op such as op has InferTypeOpInterface.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index af836d18e8c028..30166748d73929 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
// CHECK: }
// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
+
+// -----
+
+module {
+ func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+ %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %insert_slice : tensor<256x256xf32>
+ }
+ %dest1 = tensor.empty() : tensor<258x258xf32>
+ %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
+ %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @no_fuse_only_dps_consumer(
+// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
+// CHECK: linalg.add
+// CHECK: linalg.mul
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
+// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
More information about the Mlir-commits
mailing list