[Mlir-commits] [mlir] [mlir][acc] Capture explicit serial semantics for compute regions (PR #195158)
Razvan Lupusoru
llvmlistbot at llvm.org
Thu Apr 30 12:58:42 PDT 2026
================
@@ -194,61 +209,67 @@ getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
return parDims;
}
-/// Create acc.par_width operations from gang/worker/vector values of a
-/// compute construct. Queries the device-type-specific values first, falling
-/// back to the default (DeviceType::None) values.
+/// Build `acc.compute_region` launch operands: one sequential `acc.par_width`
+/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs
+/// operand and num_workers / vector_length are the constant 1, and otherwise
+/// `acc.par_width` from gang/worker/vector (device-type operands first, then
+/// default DeviceType::None).
template <typename ComputeConstructT>
static SmallVector<Value>
assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
RewriterBase &rewriter,
const ACCToGPUMappingPolicy &policy) {
- SmallVector<Value> values;
auto *ctx = rewriter.getContext();
- auto indexTy = rewriter.getIndexType();
auto loc = computeOp->getLoc();
- auto numGangs = computeOp.getNumGangsValues(deviceType);
- if (numGangs.empty())
- numGangs = computeOp.getNumGangsValues();
- for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
- auto gangLevel = getGangParLevel(gangDimIdx + 1);
- values.push_back(
- ParWidthOp::create(rewriter, loc,
- getValueOrCreateCastToIndexLike(
- rewriter, gangSize.getLoc(), indexTy, gangSize),
- policy.gangDim(ctx, gangLevel)));
- }
+ if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
+ return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
+ } else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
+ KernelsOp>::value) {
+ if (isEffectivelySerial(computeOp))
+ return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
+
+ SmallVector<Value> values;
+ auto indexTy = rewriter.getIndexType();
+
+ auto numGangs = computeOp.getNumGangsValues(deviceType);
+ if (numGangs.empty())
+ numGangs = computeOp.getNumGangsValues();
+ for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
+ auto gangLevel = getGangParLevel(gangDimIdx + 1);
+ values.push_back(ParWidthOp::create(
+ rewriter, loc,
+ getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy,
+ gangSize),
+ policy.gangDim(ctx, gangLevel)));
+ }
- Value numWorkers = computeOp.getNumWorkersValue(deviceType);
- if (!numWorkers)
- numWorkers = computeOp.getNumWorkersValue();
- if (numWorkers) {
- values.push_back(ParWidthOp::create(
- rewriter, loc,
- getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy,
- numWorkers),
- policy.workerDim(ctx)));
- }
+ Value numWorkers = computeOp.getNumWorkersValue(deviceType);
+ if (!numWorkers)
+ numWorkers = computeOp.getNumWorkersValue();
+ if (numWorkers) {
+ values.push_back(ParWidthOp::create(
+ rewriter, loc,
+ getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(),
+ indexTy, numWorkers),
+ policy.workerDim(ctx)));
+ }
----------------
razvanlupusoru wrote:
The recommended change is not correct. Because in the branch "(!numWorkers)" it attempts again to obtain non-device-type specific num_workers value.
https://github.com/llvm/llvm-project/pull/195158
More information about the Mlir-commits
mailing list