[Mlir-commits] [mlir] [mlir][acc] Ensure implicit declare hoisting works for compute_region (PR #192501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 16 11:10:29 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>

Any hoisting across `acc.compute_region` needs to be wired through block arguments as the region is `IsolatedFromAbove`. Thus update `ACCImplicitDeclare` to do so by using new API
`wireHoistedValueThroughIns` which handles the value wiring after hoisting.

---
Full diff: https://github.com/llvm/llvm-project/pull/192501.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td (+10) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp (+18) 
- (modified) mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp (+7-2) 
- (modified) mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir (+21) 
- (modified) mlir/unittests/Dialect/OpenACC/CMakeLists.txt (+1) 
- (added) mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp (+178) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
index 76902a6d2690e..69848101a5e4d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
@@ -351,6 +351,16 @@ def OpenACC_ComputeRegionOp
     /// the region block arguments. Returns the new block argument.
     ::mlir::BlockArgument appendInputArg(::mlir::Value);
 
+    /// After hoisting `value`'s defining op, wire it into this region: append it
+    /// as an `ins` operand, add the matching body entry argument, and replace
+    /// uses under this region with that argument (excluding this op's own `ins`
+    /// operand uses).
+    ///
+    /// Requires that `value` is defined outside `getRegion()` and is still used
+    /// inside the region. Otherwise returns `std::nullopt`.
+    std::optional<::mlir::BlockArgument>
+    wireHoistedValueThroughIns(::mlir::Value value);
+
     /// Check whether all parallel dimensions have width 1.
     bool isEffectivelySerial();
 
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
index de243f3dadd74..fe806186ce25d 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Region.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -481,6 +482,23 @@ BlockArgument ComputeRegionOp::appendInputArg(Value value) {
   return getBody()->addArgument(value.getType(), getLoc());
 }
 
+std::optional<BlockArgument>
+ComputeRegionOp::wireHoistedValueThroughIns(Value value) {
+  Region &region = getRegion();
+
+  auto useIsInRegion = [&](OpOperand &use) -> bool {
+    return region.isAncestor(use.getOwner()->getParentRegion());
+  };
+
+  if (!areValuesDefinedAbove(ValueRange(value), region) ||
+      !llvm::any_of(value.getUses(), useIsInRegion))
+    return std::nullopt;
+
+  BlockArgument arg = appendInputArg(value);
+  replaceAllUsesInRegionWith(value, arg, region);
+  return arg;
+}
+
 bool ComputeRegionOp::isEffectivelySerial() {
   auto *ctx = getContext();
 
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
index e99a27a7bb89a..3b8bf86fee11c 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
@@ -272,7 +272,12 @@ static void hoistNonConstantDirectUses(AccConstructT accOp,
           SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef);
       if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
                                           accSupport)) {
+        auto computeRegionParent =
+            addrOfOp->getParentOfType<acc::ComputeRegionOp>();
         addrOfOp->moveBefore(accOp);
+        if (computeRegionParent)
+          for (Value v : addrOfOp->getResults())
+            computeRegionParent.wireHoistedValueThroughIns(v);
         LLVM_DEBUG(
             llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t";
             accOp->print(llvm::dbgs(),
@@ -341,7 +346,7 @@ class ACCImplicitDeclare
     // polluting the device globals.
     mod.walk([&](Operation *op) {
       TypeSwitch<Operation *, void>(op)
-          .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+          .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::ComputeRegionOp>(
               [&](auto accOp) {
                 hoistNonConstantDirectUses(accOp, accSupport);
               });
@@ -354,7 +359,7 @@ class ACCImplicitDeclare
     GlobalOpSetT globalsToAccDeclare;
     mod.walk([&](Operation *op) {
       TypeSwitch<Operation *, void>(op)
-          .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+          .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::ComputeRegionOp>(
               [&](auto accOp) {
                 collectGlobalsFromDeviceRegion(
                     accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
index 74ff3384c093c..b1c9aa9016dc3 100644
--- a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir
@@ -173,3 +173,24 @@ func.func @test_multiple_constructs() {
 // CHECK: memref.get_global @global_kernels
 // CHECK-NEXT: acc.kernels
 
+// -----
+
+memref.global @global_in_compute_region : memref<f32> = dense<0.0>
+
+func.func @test_scalar_in_compute_region() {
+  acc.compute_region {
+    %addr = memref.get_global @global_in_compute_region : memref<f32>
+    %load = memref.load %addr[] : memref<f32>
+    acc.yield
+  } {origin = "acc.parallel"}
+  return
+}
+
+// When hoisting out of acc.compute_region, block arguments must be used to
+// wire the value through.
+// CHECK-LABEL: func.func @test_scalar_in_compute_region
+// CHECK: %[[G:.*]] = memref.get_global @global_in_compute_region
+// CHECK: acc.compute_region ins(%[[INS_ARG:.*]] = %[[G]]) : (memref<f32>) {
+// CHECK: memref.load %[[INS_ARG]][] : memref<f32>
+// CHECK-NOT: memref.load %[[G]][] : memref<f32>
+
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 7303ff581abc4..17d1721b82602 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIROpenACCTests
+  OpenACCCGOpsTest.cpp
   OpenACCOpsTest.cpp
   OpenACCOpsInterfacesTest.cpp
   OpenACCTypeInterfacesTest.cpp
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp
new file mode 100644
index 0000000000000..bbc12b0f829d9
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCCGOpsTest.cpp
@@ -0,0 +1,178 @@
+//===- OpenACCCGOpsTest.cpp - Unit tests for OpenACC codegen ops ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtilsCG.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/PatternMatch.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+using namespace mlir::acc;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class OpenACCCGOpsTest : public ::testing::Test {
+protected:
+  OpenACCCGOpsTest() : b(&context), loc(UnknownLoc::get(&context)) {
+    context.loadDialect<OpenACCDialect, arith::ArithDialect, func::FuncDialect,
+                        gpu::GPUDialect>();
+  }
+
+  /// Module with a single no-arg `func.func`, insertion point at the start of
+  /// its entry block (ready for host values and `buildComputeRegion` IR).
+  struct HostContext {
+    OwningOpRef<ModuleOp> module;
+    IRRewriter rewriter;
+    func::FuncOp func;
+    Block *entry = nullptr;
+
+    HostContext(MLIRContext &ctx, Location loc, OpBuilder &builder,
+                StringRef funcName = "f")
+        : module(ModuleOp::create(builder, loc)), rewriter(&ctx) {
+      rewriter.setInsertionPointToEnd(module->getBody());
+      func = func::FuncOp::create(rewriter, loc, funcName,
+                                  builder.getFunctionType({}, {}));
+      entry = func.addEntryBlock();
+      rewriter.setInsertionPointToEnd(entry);
+      func::ReturnOp::create(rewriter, loc);
+      rewriter.setInsertionPointToStart(entry);
+    }
+  };
+
+  /// Build a single-block region for `buildComputeRegion`: optional one
+  /// argument (mapped to `ins`), optional `arith.addi` of that argument with
+  /// itself, then `acc.yield`. `regionOut` must be empty.
+  static void populateSourceRegionSingleBlock(Region &regionOut,
+                                              MLIRContext &ctx, Location loc,
+                                              std::optional<Type> mapArgType,
+                                              bool addSelfAddi) {
+    assert(regionOut.empty() && "expected an empty region");
+    Block *block = new Block();
+    regionOut.push_back(block);
+    OpBuilder regionBuilder(&ctx);
+    regionBuilder.setInsertionPointToStart(block);
+    if (mapArgType) {
+      BlockArgument arg = block->addArgument(*mapArgType, loc);
+      if (addSelfAddi)
+        arith::AddIOp::create(regionBuilder, loc, arg, arg);
+    }
+    YieldOp::create(regionBuilder, loc);
+  }
+
+  /// Single-block region with an `i32` producer (`arith.constant`) and a user
+  /// (`arith.addi`) both inside the region — valid clone source with no
+  /// external `ins` captures.
+  static void populateSourceRegionWithInternalI32Constant(Region &regionOut,
+                                                          MLIRContext &ctx,
+                                                          Location loc,
+                                                          int64_t cst) {
+    assert(regionOut.empty() && "expected an empty region");
+    Block *block = new Block();
+    regionOut.push_back(block);
+    OpBuilder regionBuilder(&ctx);
+    regionBuilder.setInsertionPointToStart(block);
+    Value k = arith::ConstantIntOp::create(regionBuilder, loc,
+                                           IntegerType::get(&ctx, 32), cst);
+    arith::AddIOp::create(regionBuilder, loc, k, k);
+    YieldOp::create(regionBuilder, loc);
+  }
+
+  MLIRContext context;
+  OpBuilder b;
+  Location loc;
+};
+
+//===----------------------------------------------------------------------===//
+// ComputeRegionOp::wireHoistedValueThroughIns
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCCGOpsTest, WireHoistedValueThroughInsAfterHoisting) {
+  HostContext host(context, loc, b);
+
+  Region sourceRegion;
+  populateSourceRegionWithInternalI32Constant(sourceRegion, context, loc, 7);
+  IRMapping mapping;
+  auto cr = buildComputeRegion(loc, /*launchArgs=*/{}, /*inputArgs=*/{},
+                               SerialOp::getOperationName(), sourceRegion,
+                               host.rewriter, mapping);
+  ASSERT_TRUE(cr);
+
+  arith::ConstantIntOp producer;
+  arith::AddIOp addOp;
+  for (Operation &op : cr.getRegion().front().getOperations()) {
+    if (auto c = dyn_cast<arith::ConstantIntOp>(op))
+      producer = c;
+    else if (auto a = dyn_cast<arith::AddIOp>(op))
+      addOp = a;
+  }
+  ASSERT_TRUE(producer);
+  ASSERT_TRUE(addOp);
+  Value produced = producer.getResult();
+  ASSERT_EQ(addOp.getLhs(), produced);
+  ASSERT_EQ(addOp.getRhs(), produced);
+
+  // Hoist the producer out of the region (same idea as ACCImplicitDeclare).
+  producer->moveBefore(cr.getOperation());
+  ASSERT_TRUE(cr.wireHoistedValueThroughIns(produced).has_value());
+
+  EXPECT_EQ(addOp.getLhs(), addOp.getRhs());
+  EXPECT_TRUE(isa<BlockArgument>(addOp.getLhs()));
+  EXPECT_TRUE(isa<BlockArgument>(addOp.getRhs()));
+  EXPECT_TRUE(succeeded(host.module->verify()));
+}
+
+TEST_F(OpenACCCGOpsTest, WireHoistedValueThroughInsNoUseInside) {
+  HostContext host(context, loc, b);
+  Value v = arith::ConstantIntOp::create(host.rewriter, loc, b.getI32Type(), 1);
+  Value w = arith::ConstantIntOp::create(host.rewriter, loc, b.getI32Type(), 2);
+
+  Region sourceRegion;
+  populateSourceRegionSingleBlock(sourceRegion, context, loc,
+                                  /*mapArgType=*/std::nullopt,
+                                  /*addSelfAddi=*/false);
+  IRMapping mapping;
+  auto cr = buildComputeRegion(loc, /*launchArgs=*/{}, ValueRange(v),
+                               SerialOp::getOperationName(), sourceRegion,
+                               host.rewriter, mapping);
+  ASSERT_TRUE(cr);
+
+  EXPECT_FALSE(cr.wireHoistedValueThroughIns(w).has_value());
+  EXPECT_TRUE(succeeded(host.module->verify()));
+}
+
+TEST_F(OpenACCCGOpsTest, WireHoistedValueThroughInsDefinedInside) {
+  HostContext host(context, loc, b);
+  auto c128 = arith::ConstantIndexOp::create(host.rewriter, loc, 128);
+  auto threadXDim = GPUParallelDimAttr::threadXDim(&context);
+  auto pw = ParWidthOp::create(host.rewriter, loc, c128, threadXDim);
+
+  Region sourceRegion;
+  populateSourceRegionSingleBlock(sourceRegion, context, loc,
+                                  /*mapArgType=*/std::nullopt,
+                                  /*addSelfAddi=*/false);
+  IRMapping mapping;
+  auto cr = buildComputeRegion(loc, ValueRange(pw), /*inputArgs=*/{},
+                               ParallelOp::getOperationName(), sourceRegion,
+                               host.rewriter, mapping);
+  ASSERT_TRUE(cr);
+
+  BlockArgument launchArg = cr.getRegion().front().getArgument(0);
+  EXPECT_FALSE(cr.wireHoistedValueThroughIns(launchArg).has_value());
+  EXPECT_TRUE(succeeded(host.module->verify()));
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/192501


More information about the Mlir-commits mailing list