[Mlir-commits] [mlir] [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (PR #80813)

Pablo Antonio Martinez llvmlistbot at llvm.org
Thu Mar 21 09:29:03 PDT 2024


https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/80813

>From 5d459622a7a91cf212b0ff1c512b8043364d94bc Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 5 Feb 2024 16:57:59 +0000
Subject: [PATCH 1/4] [mlir][Linalg][Transform] Emit a warning when
 tile_using_forall generates non thread-safe code
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This warning aims to complement the comment in the documentation that
says:

"It is the user’s responsibility to ensure that num_threads/tile_sizes
is a valid tiling specification (i.e. that only tiles parallel
dimensions, e.g. in the Linalg case)."

because:

1. Not all users of tile_using_forall know that tiling the wrong
dimension/s (e.g., a non-parallel dimension) will generate non
thread-safe code, so this warning will inform the user about it.

2. Users of tile_using_forall may know this limitation, but they may
not realize that they are tiling a non-parallel dimension, so the
warning may help in the debugging process.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |   4 +-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp |  57 +++++++++-
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  | 100 ++++++++++++++++++
 3 files changed, 158 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 309573a562872f..e947720471f78c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1922,7 +1922,9 @@ def TileUsingForallOp :
 
     It is the user's responsibility to ensure that `num_threads/tile_sizes` is
     a valid tiling specification (i.e. that only tiles parallel dimensions,
-    e.g. in the Linalg case).
+    e.g. in the Linalg case). If the dimension is not parallelizable, a warning
+    is issued to notify the user that the generated code is not safe to
+    parallelize.
 
     If non-empty, the `mapping` is added as an attribute to the
     resulting `scf.forall`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 30aed850bed81e..ed97ad70e6e390 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,6 +304,50 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
+/// Returns a vector of bools representing if, for the given axis, `op` can be
+/// tiled by `numThreads` without incurring in a race condition and thus it is
+/// thread-safe to do the tiling. This is checked by iterating over the affine
+/// maps of the outputs in `op` and ensuring that all the results in the map are
+/// present in the affine map represented by the tiling sizes, which is derived
+/// from `numThreads` or `nominalTileSizes`.
+SmallVector<bool>
+safeToTileToForall(mlir::MLIRContext *ctx, TilingInterface op,
+                   ArrayRef<OpFoldResult> numThreads,
+                   std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+                   int numDims) {
+  ArrayRef<OpFoldResult> tilingValues =
+      nominalTileSizes.has_value() ? *nominalTileSizes : numThreads;
+
+  SmallVector<bool> safeToTile(tilingValues.size(), true);
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+  if (!linalgOp)
+    return safeToTile;
+
+  SmallVector<AffineExpr> dimExprs;
+  dimExprs.reserve(numDims);
+  for (unsigned i = 0; i < tilingValues.size(); i++) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(tilingValues[i])) {
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1)
+        dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    } else {
+      dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    }
+  }
+
+  for (uint32_t resNum = 0; resNum < op->getNumResults(); resNum++) {
+    AffineMap map =
+        linalgOp.getIndexingMapMatchingResult(op->getResult(resNum));
+
+    for (AffineExpr r : dimExprs) {
+      unsigned int axis = cast<AffineDimExpr>(r).getPosition();
+      if (!llvm::is_contained(map.getResults(), r))
+        safeToTile[axis] = false;
+    }
+  }
+
+  return safeToTile;
+}
+
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
 /// tiling is specified by the number of tiles/threads `numThreads` and the
 /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
@@ -314,8 +358,10 @@ static void calculateTileOffsetsAndSizes(
 /// size of data.
 /// It is the user's responsibility to ensure that `numThreads` is a valid
 /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
-/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+/// Linalg case). If the dimension is not parallelizable, a warning is issued to
+/// notify the user that the generated code is not safe to parallelize. If
+/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
+/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
 static FailureOr<ForallTilingResult> tileToForallOpImpl(
     RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
     std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
@@ -344,6 +390,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
         return getValueOrCreateConstantIndexOp(b, loc, ofr);
       }));
 
+  // Check if tiling is thread safe and print a warning if not.
+  SmallVector<bool> tilingSafety = safeToTileToForall(
+      b.getContext(), op, numThreads, nominalTileSizes, loopRanges.size());
+  for (size_t i = 0; i < tilingSafety.size(); i++)
+    if (!tilingSafety[i])
+      op.emitWarning() << "tiling is not thread safe at axis #" << i;
+
   // 1. Create the ForallOp. We don't use the lambda body-builder
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index abd807b3e4d3e1..e52f76c619575a 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -586,3 +586,103 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-warning at +1 {{tiling is not thread safe at axis #1}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<100xf32>
+  return %0 : tensor<100xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+
+func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
+  // expected-warning at +1 {{tiling is not thread safe at axis #0}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<300x8xf32>
+  return %0 : tensor<300x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> {
+  // expected-warning at +1 {{tiling is not thread safe at axis #1}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<100x8xf32>
+  return %0 : tensor<100x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2)>
+
+func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) {
+  // expected-warning at +2 {{tiling is not thread safe at axis #0}}
+  // expected-warning at +1 {{tiling is not thread safe at axis #1}}
+  %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) {
+  ^bb0(%in: f32, %out1: f32, %out2: f32):
+    %1 = arith.addf %in, %out1 : f32
+    %2 = arith.addf %in, %out2 : f32
+    linalg.yield %1, %2 : f32, f32
+  } -> (tensor<100x8xf32>, tensor<8xf32>)
+  return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+

>From cb774445e9b4069af84d072e06866497136c90e5 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Thu, 14 Mar 2024 16:57:02 +0000
Subject: [PATCH 2/4] [mlir][Linalg][Transform] Small nits, bugfix and tests.

Fix bug when tile_size=1 was specified. Also,
add test case using tile_size and another use
case to tile linalg.matmul to show that this
also works for non linalg.generic ops
---
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 40 +++++++--------
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  | 49 +++++++++++++++++--
 2 files changed, 65 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index ed97ad70e6e390..fe9a3c658e6d50 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -305,38 +305,35 @@ static void calculateTileOffsetsAndSizes(
 }
 
 /// Returns a vector of bools representing if, for the given axis, `op` can be
-/// tiled by `numThreads` without incurring in a race condition and thus it is
-/// thread-safe to do the tiling. This is checked by iterating over the affine
-/// maps of the outputs in `op` and ensuring that all the results in the map are
-/// present in the affine map represented by the tiling sizes, which is derived
-/// from `numThreads` or `nominalTileSizes`.
+/// tiled by without incurring in a race condition and thus it is thread-safe to
+/// do the tiling. This is checked by iterating over the affine maps of the
+/// outputs in `op` and ensuring that all the results in the map are present in
+/// the affine map represented by the tiling sizes, which is derived from
+/// `numThreads` or `nominalTileSizes`.
 SmallVector<bool>
-safeToTileToForall(mlir::MLIRContext *ctx, TilingInterface op,
+safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
                    ArrayRef<OpFoldResult> numThreads,
                    std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
                    int numDims) {
   ArrayRef<OpFoldResult> tilingValues =
       nominalTileSizes.has_value() ? *nominalTileSizes : numThreads;
+  int minTile = nominalTileSizes.has_value() ? 0 : 1;
 
   SmallVector<bool> safeToTile(tilingValues.size(), true);
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
-  if (!linalgOp)
-    return safeToTile;
-
   SmallVector<AffineExpr> dimExprs;
   dimExprs.reserve(numDims);
-  for (unsigned i = 0; i < tilingValues.size(); i++) {
+  for (unsigned i = 0, e = tilingValues.size(); i != e; i++) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(tilingValues[i])) {
-      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1)
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > minTile)
         dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
     } else {
       dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
     }
   }
 
-  for (uint32_t resNum = 0; resNum < op->getNumResults(); resNum++) {
+  for (unsigned resNum = 0; resNum < linalgOp->getNumResults(); resNum++) {
     AffineMap map =
-        linalgOp.getIndexingMapMatchingResult(op->getResult(resNum));
+        linalgOp.getIndexingMapMatchingResult(linalgOp->getResult(resNum));
 
     for (AffineExpr r : dimExprs) {
       unsigned int axis = cast<AffineDimExpr>(r).getPosition();
@@ -390,12 +387,15 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
         return getValueOrCreateConstantIndexOp(b, loc, ofr);
       }));
 
-  // Check if tiling is thread safe and print a warning if not.
-  SmallVector<bool> tilingSafety = safeToTileToForall(
-      b.getContext(), op, numThreads, nominalTileSizes, loopRanges.size());
-  for (size_t i = 0; i < tilingSafety.size(); i++)
-    if (!tilingSafety[i])
-      op.emitWarning() << "tiling is not thread safe at axis #" << i;
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+  if (linalgOp) {
+    // Check if tiling is thread safe and print a warning if not.
+    SmallVector<bool> tilingSafety = safeToTileToForall(
+      b.getContext(), linalgOp, numThreads, nominalTileSizes, loopRanges.size());
+    for (size_t i = 0; i < tilingSafety.size(); i++)
+      if (!tilingSafety[i])
+        op.emitWarning() << "tiling is not thread safe at axis #" << i;
+  }
 
   // 1. Create the ForallOp. We don't use the lambda body-builder
   // version because we require the use of RewriterBase in the body, so we
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index e52f76c619575a..74eb0b12aa8d19 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -593,7 +593,7 @@ module attributes {transform.with_named_sequence} {
 #map1 = affine_map<(d0, d1) -> (d0)>
 
 func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
-  // expected-warning at +1 {{tiling is not thread safe at axis #1}}
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
   %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
   ^bb0(%in: f32, %out: f32):
     %1 = arith.addf %in, %out : f32
@@ -617,7 +617,7 @@ module attributes {transform.with_named_sequence} {
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
 
 func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
-  // expected-warning at +1 {{tiling is not thread safe at axis #0}}
+  // expected-warning at below {{tiling is not thread safe at axis #0}}
   %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
   ^bb0(%in: f32, %out: f32):
     %1 = arith.addf %in, %out : f32
@@ -641,7 +641,7 @@ module attributes {transform.with_named_sequence} {
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> {
-  // expected-warning at +1 {{tiling is not thread safe at axis #1}}
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
   %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) {
   ^bb0(%in: f32, %out: f32):
     %1 = arith.addf %in, %out : f32
@@ -667,7 +667,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) {
   // expected-warning at +2 {{tiling is not thread safe at axis #0}}
-  // expected-warning at +1 {{tiling is not thread safe at axis #1}}
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
   %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) {
   ^bb0(%in: f32, %out1: f32, %out2: f32):
     %1 = arith.addf %in, %out1 : f32
@@ -686,3 +686,44 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #1}}
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<100xf32>
+  return %0 : tensor<100xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @tile_thread_safety6(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-warning at below {{tiling is not thread safe at axis #2}}
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 4, 8]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
\ No newline at end of file

>From 3cfccde6c1170371ed1ca48949b69ccbd39fc255 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Thu, 14 Mar 2024 16:59:13 +0000
Subject: [PATCH 3/4] [mlir][Linalg][Transform] Simplify implementation

Rather than comparing the outputs affine maps against the maps infered
from the tile sizes, we can simply check that tiled dimensions do not
contain the "reduction" iterator type. If they do, then we are certain
they are not safe to tile.
---
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 28 +++++++++----------
 1 file changed, 13 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fe9a3c658e6d50..05d2d7cd945537 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -306,10 +306,11 @@ static void calculateTileOffsetsAndSizes(
 
 /// Returns a vector of bools representing if, for the given axis, `op` can be
 /// tiled by without incurring in a race condition and thus it is thread-safe to
-/// do the tiling. This is checked by iterating over the affine maps of the
-/// outputs in `op` and ensuring that all the results in the map are present in
-/// the affine map represented by the tiling sizes, which is derived from
-/// `numThreads` or `nominalTileSizes`.
+/// do the tiling. This is checked by iterating over the affine map represented
+/// by the tiling sizes (which is derived from `numThreads` or
+/// `nominalTileSizes`) and ensuring that the corresponding iterator type is
+/// not "reduction". If it is, then we know that such dimension is unsafe to
+/// tile.
 SmallVector<bool>
 safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
                    ArrayRef<OpFoldResult> numThreads,
@@ -331,15 +332,11 @@ safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
     }
   }
 
-  for (unsigned resNum = 0; resNum < linalgOp->getNumResults(); resNum++) {
-    AffineMap map =
-        linalgOp.getIndexingMapMatchingResult(linalgOp->getResult(resNum));
-
-    for (AffineExpr r : dimExprs) {
-      unsigned int axis = cast<AffineDimExpr>(r).getPosition();
-      if (!llvm::is_contained(map.getResults(), r))
-        safeToTile[axis] = false;
-    }
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (AffineExpr r : dimExprs) {
+    unsigned int axis = cast<AffineDimExpr>(r).getPosition();
+    if (iterators[axis] == utils::IteratorType::reduction)
+      safeToTile[axis] = false;
   }
 
   return safeToTile;
@@ -390,8 +387,9 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
   if (linalgOp) {
     // Check if tiling is thread safe and print a warning if not.
-    SmallVector<bool> tilingSafety = safeToTileToForall(
-      b.getContext(), linalgOp, numThreads, nominalTileSizes, loopRanges.size());
+    SmallVector<bool> tilingSafety =
+        safeToTileToForall(b.getContext(), linalgOp, numThreads,
+                           nominalTileSizes, loopRanges.size());
     for (size_t i = 0; i < tilingSafety.size(); i++)
       if (!tilingSafety[i])
         op.emitWarning() << "tiling is not thread safe at axis #" << i;

>From 903e883c316e44cd3361408b2580e667f8bba9cf Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Thu, 21 Mar 2024 16:26:19 +0000
Subject: [PATCH 4/4] [mlir][Linalg][Transform] Simplify even more the
 implementation

Simplify the implementation by removing the end loop and by only
checking numThreads (before it was also checking nominalTileSizes).
This is possible because in the event that num_threads is not specified,
it is automatically derived from tile_sizes, so we are always sure that
numThreads will contain the values we need.

Also changed a bit the tests to show that it the implementation
properly handles the case where there is a zero.
---
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 53 +++++++------------
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |  4 +-
 2 files changed, 20 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 05d2d7cd945537..19a74b15c947f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,41 +304,25 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
-/// Returns a vector of bools representing if, for the given axis, `op` can be
-/// tiled by without incurring in a race condition and thus it is thread-safe to
-/// do the tiling. This is checked by iterating over the affine map represented
-/// by the tiling sizes (which is derived from `numThreads` or
-/// `nominalTileSizes`) and ensuring that the corresponding iterator type is
-/// not "reduction". If it is, then we know that such dimension is unsafe to
-/// tile.
-SmallVector<bool>
-safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
-                   ArrayRef<OpFoldResult> numThreads,
-                   std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-                   int numDims) {
-  ArrayRef<OpFoldResult> tilingValues =
-      nominalTileSizes.has_value() ? *nominalTileSizes : numThreads;
-  int minTile = nominalTileSizes.has_value() ? 0 : 1;
-
-  SmallVector<bool> safeToTile(tilingValues.size(), true);
-  SmallVector<AffineExpr> dimExprs;
-  dimExprs.reserve(numDims);
-  for (unsigned i = 0, e = tilingValues.size(); i != e; i++) {
-    if (auto attr = llvm::dyn_cast_if_present<Attribute>(tilingValues[i])) {
-      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > minTile)
-        dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+/// Returns a vector of bools representing if, for each axis, `op` can be tiled
+/// without incurring in a race condition and thus it is thread-safe to do the
+/// tiling. This is checked by iterating over numThreads and ensuring that the
+/// corresponding iterator type is "parallel". If it is not, then we know that
+/// such dimension is unsafe to tile.
+SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
+                                     ArrayRef<OpFoldResult> numThreads) {
+  auto iterators = linalgOp.getIteratorTypesArray();
+  SmallVector<bool> safeToTile(numThreads.size(), true);
+
+  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
+        safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
+      }
     } else {
-      dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+      safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
     }
   }
-
-  auto iterators = linalgOp.getIteratorTypesArray();
-  for (AffineExpr r : dimExprs) {
-    unsigned int axis = cast<AffineDimExpr>(r).getPosition();
-    if (iterators[axis] == utils::IteratorType::reduction)
-      safeToTile[axis] = false;
-  }
-
   return safeToTile;
 }
 
@@ -387,9 +371,8 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
   if (linalgOp) {
     // Check if tiling is thread safe and print a warning if not.
-    SmallVector<bool> tilingSafety =
-        safeToTileToForall(b.getContext(), linalgOp, numThreads,
-                           nominalTileSizes, loopRanges.size());
+    SmallVector<bool> tilingSafety = safeToTileToForall(
+        b.getContext(), linalgOp, numThreads);
     for (size_t i = 0; i < tilingSafety.size(); i++)
       if (!tilingSafety[i])
         op.emitWarning() << "tiling is not thread safe at axis #" << i;
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 74eb0b12aa8d19..12e2dea5530b59 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -722,8 +722,8 @@ func.func @tile_thread_safety6(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: ten
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 4, 8]
+    %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8]
           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list