[Mlir-commits] [mlir] [mlir][acc] Add acc.compute_region and acc.par_width operations (PR #184864)

Valentin Clement バレンタイン クレメン llvmlistbot at llvm.org
Thu Mar 5 13:10:30 PST 2026


================
@@ -325,6 +326,248 @@ void ReductionCombineOp::getEffects(
                        SideEffects::DefaultResource::get());
 }
 
+//===----------------------------------------------------------------------===//
+// ComputeRegionOp
+//===----------------------------------------------------------------------===//
+
+static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op,
+                                            GPUParallelDimAttr parDim) {
+  for (auto launchArg : op.getLaunchArgs()) {
+    auto parOp = launchArg.getDefiningOp<ParWidthOp>();
+    if (!parOp)
+      continue;
+    auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
+    if (launchArgDim == parDim)
+      return parOp;
+  }
+  return nullptr;
+}
+
+std::optional<Value>
+ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
+  if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
+    return parWidthOp.getResult();
+  return {};
+}
+
+std::optional<Value>
+ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
+  if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
+    if (parWidthOp.getLaunchArg())
+      return parWidthOp.getLaunchArg();
+  return {};
+}
+
+std::optional<uint64_t>
+ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
+  auto knownParWidth = getKnownLaunchArg(parDim);
+  if (knownParWidth.has_value())
+    return getConstantIntValue(knownParWidth.value());
+  return {};
+}
+
+BlockArgument ComputeRegionOp::appendInputArg(Value value) {
+  getInputArgsMutable().append(value);
+  return getBody()->addArgument(value.getType(), getLoc());
+}
+
+bool ComputeRegionOp::isEffectivelySerial() {
+  auto *ctx = getContext();
+
+  if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
+    return true;
+
+  auto checkDim = [&](GPUParallelDimAttr dim) -> bool {
+    auto val = getKnownConstantLaunchArg(dim);
+    return val && *val == 1;
+  };
+
+  return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
+         checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
+         checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
+         checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
+         checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
+         checkDim(GPUParallelDimAttr::blockZDim(ctx));
+}
+
+BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
+  for (auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
+    auto parOp = launchArg.getDefiningOp<ParWidthOp>();
+    assert(parOp);
+    auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
+    if (launchArgDim == parDim) {
+      assert(pos < getRegion().front().getNumArguments() &&
+             "launch arg position out of range");
+      return getRegion().front().getArgument(pos);
+    }
+  }
+  llvm_unreachable("attempting to get unspecified parDim");
+}
+
+SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
+  SmallVector<GPUParallelDimAttr> parDims;
+  for (auto launchArg : getLaunchArgs()) {
+    auto parOp = launchArg.getDefiningOp<ParWidthOp>();
+    auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
+    int64_t dimInt = launchArgDim.getValue().getInt();
+    parDims.push_back(intToParDim(getContext(), dimInt));
+  }
+  return parDims;
+}
+
+Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
+  unsigned argNumber = blockArg.getArgNumber();
+  unsigned numLaunchArgs = getLaunchArgs().size();
+  unsigned numInputArgs = getInputArgs().size();
+  assert(argNumber < (numLaunchArgs + numInputArgs) &&
+         "invalid block argument");
+  if (argNumber < numLaunchArgs)
+    return getLaunchArgs()[argNumber];
+  return getInputArgs()[argNumber - numLaunchArgs];
+}
+
+BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
+  return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
+}
+
+LogicalResult ComputeRegionOp::verify() {
+  for (auto op : getLaunchArgs())
+    if (!op.getDefiningOp<acc::ParWidthOp>())
+      return emitOpError(
+          "launch arguments must be results of acc.par_width operations");
----------------
clementval wrote:

Thanks!

https://github.com/llvm/llvm-project/pull/184864


More information about the Mlir-commits mailing list