[Mlir-commits] [mlir] [mlir][acc] Capture explicit serial semantics for compute regions (PR #195158)

Razvan Lupusoru llvmlistbot at llvm.org
Thu Apr 30 12:58:42 PDT 2026


================
@@ -194,61 +209,67 @@ getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
   return parDims;
 }
 
-/// Create acc.par_width operations from gang/worker/vector values of a
-/// compute construct. Queries the device-type-specific values first, falling
-/// back to the default (DeviceType::None) values.
+/// Build `acc.compute_region` launch operands: one sequential `acc.par_width`
+/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs
+/// operand and num_workers / vector_length are the constant 1, and otherwise
+/// `acc.par_width` from gang/worker/vector (device-type operands first, then
+/// default DeviceType::None).
 template <typename ComputeConstructT>
 static SmallVector<Value>
 assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
                       RewriterBase &rewriter,
                       const ACCToGPUMappingPolicy &policy) {
-  SmallVector<Value> values;
   auto *ctx = rewriter.getContext();
-  auto indexTy = rewriter.getIndexType();
   auto loc = computeOp->getLoc();
 
-  auto numGangs = computeOp.getNumGangsValues(deviceType);
-  if (numGangs.empty())
-    numGangs = computeOp.getNumGangsValues();
-  for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
-    auto gangLevel = getGangParLevel(gangDimIdx + 1);
-    values.push_back(
-        ParWidthOp::create(rewriter, loc,
-                           getValueOrCreateCastToIndexLike(
-                               rewriter, gangSize.getLoc(), indexTy, gangSize),
-                           policy.gangDim(ctx, gangLevel)));
-  }
+  if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
+    return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
+  } else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
+                                       KernelsOp>::value) {
+    if (isEffectivelySerial(computeOp))
+      return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
+
+    SmallVector<Value> values;
+    auto indexTy = rewriter.getIndexType();
+
+    auto numGangs = computeOp.getNumGangsValues(deviceType);
+    if (numGangs.empty())
+      numGangs = computeOp.getNumGangsValues();
+    for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
+      auto gangLevel = getGangParLevel(gangDimIdx + 1);
+      values.push_back(ParWidthOp::create(
+          rewriter, loc,
+          getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy,
+                                          gangSize),
+          policy.gangDim(ctx, gangLevel)));
+    }
 
-  Value numWorkers = computeOp.getNumWorkersValue(deviceType);
-  if (!numWorkers)
-    numWorkers = computeOp.getNumWorkersValue();
-  if (numWorkers) {
-    values.push_back(ParWidthOp::create(
-        rewriter, loc,
-        getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy,
-                                        numWorkers),
-        policy.workerDim(ctx)));
-  }
+    Value numWorkers = computeOp.getNumWorkersValue(deviceType);
+    if (!numWorkers)
+      numWorkers = computeOp.getNumWorkersValue();
+    if (numWorkers) {
+      values.push_back(ParWidthOp::create(
+          rewriter, loc,
+          getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(),
+                                          indexTy, numWorkers),
+          policy.workerDim(ctx)));
+    }
----------------
razvanlupusoru wrote:

The recommended change is not correct. Because in the branch "(!numWorkers)" it attempts again to obtain non-device-type specific num_workers value.

https://github.com/llvm/llvm-project/pull/195158


More information about the Mlir-commits mailing list