[Mlir-commits] [mlir] 19b9c74 - [mlir] Return new scf.forall handle in fuse_into_containing_op

Harsh Menon llvmlistbot at llvm.org
Thu May 25 09:47:04 PDT 2023


Author: Harsh Menon
Date: 2023-05-25T09:34:46-07:00
New Revision: 19b9c74b42824c5f7d73481d9c1fe8e385a4426c

URL: https://github.com/llvm/llvm-project/commit/19b9c74b42824c5f7d73481d9c1fe8e385a4426c
DIFF: https://github.com/llvm/llvm-project/commit/19b9c74b42824c5f7d73481d9c1fe8e385a4426c.diff

LOG: [mlir] Return new scf.forall handle in fuse_into_containing_op

Since the scf.forall is now consumed by the fuse into
containing op, we need to return a handle to the new scf.forall.
This patch does that and also ensures that the new bbArg
added to the scf.forall is used in its body.

Differential Revision: https://reviews.llvm.org/D151418

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
    mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 179cafdbf274f..06ef84b43f04b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -183,7 +183,8 @@ def FuseIntoContainingOp :
 
   let arguments = (ins TransformHandleTypeInterface:$producer_op,
                        TransformHandleTypeInterface:$containing_op);
-  let results = (outs TransformHandleTypeInterface:$fused_op);
+  let results = (outs TransformHandleTypeInterface:$fused_op,
+                      TransformHandleTypeInterface:$new_containing_op);
   let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
                        " `:` functional-type(operands, results)";
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4f476d1053827..a63f6647cf7b3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -344,7 +344,8 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
                                             Value producerOp,
                                             Value containingOp) {
   result.addOperands({producerOp, containingOp});
-  result.addTypes(transform::AnyOpType::get(builder.getContext()));
+  auto resultType = transform::AnyOpType::get(builder.getContext());
+  result.addTypes({resultType, resultType});
 }
 
 /// Add new operands to the forall op for users of the producerOp
@@ -388,8 +389,16 @@ static Operation *replaceForAllWithNewSignature(
   newforallOp.getRegion().takeBody(forallOp.getRegion());
 
   // Add additional block argument for new value being returned
+  // and replaces all uses of the new output with corresponding bbArg
+  // inside the scf.forall to enable fusion into this new scf.forall.
   newforallOp.getBody()->addArgument(newOuts.back().getType(),
                                      newOuts.back().getLoc());
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return newforallOp->isProperAncestor(op);
+                             });
 
   // Fix terminator
   scf::InParallelOp terminatorOp = newforallOp.getTerminator();
@@ -749,14 +758,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   }
 
   results.set(cast<OpResult>(getFusedOp()), fusedOps);
+  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::FuseIntoContainingOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getProducerOp(), effects);
-  onlyReadsHandle(getContainingOp(), effects);
-  producesHandle(getFusedOp(), effects);
+  consumesHandle(getContainingOp(), effects);
+  producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index f3f480247f7ff..d67b4802e772a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -48,7 +48,7 @@ module {
 
     // linalg.fill is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -92,7 +92,7 @@ module {
 
     // tensor.empty is not tileable. The op is cloned and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -139,7 +139,7 @@ module {
 
     // linalg.fill is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -188,7 +188,7 @@ module {
 
     // linalg.fill is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.any_op, !transform.any_op) -> !transform.any_op
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -249,7 +249,7 @@ module {
 
     // linalg.generic is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -285,7 +285,7 @@ module {
     %2 = transform.merge_handles %0, %0 : !transform.any_op
 
     // It shouldn't be a problem to fuse this handle.
-    transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -351,7 +351,7 @@ module {
 
     // linalg.generic is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -417,7 +417,7 @@ module {
 
     // linalg.generic is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -482,6 +482,81 @@ module {
 
     // linalg.generic is tileable. The op is tiled and fused.
     transform.structured.fuse_into_containing_op %0 into %1
-      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+  }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+#map3 = affine_map<(d0) -> (d0)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_using_new_handle
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+    -> (tensor<?xf32>, tensor<?xf32>) {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+
+    %0 = linalg.generic {
+      indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
+      } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+        ^bb0(%a: f32, %b: f32):
+          %d = arith.addf %a, %b : f32
+          linalg.yield %d : f32
+        } -> tensor<?xf32>
+
+    %1 = linalg.generic {
+      indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
+      } ins(%0 : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+        ^bb0(%a: f32, %b: f32):
+          %d = arith.mulf %a, %b : f32
+          linalg.yield %d : f32
+        } -> tensor<?xf32>
+    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+    %2 = affine.apply #map0()[%d0, %idx]
+
+    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+    %3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+      %4 = affine.apply #map1(%i)[%idx]
+      // CHECK: %[[I1:.*]] = affine.min {{.*}}
+      %5 = affine.min #map2(%i)[%d0, %idx]
+      %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+      // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
+      %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
+
+      %8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: return %[[R0]]#0, %[[R0]]#1
+    func.return %3, %1 : tensor<?xf32>, tensor<?xf32>
+    // CHECK: }
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+    %add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">)
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+    %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
+    %fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
index e824a55b92a64..7dd31835ce84f 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
@@ -52,7 +52,7 @@ module {
 
     // Fuse all producers.
     transform.structured.fuse_into_containing_op %producers into %forall_op
-      : (!transform.any_op, !transform.any_op) -> !transform.any_op
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
   }
 }
 
@@ -112,6 +112,6 @@ module {
 
     // Fuse all producers.
     transform.structured.fuse_into_containing_op %reversed_producers into %forall_op
-      : (!transform.any_op, !transform.any_op) -> !transform.any_op
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
   }
 }


        


More information about the Mlir-commits mailing list