[Mlir-commits] [mlir] 2d73190 - [mlir][linalg] Fix bug in FuseIntoContainingOp implementation
Matthias Springer
llvmlistbot at llvm.org
Tue May 30 06:57:22 PDT 2023
Author: Matthias Springer
Date: 2023-05-30T15:55:30+02:00
New Revision: 2d731904170f1e3b378bfc556d939032e50c9a3d
URL: https://github.com/llvm/llvm-project/commit/2d731904170f1e3b378bfc556d939032e50c9a3d
DIFF: https://github.com/llvm/llvm-project/commit/2d731904170f1e3b378bfc556d939032e50c9a3d.diff
LOG: [mlir][linalg] Fix bug in FuseIntoContainingOp implementation
Do not replace uses inside the body of `scf.forall` ops with results of the same op.
Differential Revision: https://reviews.llvm.org/D151706
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9233ce9b89bfb..a6a3fbb2e23b8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -361,7 +361,8 @@ static Operation *replaceForAllWithNewSignature(
SetVector<Operation *> dominatedUsers;
DominanceInfo domInfo(containingOp);
for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
- if ((user != containingOp) && (domInfo.dominates(containingOp, user))) {
+ if (!containingOp->isAncestor(user) &&
+ (domInfo.dominates(containingOp, user))) {
dominatedUsers.insert(user);
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index d67b4802e772a..3854cceb6273d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -560,3 +560,69 @@ module {
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
}
}
+
+// -----
+
+// This is a regression test. Make sure that the transform succeeds and valid
+// IR is generated.
+
+module {
+ // CHECK-LABEL: func.func @softmax_dispatch_0_generic_16x128x128_f32
+ func.func @softmax_dispatch_0_generic_16x128x128_f32() -> tensor<16x128x128xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<5.000000e+00> : tensor<16x128x128xf32>
+ %cst_1 = arith.constant 5.000000e+00 : f32
+ %1 = tensor.empty() : tensor<16x128xf32>
+ %2 = tensor.empty() : tensor<16x128x128xf32>
+ %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32>
+ %4 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32>
+ %5 = linalg.generic {producer, indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cst : tensor<16x128x128xf32>) outs(%4 : tensor<16x128xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %8 = arith.maxf %in, %out : f32
+ linalg.yield %8 : f32
+ } -> tensor<16x128xf32>
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %7 = scf.forall (%arg0, %arg1) in (16, 32) shared_outs(%arg2 = %2) -> (tensor<16x128x128xf32>) {
+ %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1)
+ %extracted_slice = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+ %extracted_slice_3 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
+ %extracted_slice_4 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+ %15:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_3, %extracted_slice_4 : tensor<1x4x128xf32>, tensor<1x4xf32>) {
+ ^bb0(%in: f32, %out: f32, %out_9: f32):
+ %22 = arith.subf %cst_1, %in : f32
+ %23 = math.exp %22 : f32
+ %24 = arith.addf %23, %out_9 : f32
+ linalg.yield %23, %24 : f32, f32
+ } -> (tensor<1x4x128xf32>, tensor<1x4xf32>)
+ %extracted_slice_5 = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+ %extracted_slice_6 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
+ %extracted_slice_7 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32>
+ %19:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5 : tensor<1x4xf32>) outs(%extracted_slice_6, %extracted_slice_7 : tensor<1x4x128xf32>, tensor<1x4xf32>) {
+ ^bb0(%in: f32, %out: f32, %out_9: f32):
+ %22 = arith.subf %cst_1, %in : f32
+ %23 = math.exp %22 : f32
+ %24 = arith.addf %23, %out_9 : f32
+ linalg.yield %23, %24 : f32, f32
+ } -> (tensor<1x4x128xf32>, tensor<1x4xf32>)
+ %extracted_slice_8 = tensor.extract_slice %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32>
+ %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15#0, %19#1 : tensor<1x4x128xf32>, tensor<1x4xf32>) outs(%extracted_slice_8 : tensor<1x4x128xf32>) {
+ ^bb0(%in: f32, %in_9: f32, %out: f32):
+ %22 = arith.divf %in, %in_9 : f32
+ linalg.yield %22 : f32
+ } -> tensor<1x4x128xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %20 into %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<1x4x128xf32> into tensor<16x128x128xf32>
+ }
+ }
+ return %7 : tensor<16x128x128xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match attributes{producer} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+ }
+}
More information about the Mlir-commits
mailing list