[Mlir-commits] [mlir] 2d73190 - [mlir][linalg] Fix bug in FuseIntoContainingOp implementation

Matthias Springer llvmlistbot at llvm.org
Tue May 30 06:57:22 PDT 2023

Author: Matthias Springer
Date: 2023-05-30T15:55:30+02:00
New Revision: 2d731904170f1e3b378bfc556d939032e50c9a3d

URL: https://github.com/llvm/llvm-project/commit/2d731904170f1e3b378bfc556d939032e50c9a3d
DIFF: https://github.com/llvm/llvm-project/commit/2d731904170f1e3b378bfc556d939032e50c9a3d.diff

LOG: [mlir][linalg] Fix bug in FuseIntoContainingOp implementation

Do not replace uses inside the body of `scf.forall` ops with results of the same op.

Differential Revision: https://reviews.llvm.org/D151706




diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9233ce9b89bfb..a6a3fbb2e23b8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -361,7 +361,8 @@ static Operation *replaceForAllWithNewSignature(
   SetVector<Operation *> dominatedUsers;
   DominanceInfo domInfo(containingOp);
   for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
-    if ((user != containingOp) && (domInfo.dominates(containingOp, user))) {
+    if (!containingOp->isAncestor(user) &&
+        (domInfo.dominates(containingOp, user))) {

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index d67b4802e772a..3854cceb6273d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -560,3 +560,69 @@ module {
       : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
+// -----
+// This is a regression test. Make sure that the transform succeeds and valid
+// IR is generated.
+module {
+  // CHECK-LABEL: func.func @softmax_dispatch_0_generic_16x128x128_f32
+  func.func @softmax_dispatch_0_generic_16x128x128_f32() -> tensor<16x128x128xf32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant dense<5.000000e+00> : tensor<16x128x128xf32>
+    %cst_1 = arith.constant 5.000000e+00 : f32
+    %1 = tensor.empty() : tensor<16x128xf32>
+    %2 = tensor.empty() : tensor<16x128x128xf32>
+    %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32>
+    %4 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32>
+    %5 = linalg.generic {producer, indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cst : tensor<16x128x128xf32>) outs(%4 : tensor<16x128xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %8 = arith.maxf %in, %out : f32
+      linalg.yield %8 : f32
+    } -> tensor<16x128xf32>
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %7 = scf.forall (%arg0, %arg1) in (16, 32) shared_outs(%arg2 = %2) -> (tensor<16x128x128xf32>) {
+      %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1)
+      %extracted_slice = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+      %extracted_slice_3 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
+      %extracted_slice_4 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+      %15:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_3, %extracted_slice_4 : tensor<1x4x128xf32>, tensor<1x4xf32>) {
+      ^bb0(%in: f32, %out: f32, %out_9: f32):
+        %22 = arith.subf %cst_1, %in : f32
+        %23 = math.exp %22 : f32
+        %24 = arith.addf %23, %out_9 : f32
+        linalg.yield %23, %24 : f32, f32
+      } -> (tensor<1x4x128xf32>, tensor<1x4xf32>)
+      %extracted_slice_5 = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+      %extracted_slice_6 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
+      %extracted_slice_7 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+      %19:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5 : tensor<1x4xf32>) outs(%extracted_slice_6, %extracted_slice_7 : tensor<1x4x128xf32>, tensor<1x4xf32>) {
+      ^bb0(%in: f32, %out: f32, %out_9: f32):
+        %22 = arith.subf %cst_1, %in : f32
+        %23 = math.exp %22 : f32
+        %24 = arith.addf %23, %out_9 : f32
+        linalg.yield %23, %24 : f32, f32
+      } -> (tensor<1x4x128xf32>, tensor<1x4xf32>)
+      %extracted_slice_8 = tensor.extract_slice %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
+      %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15#0, %19#1 : tensor<1x4x128xf32>, tensor<1x4xf32>) outs(%extracted_slice_8 : tensor<1x4x128xf32>) {
+      ^bb0(%in: f32, %in_9: f32, %out: f32):
+        %22 = arith.divf %in, %in_9 : f32
+        linalg.yield %22 : f32
+      } -> tensor<1x4x128xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %20 into %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<1x4x128xf32> into tensor<16x128x128xf32>
+      }
+    }
+    return %7 : tensor<16x128x128xf32>
+  }
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match attributes{producer} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+    transform.structured.fuse_into_containing_op %0 into %1
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+  }


More information about the Mlir-commits mailing list