[Mlir-commits] [mlir] [mlir][acc] Ensure implicit declare hoisting works for compute_region (PR #192501)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 11:10:29 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Razvan Lupusoru (razvanlupusoru)
<details>
<summary>Changes</summary>
Any hoisting across `acc.compute_region` needs to be wired through block arguments as the region is `IsolatedFromAbove`. Thus update `ACCImplicitDeclare` to do so by using new API
`wireHoistedValueThroughIns` which handles the value wiring after hoisting.
---
Full diff: https://github.com/llvm/llvm-project/pull/192501.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td (+10)
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp (+18)
- (modified) mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp (+7-2)
- (modified) mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir (+21)
- (modified) mlir/unittests/Dialect/OpenACC/CMakeLists.txt (+1)
- (added) mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp (+178)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
index 76902a6d2690e..69848101a5e4d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
@@ -351,6 +351,16 @@ def OpenACC_ComputeRegionOp
/// the region block arguments. Returns the new block argument.
::mlir::BlockArgument appendInputArg(::mlir::Value);
+ /// After hoisting `value`'s defining op, wire it into this region: append it
+ /// as an `ins` operand, add the matching body entry argument, and replace
+ /// uses under this region with that argument (excluding this op's own `ins`
+ /// operand uses).
+ ///
+ /// Requires that `value` is defined outside `getRegion()` and is still used
+ /// inside the region. Otherwise returns `std::nullopt`.
+ std::optional<::mlir::BlockArgument>
+ wireHoistedValueThroughIns(::mlir::Value value);
+
/// Check whether all parallel dimensions have width 1.
bool isEffectivelySerial();
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
index de243f3dadd74..fe806186ce25d 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/Region.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -481,6 +482,23 @@ BlockArgument ComputeRegionOp::appendInputArg(Value value) {
return getBody()->addArgument(value.getType(), getLoc());
}
+std::optional<BlockArgument>
+ComputeRegionOp::wireHoistedValueThroughIns(Value value) {
+ Region ®ion = getRegion();
+
+ auto useIsInRegion = [&](OpOperand &use) -> bool {
+ return region.isAncestor(use.getOwner()->getParentRegion());
+ };
+
+ if (!areValuesDefinedAbove(ValueRange(value), region) ||
+ !llvm::any_of(value.getUses(), useIsInRegion))
+ return std::nullopt;
+
+ BlockArgument arg = appendInputArg(value);
+ replaceAllUsesInRegionWith(value, arg, region);
+ return arg;
+}
+
bool ComputeRegionOp::isEffectivelySerial() {
auto *ctx = getContext();
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
index e99a27a7bb89a..3b8bf86fee11c 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
@@ -272,7 +272,12 @@ static void hoistNonConstantDirectUses(AccConstructT accOp,
SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef);
if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
accSupport)) {
+ auto computeRegionParent =
+ addrOfOp->getParentOfType<acc::ComputeRegionOp>();
addrOfOp->moveBefore(accOp);
+ if (computeRegionParent)
+ for (Value v : addrOfOp->getResults())
+ computeRegionParent.wireHoistedValueThroughIns(v);
LLVM_DEBUG(
llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t";
accOp->print(llvm::dbgs(),
@@ -341,7 +346,7 @@ class ACCImplicitDeclare
// polluting the device globals.
mod.walk([&](Operation *op) {
TypeSwitch<Operation *, void>(op)
- .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::ComputeRegionOp>(
[&](auto accOp) {
hoistNonConstantDirectUses(accOp, accSupport);
});
@@ -354,7 +359,7 @@ class ACCImplicitDeclare
GlobalOpSetT globalsToAccDeclare;
mod.walk([&](Operation *op) {
TypeSwitch<Operation *, void>(op)
- .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::ComputeRegionOp>(
[&](auto accOp) {
collectGlobalsFromDeviceRegion(
accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
index 74ff3384c093c..b1c9aa9016dc3 100644
--- a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
@@ -173,3 +173,24 @@ func.func @test_multiple_constructs() {
// CHECK: memref.get_global @global_kernels
// CHECK-NEXT: acc.kernels
+// -----
+
+memref.global @global_in_compute_region : memref<f32> = dense<0.0>
+
+func.func @test_scalar_in_compute_region() {
+ acc.compute_region {
+ %addr = memref.get_global @global_in_compute_region : memref<f32>
+ %load = memref.load %addr[] : memref<f32>
+ acc.yield
+ } {origin = "acc.parallel"}
+ return
+}
+
+// When hoisting out of acc.compute_region, block arguments must be used to
+// wire the value through.
+// CHECK-LABEL: func.func @test_scalar_in_compute_region
+// CHECK: %[[G:.*]] = memref.get_global @global_in_compute_region
+// CHECK: acc.compute_region ins(%[[INS_ARG:.*]] = %[[G]]) : (memref<f32>) {
+// CHECK: memref.load %[[INS_ARG]][] : memref<f32>
+// CHECK-NOT: memref.load %[[G]][] : memref<f32>
+
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 7303ff581abc4..17d1721b82602 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIROpenACCTests
+ OpenACCCGOpsTest.cpp
OpenACCOpsTest.cpp
OpenACCOpsInterfacesTest.cpp
OpenACCTypeInterfacesTest.cpp
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp
new file mode 100644
index 0000000000000..bbc12b0f829d9
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp
@@ -0,0 +1,178 @@
+//===- OpenACCCGOpsTest.cpp - Unit tests for OpenACC codegen ops ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtilsCG.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/PatternMatch.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+using namespace mlir::acc;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class OpenACCCGOpsTest : public ::testing::Test {
+protected:
+ OpenACCCGOpsTest() : b(&context), loc(UnknownLoc::get(&context)) {
+ context.loadDialect<OpenACCDialect, arith::ArithDialect, func::FuncDialect,
+ gpu::GPUDialect>();
+ }
+
+ /// Module with a single no-arg `func.func`, insertion point at the start of
+ /// its entry block (ready for host values and `buildComputeRegion` IR).
+ struct HostContext {
+ OwningOpRef<ModuleOp> module;
+ IRRewriter rewriter;
+ func::FuncOp func;
+ Block *entry = nullptr;
+
+ HostContext(MLIRContext &ctx, Location loc, OpBuilder &builder,
+ StringRef funcName = "f")
+ : module(ModuleOp::create(builder, loc)), rewriter(&ctx) {
+ rewriter.setInsertionPointToEnd(module->getBody());
+ func = func::FuncOp::create(rewriter, loc, funcName,
+ builder.getFunctionType({}, {}));
+ entry = func.addEntryBlock();
+ rewriter.setInsertionPointToEnd(entry);
+ func::ReturnOp::create(rewriter, loc);
+ rewriter.setInsertionPointToStart(entry);
+ }
+ };
+
+ /// Build a single-block region for `buildComputeRegion`: optional one
+ /// argument (mapped to `ins`), optional `arith.addi` of that argument with
+ /// itself, then `acc.yield`. `regionOut` must be empty.
+ static void populateSourceRegionSingleBlock(Region ®ionOut,
+ MLIRContext &ctx, Location loc,
+ std::optional<Type> mapArgType,
+ bool addSelfAddi) {
+ assert(regionOut.empty() && "expected an empty region");
+ Block *block = new Block();
+ regionOut.push_back(block);
+ OpBuilder regionBuilder(&ctx);
+ regionBuilder.setInsertionPointToStart(block);
+ if (mapArgType) {
+ BlockArgument arg = block->addArgument(*mapArgType, loc);
+ if (addSelfAddi)
+ arith::AddIOp::create(regionBuilder, loc, arg, arg);
+ }
+ YieldOp::create(regionBuilder, loc);
+ }
+
+ /// Single-block region with an `i32` producer (`arith.constant`) and a user
+ /// (`arith.addi`) both inside the region — valid clone source with no
+ /// external `ins` captures.
+ static void populateSourceRegionWithInternalI32Constant(Region ®ionOut,
+ MLIRContext &ctx,
+ Location loc,
+ int64_t cst) {
+ assert(regionOut.empty() && "expected an empty region");
+ Block *block = new Block();
+ regionOut.push_back(block);
+ OpBuilder regionBuilder(&ctx);
+ regionBuilder.setInsertionPointToStart(block);
+ Value k = arith::ConstantIntOp::create(regionBuilder, loc,
+ IntegerType::get(&ctx, 32), cst);
+ arith::AddIOp::create(regionBuilder, loc, k, k);
+ YieldOp::create(regionBuilder, loc);
+ }
+
+ MLIRContext context;
+ OpBuilder b;
+ Location loc;
+};
+
+//===----------------------------------------------------------------------===//
+// ComputeRegionOp::wireHoistedValueThroughIns
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCCGOpsTest, WireHoistedValueThroughInsAfterHoisting) {
+ HostContext host(context, loc, b);
+
+ Region sourceRegion;
+ populateSourceRegionWithInternalI32Constant(sourceRegion, context, loc, 7);
+ IRMapping mapping;
+ auto cr = buildComputeRegion(loc, /*launchArgs=*/{}, /*inputArgs=*/{},
+ SerialOp::getOperationName(), sourceRegion,
+ host.rewriter, mapping);
+ ASSERT_TRUE(cr);
+
+ arith::ConstantIntOp producer;
+ arith::AddIOp addOp;
+ for (Operation &op : cr.getRegion().front().getOperations()) {
+ if (auto c = dyn_cast<arith::ConstantIntOp>(op))
+ producer = c;
+ else if (auto a = dyn_cast<arith::AddIOp>(op))
+ addOp = a;
+ }
+ ASSERT_TRUE(producer);
+ ASSERT_TRUE(addOp);
+ Value produced = producer.getResult();
+ ASSERT_EQ(addOp.getLhs(), produced);
+ ASSERT_EQ(addOp.getRhs(), produced);
+
+ // Hoist the producer out of the region (same idea as ACCImplicitDeclare).
+ producer->moveBefore(cr.getOperation());
+ ASSERT_TRUE(cr.wireHoistedValueThroughIns(produced).has_value());
+
+ EXPECT_EQ(addOp.getLhs(), addOp.getRhs());
+ EXPECT_TRUE(isa<BlockArgument>(addOp.getLhs()));
+ EXPECT_TRUE(isa<BlockArgument>(addOp.getRhs()));
+ EXPECT_TRUE(succeeded(host.module->verify()));
+}
+
+TEST_F(OpenACCCGOpsTest, WireHoistedValueThroughInsNoUseInside) {
+ HostContext host(context, loc, b);
+ Value v = arith::ConstantIntOp::create(host.rewriter, loc, b.getI32Type(), 1);
+ Value w = arith::ConstantIntOp::create(host.rewriter, loc, b.getI32Type(), 2);
+
+ Region sourceRegion;
+ populateSourceRegionSingleBlock(sourceRegion, context, loc,
+ /*mapArgType=*/std::nullopt,
+ /*addSelfAddi=*/false);
+ IRMapping mapping;
+ auto cr = buildComputeRegion(loc, /*launchArgs=*/{}, ValueRange(v),
+ SerialOp::getOperationName(), sourceRegion,
+ host.rewriter, mapping);
+ ASSERT_TRUE(cr);
+
+ EXPECT_FALSE(cr.wireHoistedValueThroughIns(w).has_value());
+ EXPECT_TRUE(succeeded(host.module->verify()));
+}
+
+TEST_F(OpenACCCGOpsTest, WireHoistedValueThroughInsDefinedInside) {
+ HostContext host(context, loc, b);
+ auto c128 = arith::ConstantIndexOp::create(host.rewriter, loc, 128);
+ auto threadXDim = GPUParallelDimAttr::threadXDim(&context);
+ auto pw = ParWidthOp::create(host.rewriter, loc, c128, threadXDim);
+
+ Region sourceRegion;
+ populateSourceRegionSingleBlock(sourceRegion, context, loc,
+ /*mapArgType=*/std::nullopt,
+ /*addSelfAddi=*/false);
+ IRMapping mapping;
+ auto cr = buildComputeRegion(loc, ValueRange(pw), /*inputArgs=*/{},
+ ParallelOp::getOperationName(), sourceRegion,
+ host.rewriter, mapping);
+ ASSERT_TRUE(cr);
+
+ BlockArgument launchArg = cr.getRegion().front().getArgument(0);
+ EXPECT_FALSE(cr.wireHoistedValueThroughIns(launchArg).has_value());
+ EXPECT_TRUE(succeeded(host.module->verify()));
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/192501
More information about the Mlir-commits
mailing list