[Mlir-commits] [mlir] 16cef47 - [mlir][acc] Capture explicit serial semantics for compute regions (#195158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 30 13:59:33 PDT 2026


Author: Razvan Lupusoru
Date: 2026-04-30T13:59:28-07:00
New Revision: 16cef47f1cd936022e18c99b679edabeb57b2ae3

URL: https://github.com/llvm/llvm-project/commit/16cef47f1cd936022e18c99b679edabeb57b2ae3
DIFF: https://github.com/llvm/llvm-project/commit/16cef47f1cd936022e18c99b679edabeb57b2ae3.diff

LOG: [mlir][acc] Capture explicit serial semantics for compute regions (#195158)

This PR improves robustness in capturing when user's intent is to treat
OpenACC region as sequential. It does so in the following ways:
- Ensure that `seq` acc.par_width is explicitly used when region is
serial. Previously it was not assigning any acc.par_width which causes
ambiguities because that way it is indistinguishable whether a region is
explicitly serial vs whether the region needs implicitly assigned
parallelism.
- Treas `acc parallel` and `acc kernels` with `num_gangs(1)`
`num_workers(1)` `vector_length(1)` exactly the same as `acc serial`.
This is because these are all parallelism dimensions expressible with
OpenACC clauses and being all set to 1 makes the semantics consistent
with those defined for `acc serial`.

Added: 
    

Modified: 
    mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
    mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
    mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
index 9cc36312d3615..9868584a4b699 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
@@ -24,15 +24,17 @@
 // ----------------
 // 1. Compute constructs: acc.parallel, acc.serial, and acc.kernels are
 //    replaced by acc.kernel_environment containing a single acc.compute_region.
-//    Launch arguments (num_gangs, num_workers, vector_length) become
-//    acc.par_width ops (each result is `index`) and are passed as
-//    compute_region launch operands (still required to be acc.par_width
-//    results by the compute_region verifier).
+//    For acc.parallel / acc.kernels, launch arguments (num_gangs, num_workers,
+//    vector_length) become acc.par_width ops (each result is `index`) and are
+//    passed as compute_region launch operands. Compute regions with
+//    num_gangs(1), num_workers(1), and vector_length(1) and acc serial use a
+//    single sequential acc.par_width launch operand.
 //
 // 2. acc.loop: Converted according to context and attributes:
 //    - Unstructured: body wrapped in scf.execute_region.
-//    - Sequential (serial region or seq clause): scf.parallel with
-//      par_dims = sequential.
+//    - Sequential (serial region, seq clause, or compute region with
+//    num_gangs(1), num_workers(1), and vector_length(1)):
+//      scf.parallel with par_dims = sequential.
 //    - Auto (in parallel/kernels): scf.for with collapse when
 //    multi-dimensional.
 //    - Orphan (not inside a compute construct): scf.for, no collapse.
@@ -56,6 +58,7 @@
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace acc {
@@ -82,24 +85,34 @@ static Value stripIndexCasts(Value val) {
   return val;
 }
 
-/// A parallel construct is "effectively serial" when it specifies
-/// num_gangs(1), num_workers(1), and vector_length(1). This matches
-/// the semantics of acc.serial but expressed through acc.parallel.
-static bool isEffectivelySerial(ParallelOp op) {
+template <typename ComputeOpT>
+static bool isGangWorkerVectorAllOne(ComputeOpT op) {
   auto numGangs = op.getNumGangsValues();
-  if (numGangs.size() != 1)
+  if (numGangs.empty())
     return false;
+  for (Value gangSize : numGangs) {
+    if (!isConstantIntValue(stripIndexCasts(gangSize), 1))
+      return false;
+  }
   Value numWorkers = op.getNumWorkersValue();
   if (!numWorkers)
     return false;
   Value vectorLength = op.getVectorLengthValue();
   if (!vectorLength)
     return false;
-  return isConstantIntValue(stripIndexCasts(numGangs.front()), 1) &&
-         isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
+  return isConstantIntValue(stripIndexCasts(numWorkers), 1) &&
          isConstantIntValue(stripIndexCasts(vectorLength), 1);
 }
 
+/// A compute construct is "effectively serial" when it specifies
+/// num_gangs(1), num_workers(1), and vector_length(1). This is because
+/// these are the only parallelism dimensions expressible from OpenACC spec
+/// point-of-view and is consistent with how `serial` semantics are defined.
+template <typename ComputeOpT>
+static bool isEffectivelySerial(ComputeOpT op) {
+  return isGangWorkerVectorAllOne(op);
+}
+
 static bool isOpInComputeRegion(Operation *op) {
   Region *region = op->getBlock()->getParent();
   return getEnclosingComputeOp(*region) != nullptr;
@@ -108,10 +121,12 @@ static bool isOpInComputeRegion(Operation *op) {
 static bool isOpInSerialRegion(Operation *op) {
   if (auto parallelOp = op->getParentOfType<ParallelOp>())
     return isEffectivelySerial(parallelOp);
-  if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
-    return computeRegion.isEffectivelySerial();
+  if (auto kernelsOp = op->getParentOfType<KernelsOp>())
+    return isEffectivelySerial(kernelsOp);
   if (op->getParentOfType<SerialOp>())
     return true;
+  if (auto computeRegion = op->getParentOfType<ComputeRegionOp>())
+    return computeRegion.isEffectivelySerial();
   if (auto funcOp = op->getParentOfType<FunctionOpInterface>()) {
     if (isSpecializedAccRoutine(funcOp)) {
       auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
@@ -194,61 +209,67 @@ getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
   return parDims;
 }
 
-/// Create acc.par_width operations from gang/worker/vector values of a
-/// compute construct. Queries the device-type-specific values first, falling
-/// back to the default (DeviceType::None) values.
+/// Build `acc.compute_region` launch operands: one sequential `acc.par_width`
+/// for `acc.serial`, for `acc.parallel` / `acc.kernels` when every num_gangs
+/// operand and num_workers / vector_length are the constant 1, and otherwise
+/// `acc.par_width` from gang/worker/vector (device-type operands first, then
+/// default DeviceType::None).
 template <typename ComputeConstructT>
 static SmallVector<Value>
 assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
                       RewriterBase &rewriter,
                       const ACCToGPUMappingPolicy &policy) {
-  SmallVector<Value> values;
   auto *ctx = rewriter.getContext();
-  auto indexTy = rewriter.getIndexType();
   auto loc = computeOp->getLoc();
 
-  auto numGangs = computeOp.getNumGangsValues(deviceType);
-  if (numGangs.empty())
-    numGangs = computeOp.getNumGangsValues();
-  for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
-    auto gangLevel = getGangParLevel(gangDimIdx + 1);
-    values.push_back(
-        ParWidthOp::create(rewriter, loc,
-                           getValueOrCreateCastToIndexLike(
-                               rewriter, gangSize.getLoc(), indexTy, gangSize),
-                           policy.gangDim(ctx, gangLevel)));
-  }
+  if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
+    return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
+  } else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
+                                       KernelsOp>::value) {
+    if (isEffectivelySerial(computeOp))
+      return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
+
+    SmallVector<Value> values;
+    auto indexTy = rewriter.getIndexType();
+
+    auto numGangs = computeOp.getNumGangsValues(deviceType);
+    if (numGangs.empty())
+      numGangs = computeOp.getNumGangsValues();
+    for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
+      auto gangLevel = getGangParLevel(gangDimIdx + 1);
+      values.push_back(ParWidthOp::create(
+          rewriter, loc,
+          getValueOrCreateCastToIndexLike(rewriter, gangSize.getLoc(), indexTy,
+                                          gangSize),
+          policy.gangDim(ctx, gangLevel)));
+    }
 
-  Value numWorkers = computeOp.getNumWorkersValue(deviceType);
-  if (!numWorkers)
-    numWorkers = computeOp.getNumWorkersValue();
-  if (numWorkers) {
-    values.push_back(ParWidthOp::create(
-        rewriter, loc,
-        getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(), indexTy,
-                                        numWorkers),
-        policy.workerDim(ctx)));
-  }
+    Value numWorkers = computeOp.getNumWorkersValue(deviceType);
+    if (!numWorkers)
+      numWorkers = computeOp.getNumWorkersValue();
+    if (numWorkers) {
+      values.push_back(ParWidthOp::create(
+          rewriter, loc,
+          getValueOrCreateCastToIndexLike(rewriter, numWorkers.getLoc(),
+                                          indexTy, numWorkers),
+          policy.workerDim(ctx)));
+    }
 
-  Value vectorLength = computeOp.getVectorLengthValue(deviceType);
-  if (!vectorLength)
-    vectorLength = computeOp.getVectorLengthValue();
-  if (vectorLength) {
-    values.push_back(ParWidthOp::create(
-        rewriter, loc,
-        getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
-                                        indexTy, vectorLength),
-        policy.vectorDim(ctx)));
+    Value vectorLength = computeOp.getVectorLengthValue(deviceType);
+    if (!vectorLength)
+      vectorLength = computeOp.getVectorLengthValue();
+    if (vectorLength) {
+      values.push_back(ParWidthOp::create(
+          rewriter, loc,
+          getValueOrCreateCastToIndexLike(rewriter, vectorLength.getLoc(),
+                                          indexTy, vectorLength),
+          policy.vectorDim(ctx)));
+    }
+    return values;
+  } else {
+    llvm_unreachable("assignKnownLaunchArgs: expected parallel, kernels, or "
+                     "serial");
   }
-  return values;
-}
-
-/// SerialOp has no gang/worker/vector clauses.
-template <>
-SmallVector<Value>
-assignKnownLaunchArgs<SerialOp>(SerialOp, DeviceType, RewriterBase &,
-                                const ACCToGPUMappingPolicy &) {
-  return {};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
index ee177aaf6e7a7..c2049dab676e3 100644
--- a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
@@ -64,8 +64,8 @@ func.func @serial_loop(%buf: memref<4xi32>) {
   %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
   // CHECK-NOT: acc.serial
   // CHECK: acc.kernel_environment
-  // CHECK-NOT: acc.par_width
-  // CHECK: acc.compute_region
+  // CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+  // CHECK: acc.compute_region launch(
   // CHECK: scf.parallel
   // CHECK: acc.par_dims = #acc<par_dims[sequential]>
   acc.serial dataOperands(%dev : memref<4xi32>) {
@@ -117,7 +117,9 @@ func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>)
   %c42 = arith.constant 42 : i32
   %dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
   // CHECK: acc.kernel_environment
-  // CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) {
+  // CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+  // CHECK: acc.compute_region launch(
+  // CHECK-SAME: ins({{.*}}) : (memref<1xi32>) {
   // CHECK-DAG: arith.constant 42 : i32
   // CHECK-DAG: arith.constant 0 : index
   // CHECK: memref.store
@@ -129,3 +131,142 @@ func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>)
   acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
   return
 }
+
+// -----
+
+// acc.parallel with num_gangs(1), num_workers(1), and vector_length(1) is
+// treated like acc.serial: sequential acc.par_width launch args and sequential
+// par_dims on lowered loops.
+
+// CHECK-LABEL: func.func @parallel_unit_launch_serial_loops
+func.func @parallel_unit_launch_serial_loops(%buf: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_i32 = arith.constant 1 : i32
+
+  %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
+  // CHECK-NOT: acc.parallel
+  // CHECK: acc.kernel_environment
+  // CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+  // CHECK: acc.compute_region launch(
+  // CHECK: scf.parallel
+  // CHECK: acc.par_dims = #acc<par_dims[sequential]>
+  acc.parallel num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
+    acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
+      %vi = arith.index_cast %i : index to i32
+      memref.store %vi, %dev[%i] : memref<4xi32>
+      acc.yield
+    } attributes {independent = [#acc.device_type<none>]}
+    acc.yield
+  }
+  acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
+  return
+}
+
+// -----
+
+// acc.kernels with num_gangs(1), num_workers(1), and vector_length(1) is
+// treated like acc.serial: sequential acc.par_width launch args and sequential
+// par_dims on lowered loops.
+
+// CHECK-LABEL: func.func @kernels_unit_launch_serial_loops
+func.func @kernels_unit_launch_serial_loops(%buf: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_i32 = arith.constant 1 : i32
+
+  %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
+  // CHECK-NOT: acc.kernels
+  // CHECK: acc.kernel_environment
+  // CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+  // CHECK: acc.compute_region launch(
+  // CHECK: scf.parallel
+  // CHECK: acc.par_dims = #acc<par_dims[sequential]>
+  acc.kernels num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
+    acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
+      %vi = arith.index_cast %i : index to i32
+      memref.store %vi, %dev[%i] : memref<4xi32>
+      acc.yield
+    } attributes {independent = [#acc.device_type<none>]}
+    acc.terminator
+  }
+  acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_vector_length32_independent
+func.func @parallel_vector_length32_independent(%buf: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_i32 = arith.constant 1 : i32
+  %c32_i32 = arith.constant 32 : i32
+
+  %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
+  // CHECK-NOT: acc.par_dims = #acc<par_dims[sequential]>
+  // CHECK: acc.par_dims = #acc<par_dims[thread_x]>
+  acc.parallel num_gangs({%c1_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c32_i32 : i32) dataOperands(%dev : memref<4xi32>) {
+    acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
+      %vi = arith.index_cast %i : index to i32
+      memref.store %vi, %dev[%i] : memref<4xi32>
+      acc.yield
+    } attributes {independent = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
+    acc.yield
+  }
+  acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @kernels_num_gangs4_independent
+func.func @kernels_num_gangs4_independent(%buf: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_i32 = arith.constant 1 : i32
+  %c4_i32 = arith.constant 4 : i32
+
+  %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
+  // CHECK-NOT: acc.par_dims = #acc<par_dims[sequential]>
+  // CHECK: acc.par_dims = #acc<par_dims[thread_x]>
+  acc.kernels num_gangs({%c4_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
+    acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
+      %vi = arith.index_cast %i : index to i32
+      memref.store %vi, %dev[%i] : memref<4xi32>
+      acc.yield
+    } attributes {independent = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
+    acc.terminator
+  }
+  acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_num_gangs_1_2_independent
+func.func @parallel_num_gangs_1_2_independent(%buf: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i32 = arith.constant 2 : i32
+
+  %dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
+  // CHECK-NOT: acc.par_dims = #acc<par_dims[sequential]>
+  // CHECK: acc.par_dims = #acc<par_dims[thread_x]>
+  acc.parallel num_gangs({%c1_i32 : i32, %c2_i32 : i32}) num_workers(%c1_i32 : i32) vector_length(%c1_i32 : i32) dataOperands(%dev : memref<4xi32>) {
+    acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
+      %vi = arith.index_cast %i : index to i32
+      memref.store %vi, %dev[%i] : memref<4xi32>
+      acc.yield
+    } attributes {independent = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
+    acc.yield
+  }
+  acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
+  return
+}

diff  --git a/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir b/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir
index bd2f006396c6e..4032fed217b62 100644
--- a/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-compute-lowering-loop.mlir
@@ -92,8 +92,8 @@ func.func @serial_loop_normalized(%buf: memref<1xi32>) {
   %dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
   // CHECK-NOT: acc.serial
   // CHECK: acc.kernel_environment
-  // CHECK-NOT: acc.par_width
-  // CHECK: acc.compute_region
+  // CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+  // CHECK: acc.compute_region launch(
   // CHECK: scf.parallel
   // CHECK-DAG: arith.muli
   // CHECK-DAG: arith.addi


        


More information about the Mlir-commits mailing list