[Mlir-commits] [mlir] [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (PR #80813)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu Mar 14 07:12:59 PDT 2024


================
@@ -304,6 +304,50 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
+/// Returns a vector of bools representing if, for the given axis, `op` can be
+/// tiled by `numThreads` without incurring in a race condition and thus it is
+/// thread-safe to do the tiling. This is checked by iterating over the affine
+/// maps of the outputs in `op` and ensuring that all the results in the map are
+/// present in the affine map represented by the tiling sizes, which is derived
+/// from `numThreads` or `nominalTileSizes`.
+SmallVector<bool>
+safeToTileToForall(mlir::MLIRContext *ctx, TilingInterface op,
+                   ArrayRef<OpFoldResult> numThreads,
+                   std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+                   int numDims) {
+  ArrayRef<OpFoldResult> tilingValues =
+      nominalTileSizes.has_value() ? *nominalTileSizes : numThreads;
+
+  SmallVector<bool> safeToTile(tilingValues.size(), true);
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+  if (!linalgOp)
+    return safeToTile;
+
+  SmallVector<AffineExpr> dimExprs;
+  dimExprs.reserve(numDims);
+  for (unsigned i = 0; i < tilingValues.size(); i++) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(tilingValues[i])) {
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1)
----------------
ftynse wrote:

This doesn't look correct given that `tilingValues` is either tile sizes or num threads (that is, number of tiles). When tile size is 1, we will have as many tiles (and, therefore, hypothetical threads in the forall loop) as elements along the axis, so it's clearly not safe to ignore here.

https://github.com/llvm/llvm-project/pull/80813


More information about the Mlir-commits mailing list