[Mlir-commits] [mlir] [MLIR][TilingInterface] Extend consumer fusion for multi-use of producer shared by terminator ops (PR #110105)

Abhishek Varma llvmlistbot at llvm.org
Mon Sep 30 01:45:06 PDT 2024


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/110105

>From eae411492047e7a9714e4a4cfca06b0325daf568 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 26 Sep 2024 10:53:24 +0000
Subject: [PATCH 1/3] [MLIR][TilingInterface] Extend consumer fusion for
 multi-use of producer

-- This commit extends consumer fusion to take place even if the producer
   has multiple uses.
-- The multiple uses of the producer essentially means that besides the consumer
   op in concern, the only other uses of the producer are allowed in :-
   1. scf.yield
   2. tensor.parallel_insert_slice

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../SCF/Transforms/TileUsingInterface.cpp     | 42 +++++++----
 .../tile-and-fuse-consumer.mlir               | 71 +++++++++++++++++++
 2 files changed, 98 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 7cfd772a72b175..cbf468b201653f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1481,21 +1481,33 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
 /// failure otherwise.
 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
+  // Check that the value has exactly one use which isn't a scf.yield or a
+  // tensor.parallel_insert_slice op.
+  Operation *visitedConsumerOp = nullptr;
+  for (OpOperand &opOperand : val.getUses()) {
+    Operation *consumerOp = opOperand.getOwner();
+    if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+      continue;
+    if (visitedConsumerOp && visitedConsumerOp != consumerOp)
+      return failure();
+    // TODO: We have to init result of consumer before scf.for, use
+    //       DestinationStyleOpInterface to get result shape from init for now.
+    //       Add support for other op such as op has InferTypeOpInterface.
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      return failure();
+    if (containingOpBlock != consumerOp->getBlock())
+      return failure();
+    visitedConsumerOp = consumerOp;
+  }
+
+  for (OpOperand &opOperand : val.getUses()) {
+    Operation *consumerOp = opOperand.getOwner();
+    if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+      continue;
+    return &opOperand;
+  }
+  return failure();
 }
 
 /// Find the perfectly nested loops outside of given loop(included) sorted from
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index fdefdcc453ae7a..f5f703d95e2d5b 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -437,3 +437,74 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
 //      CHECK:   }
 //      CHECK:   return %[[LOOP_RESULT1]]#1 :
+
+// -----
+
+// This test case checks fusion of consumer even if the producer has multiple uses.
+// The multiple uses of the producer essentially means that besides the consumer
+// op in concern, the only other uses of the producer are allowed in :-
+// 1. scf.yield
+// 2. tensor.parallel_insert_slice
+
+module {
+  module {
+    func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+      %c0 = arith.constant 0 : index
+      %c64 = arith.constant 64 : index
+      %c256 = arith.constant 256 : index
+      %cst = arith.constant 0.000000e+00 : f32
+      %0 = tensor.empty() : tensor<256x256xf32>
+      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+        %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) {
+          %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
+          %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
+          %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
+          %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
+          scf.yield %inserted_slice : tensor<256x256xf32>
+        }
+        %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+        scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32>
+      }
+      return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32>
+    }
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+//      CHECK: func.func @fuse_consumer_for_multi_use_producer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[dest1:.*]] = linalg.fill
+// CHECK-SAME:          outs(%[[dest0]] :
+//      CHECK:   %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
+// CHECK-SAME:   {
+//      CHECK:       %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
+// CHECK-SAME:         iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
+// CHECK-SAME:         {
+//      CHECK:            %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[MAT_OUT_SLICE]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
+// CHECK-SAME:              outs(%[[ADD_OUT_SLICE]] :
+//      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
+//      CHECK:         }
+//      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
+//      CHECK:   }
+//      CHECK:   return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :

>From 85af1e3167aac4360943d25a17fdbb4fc46e5cc7 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 30 Sep 2024 07:13:50 +0000
Subject: [PATCH 2/3] Address review comments

---
 .../Dialect/SCF/Transforms/TileUsingInterface.cpp  | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index cbf468b201653f..657bb6a98d2308 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1483,12 +1483,12 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
   // Check that the value has exactly one use which isn't a scf.yield or a
   // tensor.parallel_insert_slice op.
-  Operation *visitedConsumerOp = nullptr;
+  OpOperand *operand = nullptr;
   for (OpOperand &opOperand : val.getUses()) {
     Operation *consumerOp = opOperand.getOwner();
     if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
       continue;
-    if (visitedConsumerOp && visitedConsumerOp != consumerOp)
+    if (operand && *operand != opOperand)
       return failure();
     // TODO: We have to init result of consumer before scf.for, use
     //       DestinationStyleOpInterface to get result shape from init for now.
@@ -1498,16 +1498,10 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
       return failure();
     if (containingOpBlock != consumerOp->getBlock())
       return failure();
-    visitedConsumerOp = consumerOp;
+    operand = &opOperand;
   }
 
-  for (OpOperand &opOperand : val.getUses()) {
-    Operation *consumerOp = opOperand.getOwner();
-    if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
-      continue;
-    return &opOperand;
-  }
-  return failure();
+  return operand;
 }
 
 /// Find the perfectly nested loops outside of given loop(included) sorted from

>From bba3e9c39a63eee4db0baada37511984a588316e Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 30 Sep 2024 08:44:21 +0000
Subject: [PATCH 3/3] Address final review comment

---
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 657bb6a98d2308..a97296b4404a53 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1488,8 +1488,6 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
     Operation *consumerOp = opOperand.getOwner();
     if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
       continue;
-    if (operand && *operand != opOperand)
-      return failure();
     // TODO: We have to init result of consumer before scf.for, use
     //       DestinationStyleOpInterface to get result shape from init for now.
     //       Add support for other op such as op has InferTypeOpInterface.
@@ -1501,7 +1499,9 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
     operand = &opOperand;
   }
 
-  return operand;
+  if (operand)
+    return operand;
+  return failure();
 }
 
 /// Find the perfectly nested loops outside of given loop(included) sorted from



More information about the Mlir-commits mailing list