[all-commits] [llvm/llvm-project] c41286: [mlir][linalg] Emit a warning when tile_using_fora...
Pablo Antonio Martinez via All-commits
all-commits at lists.llvm.org
Fri Mar 22 04:53:50 PDT 2024
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: c41286af3f30e099556c6edbef0001466afaefcb
https://github.com/llvm/llvm-project/commit/c41286af3f30e099556c6edbef0001466afaefcb
Author: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: 2024-03-22 (Fri, 22 Mar 2024)
Changed paths:
M mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
M mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
M mlir/test/Dialect/Linalg/tile-to-forall.mlir
Log Message:
-----------
[mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)
**Description**
The documentation of `transform.structured.tile_using_forall` says:
_"It is the user’s responsibility to ensure that num_threads/tile_sizes
is a valid tiling specification (i.e. that only tiles parallel
dimensions, e.g. in the Linalg case)."_
In other words, tiling a non-parallel dimension would generate code with
data races which is not safe to parallelize. For example, consider this
example (included in the tests in this PR):
```
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
%0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) {
%1 = affine.min #map(%arg2)
%2 = affine.max #map1(%1)
%3 = affine.apply #map2(%arg2)
%extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor<?x300x8xf32>
%4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor<?x300x8xf32>) outs(%arg3 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.addf %in, %out : f32
linalg.yield %5 : f32
} -> tensor<300x8xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32>
}
}
return %0 : tensor<300x8xf32>
}
```
We can easily see that this is not safe to parallelize because all
threads would be writing to the same position in `%arg3` (in the
`scf.forall.in_parallel`.
This PR detects wether it's safe to `tile_using_forall` and emits a
warning in the case it is not.
**Brief explanation**
It first generates a vector of affine expressions representing the tile
values and stores it in `dimExprs`. These affine expressions are
compared with the affine expressions coming from the results of the
affine map of each output in the linalg op. So going back to the
previous example, the original transform is:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
// expected-warning at +1 {{tiling is not thread safe at axis #0}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<300x8xf32>
return %0 : tensor<300x8xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
```
The `num_threads` attribute would be represented as `(d0)`. Because the
linalg op has only one output (`arg1`) it would only check against the
results of `#map1`, which are `(d1, d2)`. The idea is to check that all
affine expressions in `dimExprs` are present in the output affine map.
In this example, `d0` is not in `(d1, d2)`, so tiling that axis is
considered not thread safe.
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list