[Mlir-commits] [mlir] 1365ff7 - [mlir] allow repeated payload in structured.fuse_into_containing

Alex Zinenko llvmlistbot at llvm.org
Mon May 15 07:30:48 PDT 2023


Author: Alex Zinenko
Date: 2023-05-15T14:30:19Z
New Revision: 1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e

URL: https://github.com/llvm/llvm-project/commit/1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e
DIFF: https://github.com/llvm/llvm-project/commit/1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e.diff

LOG: [mlir] allow repeated payload in structured.fuse_into_containing

Structured fusion proceeds by iteratively finding the next suitable
producer to be fused into the loop. Therefore, it shouldn't matter if
the same producer is listed multiple times (e.g., it is used as multiple
operands). Adjust the implementation of the transform op to support this
case.

Also fix the checking code in the interpreter to actually respect the
TransformOpInterface indication that repeated payload is allowed, it
seems to have been accidentally dropped in one of the refactorings.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D150561

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cdeabb7435519..c7bc3767b27cf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -143,7 +143,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
 
 def FuseIntoContainingOp :
     Op<Transform_Dialect, "structured.fuse_into_containing_op",
-      [DeclareOpInterfaceMethods<TransformOpInterface>,
+      [DeclareOpInterfaceMethods<TransformOpInterface,
+          ["allowsRepeatedHandleOperands"]>,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Fuse a producer into a containing operation.";
 
@@ -160,7 +161,7 @@ def FuseIntoContainingOp :
     producer op handle may be associated with multiple payload ops. This
     transform fuses producers one-by-one, always picking an unspecified producer
     that has at least one use inside the containing op among the
-    producers.
+    producers. A producer can be listed multiple times in the handle.
 
     Note: If a producer has multiple uses inside the containing op, it is
     currently tiled and/or cloned multiple times into the containing op.
@@ -176,8 +177,8 @@ def FuseIntoContainingOp :
     containing op. I.e., "producers" that are not consumed within the containing
     op are rejected by this operation.
 
-    This operation reads and frees the producer handle.
-    This operation reads the containing op handle.
+    This operation consumes the producer handle.
+    This operation only reads the containing op handle.
   }];
 
   let arguments = (ins PDL_Operation:$producer_op,

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index afef59990afc1..0703ca31f402c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
   return fusedOp;
 }
 
+bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
+  // Allow repeated handles since we are fusing everything anyway.
+  return true;
+}
+
 DiagnosedSilenceableFailure
 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
                                        transform::TransformState &state) {
@@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 
   // Helper function to find the next producer that should be fused. Take any
   // producer that has a use inside the containing op.
-  SmallVector<Operation *> remainingProducers(producerOps.begin(),
-                                              producerOps.end());
+  SetVector<Operation *> remainingProducers(producerOps.begin(),
+                                            producerOps.end());
   auto getNextProducer = [&]() -> FailureOr<Operation *> {
     for (const auto &it : enumerate(remainingProducers)) {
       Operation *producerOp = it.value();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bad1d74fb473c..5685187e853f5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG("--handle not consumed -> SKIP\n");
         continue;
       }
+      if (transform.allowsRepeatedHandleOperands()) {
+        FULL_LDBG("--op allows repeated handles -> SKIP\n");
+        continue;
+      }
       FULL_LDBG("--handle is consumed\n");
 
       Type operandType = operand.get().getType();

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index d6b3ff3181b29..537ee8664df47 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -247,3 +247,39 @@ module {
     transform.structured.fuse_into_containing_op %0 into %1
   }
 }
+
+// -----
+
+module {
+  // CHECK-LABEL: func.func @fuse_repeated
+  func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> {
+    %c0 = arith.constant 0.0 : f32
+    %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32>
+
+    // CHECK: scf.forall
+    %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) {
+      %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+      %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+      // CHECK: %[[FUSED:.+]] = linalg.fill
+      // CHECK: elemwise_unary ins(%[[FUSED]]
+      %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
+      }
+    }
+
+    return %1 : tensor<2xf32>
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+
+    // Create a new handle that points to `linalg.fill` twice.
+    %2 = transform.merge_handles %0, %0 : !pdl.operation
+
+    // It shouldn't be a problem to fuse this handle.
+    transform.structured.fuse_into_containing_op %2 into %1
+  }
+}


        


More information about the Mlir-commits mailing list