[Mlir-commits] [mlir] [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (PR #80813)

Pablo Antonio Martinez llvmlistbot at llvm.org
Tue Mar 19 02:43:47 PDT 2024


================
@@ -304,6 +304,50 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
+/// Returns a vector of bools representing if, for the given axis, `op` can be
+/// tiled by `numThreads` without incurring in a race condition and thus it is
+/// thread-safe to do the tiling. This is checked by iterating over the affine
+/// maps of the outputs in `op` and ensuring that all the results in the map are
+/// present in the affine map represented by the tiling sizes, which is derived
+/// from `numThreads` or `nominalTileSizes`.
+SmallVector<bool>
+safeToTileToForall(mlir::MLIRContext *ctx, TilingInterface op,
+                   ArrayRef<OpFoldResult> numThreads,
+                   std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+                   int numDims) {
+  ArrayRef<OpFoldResult> tilingValues =
+      nominalTileSizes.has_value() ? *nominalTileSizes : numThreads;
+
+  SmallVector<bool> safeToTile(tilingValues.size(), true);
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+  if (!linalgOp)
+    return safeToTile;
+
+  SmallVector<AffineExpr> dimExprs;
+  dimExprs.reserve(numDims);
+  for (unsigned i = 0; i < tilingValues.size(); i++) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(tilingValues[i])) {
+      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1)
+        dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    } else {
+      dimExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    }
+  }
+
+  for (uint32_t resNum = 0; resNum < op->getNumResults(); resNum++) {
+    AffineMap map =
+        linalgOp.getIndexingMapMatchingResult(op->getResult(resNum));
+
+    for (AffineExpr r : dimExprs) {
+      unsigned int axis = cast<AffineDimExpr>(r).getPosition();
+      if (!llvm::is_contained(map.getResults(), r))
+        safeToTile[axis] = false;
+    }
+  }
----------------
pabloantoniom wrote:

I agree, we can simply check the iterator type. I have simplified my implementation to do this. 

I imagine that, just like you pointed out, the semantics of a `linalg.generic` that performs a reduction but uses the parallel iterator type are undefined. My original implementation would still catch those cases, but yes, I guess in those cases having a race is fine. I wonder, however, if my previous implementation could be reused to verify if the iterator types are correct (e.g., if the code performs a reduction but the corresponding iterator type is "parallel"). I don't know how useful/relevant that would be.

https://github.com/llvm/llvm-project/pull/80813


More information about the Mlir-commits mailing list