[Mlir-commits] [mlir] [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (PR #80813)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu Mar 21 03:52:11 PDT 2024


================
@@ -304,6 +304,44 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
+/// Returns a vector of bools representing if, for the given axis, `op` can be
+/// tiled by without incurring in a race condition and thus it is thread-safe to
+/// do the tiling. This is checked by iterating over the affine map represented
+/// by the tiling sizes (which is derived from `numThreads` or
+/// `nominalTileSizes`) and ensuring that the corresponding iterator type is
+/// not "reduction". If it is, then we know that such dimension is unsafe to
+/// tile.
+SmallVector<bool>
+safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
+                   ArrayRef<OpFoldResult> numThreads,
+                   std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+                   int numDims) {
+  ArrayRef<OpFoldResult> tilingValues =
+      nominalTileSizes.has_value() ? *nominalTileSizes : numThreads;
+  int minTile = nominalTileSizes.has_value() ? 0 : 1;
+
+  SmallVector<bool> safeToTile(tilingValues.size(), true);
+  SmallVector<AffineExpr> dimExprs;
+  dimExprs.reserve(numDims);
+  for (unsigned i = 0, e = tilingValues.size(); i != e; i++) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(tilingValues[i])) {
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > minTile)
+        dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    } else {
+      dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    }
+  }
+
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (AffineExpr r : dimExprs) {
+    unsigned int axis = cast<AffineDimExpr>(r).getPosition();
----------------
ftynse wrote:

We no longer need affine expressions here. The code above can just push `i` in the vector, and this code can just iterate over those. Furthermore, it may not be necessary to even have the second loop and a vector to communicate between them.

https://github.com/llvm/llvm-project/pull/80813


More information about the Mlir-commits mailing list