[Mlir-commits] [mlir] c41286a - [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 22 04:53:32 PDT 2024


Author: Pablo Antonio Martinez
Date: 2024-03-22T12:53:29+01:00
New Revision: c41286af3f30e099556c6edbef0001466afaefcb

URL: https://github.com/llvm/llvm-project/commit/c41286af3f30e099556c6edbef0001466afaefcb
DIFF: https://github.com/llvm/llvm-project/commit/c41286af3f30e099556c6edbef0001466afaefcb.diff

LOG: [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)

**Description**

The documentation of `transform.structured.tile_using_forall` says:

_"It is the user’s responsibility to ensure that num_threads/tile_sizes
is a valid tiling specification (i.e. that only tiles parallel
dimensions, e.g. in the Linalg case)."_

In other words, tiling a non-parallel dimension would generate code with
data races which is not safe to parallelize. For example, consider this
example (included in the tests in this PR):

```
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
  %0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) {
    %1 = affine.min #map(%arg2)
    %2 = affine.max #map1(%1)
    %3 = affine.apply #map2(%arg2)
    %extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor<?x300x8xf32>
    %4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor<?x300x8xf32>) outs(%arg3 : tensor<300x8xf32>) {
    ^bb0(%in: f32, %out: f32):
      %5 = arith.addf %in, %out : f32
      linalg.yield %5 : f32
    } -> tensor<300x8xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32>
    }
  }
  return %0 : tensor<300x8xf32>
}
```

We can easily see that this is not safe to parallelize because all
threads would be writing to the same position in `%arg3` (in the
`scf.forall.in_parallel`.

This PR detects wether it's safe to `tile_using_forall` and emits a
warning in the case it is not.

**Brief explanation**
It first generates a vector of affine expressions representing the tile
values and stores it in `dimExprs`. These affine expressions are
compared with the affine expressions coming from the results of the
affine map of each output in the linalg op. So going back to the
previous example, the original transform is:

```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>

func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
  // expected-warning at +1 {{tiling is not thread safe at axis #0}}
  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
  ^bb0(%in: f32, %out: f32):
    %1 = arith.addf %in, %out : f32
    linalg.yield %1 : f32
  } -> tensor<300x8xf32>
  return %0 : tensor<300x8xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}
```

The `num_threads` attribute would be represented as `(d0)`. Because the
linalg op has only one output (`arg1`) it would only check against the
results of `#map1`, which are `(d1, d2)`. The idea is to check that all
affine expressions in `dimExprs` are present in the output affine map.
In this example, `d0` is not in `(d1, d2)`, so tiling that axis is
considered not thread safe.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/tile-to-forall.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4f34016066b4ce..c260fe3f7a46a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1918,7 +1918,9 @@ def TileUsingForallOp :
 
     It is the user's responsibility to ensure that `num_threads/tile_sizes` is
     a valid tiling specification (i.e. that only tiles parallel dimensions,
-    e.g. in the Linalg case).
+    e.g. in the Linalg case). If the dimension is not parallelizable, a warning
+    is issued to notify the user that the generated code is not safe to
+    parallelize.
 
     If non-empty, the `mapping` is added as an attribute to the
     resulting `scf.forall`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 30aed850bed81e..462f692615faad 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,6 +304,28 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
+/// Returns a vector of bools representing if, for each axis, `op` can be tiled
+/// without incurring in a race condition and thus it is thread-safe to do the
+/// tiling. This is checked by iterating over numThreads and ensuring that the
+/// corresponding iterator type is "parallel". If it is not, then we know that
+/// such dimension is unsafe to tile.
+SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
+                                     ArrayRef<OpFoldResult> numThreads) {
+  auto iterators = linalgOp.getIteratorTypesArray();
+  SmallVector<bool> safeToTile(numThreads.size(), true);
+
+  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
+        safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
+      }
+    } else {
+      safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
+    }
+  }
+  return safeToTile;
+}
+
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
 /// tiling is specified by the number of tiles/threads `numThreads` and the
 /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
@@ -314,8 +336,10 @@ static void calculateTileOffsetsAndSizes(
 /// size of data.
 /// It is the user's responsibility to ensure that `numThreads` is a valid
 /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
-/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+/// Linalg case). If the dimension is not parallelizable, a warning is issued to
+/// notify the user that the generated code is not safe to parallelize. If
+/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
+/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
 static FailureOr<ForallTilingResult> tileToForallOpImpl(
     RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
     std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
@@ -344,6 +368,16 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
         return getValueOrCreateConstantIndexOp(b, loc, ofr);
       }));
 
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+  if (linalgOp) {
+    // Check if tiling is thread safe and print a warning if not.
+    SmallVector<bool> tilingSafety =
+        safeToTileToForall(b.getContext(), linalgOp, numThreads);
+    for (size_t i = 0; i < tilingSafety.size(); i++)
+      if (!tilingSafety[i])
+        op.emitWarning() << "tiling is not thread safe at axis #" << i;
+  }
+
   // 1. Create the ForallOp. We don't use the lambda body-builder
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.

diff  --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index abd807b3e4d3e1..12e2dea5530b59 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -586,3 +586,144 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<100xf32>
+  return %0 : tensor<100xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+
+func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #0}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<300x8xf32>
+  return %0 : tensor<300x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<100x8xf32>
+  return %0 : tensor<100x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2)>
+
+func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) {
+  // expected-warning at +2 {{tiling is not thread safe at axis #0}}
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
+  %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) {
+  ^bb0(%in: f32, %out1: f32, %out2: f32):
+    %1 = arith.addf %in, %out1 : f32
+    %2 = arith.addf %in, %out2 : f32
+    linalg.yield %1, %2 : f32, f32
+  } -> (tensor<100x8xf32>, tensor<8xf32>)
+  return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<100xf32>
+  return %0 : tensor<100xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @tile_thread_safety6(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #2}}
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list