[Mlir-commits] [flang] [mlir] [mlir][acc] Add OffloadLiveInValueCanonicalization pass (PR #174671)
Razvan Lupusoru
llvmlistbot at llvm.org
Wed Jan 7 09:25:52 PST 2026
https://github.com/razvanlupusoru updated https://github.com/llvm/llvm-project/pull/174671
>From 5f8de784cf18e5b171ee3edfd842dc5128b54166 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 6 Jan 2026 16:07:00 -0800
Subject: [PATCH 1/3] [mlir][acc] Add OffloadLiveInValueCanonicalization pass
Introduce a pass to canonicalize live-in values for regions that
will be outlined for device execution.
When a region is outlined, values defined outside but used inside
become arguments to the outlined function. However, some values
cannot or should not be passed as arguments:
- Synthetic types (shape metadata, field indices)
- Constants better recreated inside the region
- Address-of operations for device-resident globals
This pass identifies such values and either sinks the defining
operation into the region (when all uses are inside) or clones
it inside (when uses exist both inside and outside).
To identify target regions in a dialect-agnostic way, this patch
introduces `OffloadRegionOpInterface`. This marker interface allows
the pass to work uniformly across OpenACC compute constructs,
GPU operations, and other offload dialects without hardcoding
operation types. The interface is attached to `acc.parallel`,
`acc.kernels`, and `acc.serial` directly through TableGen. It is also
being attached to `gpu.launch` and `cuf.kernel`, the latter through
the FIR OpenACC extensions.
The pass leverages existing interfaces for candidate detection:
`OutlineRematerializationOpInterface` marks operations producing
non-argument-passable values, while `ViewLikeOpInterface` and
`PartialEntityAccessOpInterface` allow tracing through casts and
views to find original defining operations. OpenACCSupport
analysis provides symbol validation for address-of operations.
---
.../OpenACC/Support/FIROpenACCOpsInterfaces.h | 8 +
.../Optimizer/OpenACC/Support/CMakeLists.txt | 2 +
.../Support/FIROpenACCOpsInterfaces.cpp | 24 +-
.../Support/RegisterOpenACCExtensions.cpp | 7 +
.../mlir/Dialect/OpenACC/OpenACCBase.td | 4 +-
.../mlir/Dialect/OpenACC/OpenACCOps.td | 3 +
.../Dialect/OpenACC/OpenACCOpsInterfaces.td | 11 +
.../mlir/Dialect/OpenACC/Transforms/Passes.td | 76 +++++
mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 6 +
.../Dialect/OpenACC/Transforms/CMakeLists.txt | 1 +
.../OffloadLiveInValueCanonicalization.cpp | 303 ++++++++++++++++++
12 files changed, 437 insertions(+), 9 deletions(-)
create mode 100644 mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index b017cb4733b6c..c6f52bbd0c64b 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -86,6 +86,14 @@ struct OutlineRematerializationModel
: public mlir::acc::OutlineRematerializationOpInterface::ExternalModel<
OutlineRematerializationModel<Op>, Op> {};
+/// External model for OffloadRegionOpInterface.
+/// This interface marks operations whose regions are targets for offloading
+/// and outlining.
+template <typename Op>
+struct OffloadRegionModel
+ : public mlir::acc::OffloadRegionOpInterface::ExternalModel<
+ OffloadRegionModel<Op>, Op> {};
+
} // namespace fir::acc
#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_
diff --git a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt
index 9c6f0ee74f4cf..cac3ca93207e7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FIROpenACCSupport
RegisterOpenACCExtensions.cpp
DEPENDS
+ CUFDialect
FIRBuilder
FIRDialect
FIRDialectSupport
@@ -15,6 +16,7 @@ add_flang_library(FIROpenACCSupport
HLFIRDialect
LINK_LIBS
+ CUFDialect
FIRBuilder
FIRCodeGenDialect
FIRDialect
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index e4d02e93b041f..dacafb1eeb4b2 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -40,26 +40,34 @@ mlir::Value PartialEntityAccessModel<hlfir::DesignateOp>::getBaseEntity(
mlir::Value PartialEntityAccessModel<fir::DeclareOp>::getBaseEntity(
mlir::Operation *op) const {
- return mlir::cast<fir::DeclareOp>(op).getStorage();
+ auto declareOp = mlir::cast<fir::DeclareOp>(op);
+ // If storage is present, return it (partial view case)
+ if (mlir::Value storage = declareOp.getStorage())
+ return storage;
+ // Otherwise return the memref (complete view case)
+ return declareOp.getMemref();
}
bool PartialEntityAccessModel<fir::DeclareOp>::isCompleteView(
mlir::Operation *op) const {
- // Return false (partial view) only if storage is present
- // Return true (complete view) if storage is absent
- return !getBaseEntity(op);
+ // Complete view if storage is absent
+ return !mlir::cast<fir::DeclareOp>(op).getStorage();
}
mlir::Value PartialEntityAccessModel<hlfir::DeclareOp>::getBaseEntity(
mlir::Operation *op) const {
- return mlir::cast<hlfir::DeclareOp>(op).getStorage();
+ auto declareOp = mlir::cast<hlfir::DeclareOp>(op);
+ // If storage is present, return it (partial view case)
+ if (mlir::Value storage = declareOp.getStorage())
+ return storage;
+ // Otherwise return the memref (complete view case)
+ return declareOp.getMemref();
}
bool PartialEntityAccessModel<hlfir::DeclareOp>::isCompleteView(
mlir::Operation *op) const {
- // Return false (partial view) only if storage is present
- // Return true (complete view) if storage is absent
- return !getBaseEntity(op);
+ // Complete view if storage is absent
+ return !mlir::cast<hlfir::DeclareOp>(op).getStorage();
}
mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const {
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index d7e9ae4ec85b9..4c514599df414 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -12,6 +12,8 @@
#include "flang/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.h"
+#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -84,6 +86,11 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) {
PartialEntityAccessModel<hlfir::DeclareOp>>(*ctx);
});
+ // Register CUF operation interfaces
+ registry.addExtension(+[](mlir::MLIRContext *ctx, cuf::CUFDialect *dialect) {
+ cuf::KernelOp::attachInterface<OffloadRegionModel<cuf::KernelOp>>(*ctx);
+ });
+
registerAttrsExtensions(registry);
}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCBase.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCBase.td
index 2f7dfb2751c91..e01cfeef57a83 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCBase.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCBase.td
@@ -21,7 +21,9 @@ def OpenACC_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
let cppNamespace = "::mlir::acc";
- let dependentDialects = ["::mlir::memref::MemRefDialect","::mlir::LLVM::LLVMDialect"];
+ let dependentDialects = ["::mlir::memref::MemRefDialect",
+ "::mlir::LLVM::LLVMDialect",
+ "::mlir::gpu::GPUDialect"];
}
#endif // OPENACC_BASE
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 73ca362c6dc3d..644d1f8e9e649 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1685,6 +1685,7 @@ def OpenACC_ParallelOp
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<ComputeRegionOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ OffloadRegionOpInterface,
MemoryEffects<[MemWrite<OpenACC_ConstructResource>,
MemRead<OpenACC_CurrentDeviceIdResource>]>]> {
let summary = "parallel construct";
@@ -1885,6 +1886,7 @@ def OpenACC_SerialOp
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<ComputeRegionOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ OffloadRegionOpInterface,
MemoryEffects<[MemWrite<OpenACC_ConstructResource>,
MemRead<OpenACC_CurrentDeviceIdResource>]>]> {
let summary = "serial construct";
@@ -2025,6 +2027,7 @@ def OpenACC_KernelsOp
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<ComputeRegionOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ OffloadRegionOpInterface,
MemoryEffects<[MemWrite<OpenACC_ConstructResource>,
MemRead<OpenACC_CurrentDeviceIdResource>]>]> {
let summary = "kernels construct";
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 3242f25c44399..44632eb4cdac4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -122,4 +122,15 @@ def OutlineRematerializationOpInterface : OpInterface<"OutlineRematerializationO
}];
}
+def OffloadRegionOpInterface : OpInterface<"OffloadRegionOpInterface"> {
+ let cppNamespace = "::mlir::acc";
+
+ let description = [{
+ An interface for operations whose regions are targets for offloading
+ and outlining. Operations implementing this interface indicate that
+ their regions will be extracted and compiled separately (e.g., as
+ device kernels or outlined functions).
+ }];
+}
+
#endif // OPENACC_OPS_INTERFACES
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index 68a52e0706d60..94a4f8732fafa 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -287,4 +287,80 @@ def ACCIfClauseLowering : Pass<"acc-if-clause-lowering", "mlir::func::FuncOp"> {
"mlir::scf::SCFDialect"];
}
+def OffloadLiveInValueCanonicalization : Pass<"offload-livein-value-canonicalization", "mlir::func::FuncOp"> {
+ let summary = "Canonicalize live-in values for regions destined for outlining";
+ let description = [{
+ This pass canonicalizes live-in values for regions destined for outlining.
+ It handles operations that produce synthetic types or values that cannot
+ be passed as arguments to outlined regions.
+
+ The pass performs the following transformations:
+
+ 1. **Sinking**: Operations whose results are only used inside the region
+ are moved into the region. This reduces the number of live-in values
+ and keeps related operations together.
+
+ 2. **Rematerialization**: Operations whose results are used both inside
+ and outside the region are cloned into the region. The uses inside
+ the region are updated to use the cloned operation's results.
+
+ Operations are considered candidates for these transformations if they
+ implement the `OutlineRematerializationOpInterface` or match constant
+ patterns. These operations typically produce synthetic types (shapes,
+ bounds, field indices) that cannot be passed as function arguments.
+
+ The pass iterates until convergence since canonicalizing one value may
+ expose new candidates (e.g., a bounds operation's operands may themselves
+ be constants that should be rematerialized).
+
+ Example transformation (rematerialization):
+ ```mlir
+ // Before:
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %bounds = acc.bounds lowerbound(%c0 : index) upperbound(%c10 : index)
+ acc.parallel {
+ %priv = acc.private varPtr(%ptr : ...) bounds(%bounds) -> ...
+ acc.yield
+ }
+ // %bounds is also used elsewhere
+
+ // After:
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %bounds = acc.bounds lowerbound(%c0 : index) upperbound(%c10 : index)
+ acc.parallel {
+ %c0_clone = arith.constant 0 : index
+ %c10_clone = arith.constant 10 : index
+ %bounds_clone = acc.bounds lowerbound(%c0_clone : index) upperbound(%c10_clone : index)
+ %priv = acc.private varPtr(%ptr : ...) bounds(%bounds_clone) -> ...
+ acc.yield
+ }
+ ```
+
+ Example transformation (sinking):
+ ```mlir
+ // Before:
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %bounds = acc.bounds lowerbound(%c0 : index) upperbound(%c10 : index)
+ acc.parallel {
+ %priv = acc.private varPtr(%ptr : ...) bounds(%bounds) -> ...
+ acc.yield
+ }
+ // %bounds is NOT used elsewhere
+
+ // After:
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %bounds = acc.bounds lowerbound(%c0 : index) upperbound(%c10 : index)
+ %priv = acc.private varPtr(%ptr : ...) bounds(%bounds) -> ...
+ acc.yield
+ }
+ ```
+ }];
+ let dependentDialects = ["mlir::acc::OpenACCDialect"];
+}
+
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
index ed7425bd52525..2bd41d99a3661 100644
--- a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIROpenACCDialect
LINK_LIBS PUBLIC
MLIRIR
+ MLIRGPUDialect
MLIRLLVMDialect
MLIRMemRefDialect
MLIROpenACCMPCommon
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bb8e6881d7d9d..bb643d6d1466b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -291,6 +292,10 @@ struct MemrefGlobalVariableModel
}
};
+struct GPULaunchOffloadRegionModel
+ : public acc::OffloadRegionOpInterface::ExternalModel<
+ GPULaunchOffloadRegionModel, gpu::LaunchOp> {};
+
/// Helper function for any of the times we need to modify an ArrayAttr based on
/// a device type list. Returns a new ArrayAttr with all of the
/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
@@ -387,6 +392,7 @@ void OpenACCDialect::initialize() {
memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
*getContext());
memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
+ gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 3a0ca338766e4..1e2f86964ac0d 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms
ACCSpecializeForDevice.cpp
ACCSpecializeForHost.cpp
LegalizeDataValues.cpp
+ OffloadLiveInValueCanonicalization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp b/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp
new file mode 100644
index 0000000000000..4651068e1e251
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp
@@ -0,0 +1,303 @@
+//===- OffloadLiveInValueCanonicalization.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass canonicalizes live-in values for regions destined for offloading.
+//
+// Overview:
+// ---------
+// When a region is outlined (extracted into a separate function for device
+// execution), values defined outside the region but used inside become
+// arguments to the outlined function. However, some values cannot be passed
+// as arguments because they represent synthetic types (e.g., shape metadata,
+// field indices) or are better handled by recreating them inside the region.
+//
+// This pass identifies such values and either:
+// 1. Sinks the defining operation into the region (if all uses are inside)
+// 2. Rematerializes (clones) the operation inside the region (if there are
+// uses both inside and outside)
+//
+// Transforms:
+// -----------
+// The pass performs two main transformations on live-in values:
+//
+// 1. Sinking: If a candidate operation's result is only used inside the
+// offload region, the operation is moved into the region.
+//
+// 2. Rematerialization: If a candidate operation's result is used both
+// inside and outside the region, the operation is cloned inside the
+// region and uses within the region are updated to use the clone.
+//
+// Candidate operations are:
+// - Constants (matching arith.constant, etc.)
+// - Operations implementing `acc::OutlineRematerializationOpInterface`
+// - Address-of operations (`acc::AddressOfGlobalOpInterface`) referencing
+// symbols that are valid in GPU regions or constant globals
+//
+// The pass traces through view-like operations (`ViewLikeOpInterface`) and
+// partial entity access operations (`acc::PartialEntityAccessOpInterface`)
+// to find the original defining operation before making candidate decisions.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Target Region Identification: Operations representing offload regions
+// must implement `acc::OffloadRegionOpInterface`. This interface marks
+// regions that will be outlined for device execution.
+//
+// 2. Rematerialization Candidates: Operations producing values that should
+// be rematerialized (rather than passed as arguments) should implement
+// `acc::OutlineRematerializationOpInterface`. Examples include operations
+// producing shape metadata, field indices, or other synthetic types.
+//
+// 3. Analysis Registration (Optional): If custom behavior is needed for
+// symbol validation (e.g., determining if a global is valid on device),
+// pre-register `acc::OpenACCSupport` analysis on the parent module.
+// If not registered, default behavior will be used.
+//
+// 4. View-Like Operations: Operations that create views or casts should
+// implement `ViewLikeOpInterface` or `acc::PartialEntityAccessOpInterface`
+// to allow the pass to trace through to the original defining operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_OFFLOADLIVEINVALUECANONICALIZATION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "offload-livein-value-canonicalization"
+
+using namespace mlir;
+
+namespace {
+
+/// Returns true if all users of the given value are inside the region.
+static bool allUsersAreInsideRegion(Value val, Region ®ion) {
+ for (Operation *user : val.getUsers()) {
+ if (!region.isAncestor(user->getParentRegion()))
+ return false;
+ }
+ return true;
+}
+
+/// Traces through view-like and partial entity access operations to find the
+/// original defining value.
+static Value getOriginalValue(Value val) {
+ Value prev;
+ while (val && val != prev) {
+ prev = val;
+ if (auto viewLikeOp = val.getDefiningOp<ViewLikeOpInterface>())
+ val = viewLikeOp.getViewSource();
+ if (auto partialAccess =
+ val.getDefiningOp<acc::PartialEntityAccessOpInterface>()) {
+ Value base = partialAccess.getBaseEntity();
+ if (base)
+ val = base;
+ }
+ }
+ return val;
+}
+
+/// Returns true if the operation is a candidate for rematerialization.
+/// Candidates are operations that:
+/// 1. Match the constant pattern (arith.constant, etc.)
+/// 2. Implement OutlineRematerializationOpInterface
+/// 3. Are address-of operations referencing valid symbols or constant globals
+/// The function traces through view-like operations (casts, reinterpret_cast)
+/// to find the original defining operation before making the determination.
+static bool isRematerializationCandidate(Value val,
+ acc::OpenACCSupport &accSupport) {
+ // Trace through view-like operations to find the original value.
+ Value origVal = getOriginalValue(val);
+ Operation *definingOp = origVal.getDefiningOp();
+ if (!definingOp)
+ return false;
+
+ LLVM_DEBUG(llvm::dbgs() << "\tChecking candidate: " << *definingOp << "\n");
+
+ // Constants are trivial and useful to rematerialize.
+ if (matchPattern(definingOp, m_Constant())) {
+ LLVM_DEBUG(llvm::dbgs() << "\t\t-> constant pattern matched\n");
+ return true;
+ }
+
+ // Operations implementing OutlineRematerializationOpInterface are candidates.
+ if (isa<acc::OutlineRematerializationOpInterface>(definingOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "\t\t-> OutlineRematerializationOpInterface\n");
+ return true;
+ }
+
+ // Address-of operations referencing globals that are valid in GPU regions
+ // or referencing constant globals should be rematerialized.
+ if (auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(definingOp)) {
+ SymbolRefAttr symbol = addrOfOp.getSymbol();
+ LLVM_DEBUG(llvm::dbgs()
+ << "\t\tAddressOfGlobalOpInterface, symbol: " << symbol << "\n");
+
+ // If the symbol is already valid in GPU regions (e.g., has acc.declare),
+ // rematerializing ensures the address refers to the device copy.
+ Operation *globalOp = nullptr;
+ if (accSupport.isValidSymbolUse(definingOp, symbol, &globalOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "\t\t-> isValidSymbolUse: true\n");
+ return true;
+ }
+ LLVM_DEBUG(llvm::dbgs() << "\t\t-> isValidSymbolUse: false\n");
+
+ // If the referenced global is constant, prefer rematerialization so the
+ // constant can be placed in GPU memory.
+ if (globalOp) {
+ if (auto globalVarOp =
+ dyn_cast<acc::GlobalVariableOpInterface>(globalOp)) {
+ if (globalVarOp.isConstant()) {
+ LLVM_DEBUG(llvm::dbgs() << "\t\t-> constant global\n");
+ return true;
+ }
+ }
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "\t\t-> not a candidate\n");
+ return false;
+}
+
+class OffloadLiveInValueCanonicalization
+ : public acc::impl::OffloadLiveInValueCanonicalizationBase<
+ OffloadLiveInValueCanonicalization> {
+public:
+ using acc::impl::OffloadLiveInValueCanonicalizationBase<
+ OffloadLiveInValueCanonicalization>::
+ OffloadLiveInValueCanonicalizationBase;
+
+ /// Canonicalizes live-in values for a region by sinking or rematerializing
+ /// operations. Returns true if any changes were made.
+ bool canonicalizeLiveInValues(Region ®ion,
+ acc::OpenACCSupport &accSupport) {
+ // 1) Collect live-in values.
+ SetVector<Value> liveInValues;
+ getUsedValuesDefinedAbove(region, liveInValues);
+ LLVM_DEBUG(llvm::dbgs()
+ << "\tFound " << liveInValues.size() << " live-in value(s)\n");
+
+ auto isSinkCandidate = [®ion, &accSupport](Value val) -> bool {
+ return isRematerializationCandidate(val, accSupport) &&
+ allUsersAreInsideRegion(val, region);
+ };
+ auto isCloneCandidate = [®ion, &accSupport](Value val) -> bool {
+ return isRematerializationCandidate(val, accSupport) &&
+ !allUsersAreInsideRegion(val, region);
+ };
+
+ // 2) Filter values into two sets - sink and rematerialization candidates.
+ SmallVector<Value> sinkCandidates(
+ llvm::make_filter_range(liveInValues, isSinkCandidate));
+ SmallVector<Value> rematerializationCandidates(
+ llvm::make_filter_range(liveInValues, isCloneCandidate));
+
+ LLVM_DEBUG(llvm::dbgs() << "\tSink candidates: " << sinkCandidates.size()
+ << ", clone candidates: "
+ << rematerializationCandidates.size() << "\n");
+
+ if (rematerializationCandidates.empty() && sinkCandidates.empty())
+ return false;
+
+ LLVM_DEBUG(llvm::dbgs() << "\tCanonicalizing values into "
+ << *region.getParentOp() << "\n");
+
+ // 3) Handle the sink set by moving the operations into the region.
+ for (Value sinkCandidate : sinkCandidates) {
+ Operation *sinkOp = sinkCandidate.getDefiningOp();
+ assert(sinkOp && "must have op to be considered");
+ sinkOp->moveBefore(®ion.front().front());
+ LLVM_DEBUG(llvm::dbgs() << "\t\tSunk: " << *sinkOp << "\n");
+ }
+
+ // 4) Handle the rematerialization set by copying the operations into
+ // the region.
+ OpBuilder builder(region);
+ SmallVector<Operation *> opsToRematerialize;
+ for (Value rematerializationCandidate : rematerializationCandidates) {
+ Operation *rematerializationOp =
+ rematerializationCandidate.getDefiningOp();
+ assert(rematerializationOp && "must have op to be considered");
+ opsToRematerialize.push_back(rematerializationOp);
+ }
+ computeTopologicalSorting(opsToRematerialize);
+ for (Operation *rematerializationOp : opsToRematerialize) {
+ Operation *clonedOp = builder.clone(*rematerializationOp);
+ for (auto [oldResult, newResult] : llvm::zip(
+ rematerializationOp->getResults(), clonedOp->getResults())) {
+ replaceAllUsesInRegionWith(oldResult, newResult, region);
+ }
+ LLVM_DEBUG(llvm::dbgs() << "\t\tCloned: " << *clonedOp << "\n");
+ }
+
+ return true;
+ }
+
+ void runOnOperation() override {
+ LLVM_DEBUG(llvm::dbgs() << "Enter OffloadLiveInValueCanonicalization\n");
+
+ // Since OpenACCSupport is normally registered on modules, attempt to
+ // get it from the parent module first (if available), then fallback
+ // to the per-function analysis.
+ acc::OpenACCSupport *accSupportPtr = nullptr;
+ if (auto parentAnalysis = getCachedParentAnalysis<acc::OpenACCSupport>())
+ accSupportPtr = &parentAnalysis->get();
+ else
+ accSupportPtr = &getAnalysis<acc::OpenACCSupport>();
+ acc::OpenACCSupport &accSupport = *accSupportPtr;
+
+ func::FuncOp func = getOperation();
+ LLVM_DEBUG(llvm::dbgs()
+ << "Processing function: " << func.getName() << "\n");
+
+ func.walk([&](Operation *op) {
+ if (isa<acc::OffloadRegionOpInterface>(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Found offload region: " << op->getName() << "\n");
+ assert(op->getNumRegions() == 1 && "must have 1 region");
+
+ // Canonicalization of values changes live-in set.
+ // Rerun the algorithm until convergence.
+ bool changes = false;
+ int iteration = 0;
+ do {
+ LLVM_DEBUG(llvm::dbgs() << "\tIteration " << iteration++ << "\n");
+ changes = canonicalizeLiveInValues(op->getRegion(0), accSupport);
+ } while (changes);
+ LLVM_DEBUG(llvm::dbgs()
+ << "\tConverged after " << iteration << " iteration(s)\n");
+ }
+ });
+
+ LLVM_DEBUG(llvm::dbgs() << "Exit OffloadLiveInValueCanonicalization\n");
+ }
+};
+
+} // namespace
>From b6321609a4e620064f53b028954d8d2db01e745f Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 7 Jan 2026 09:05:12 -0800
Subject: [PATCH 2/3] Add tests exercising the pass
---
...-offload-livein-value-canonicalization.fir | 79 ++++++
.../offload-livein-value-canonicalization.fir | 255 ++++++++++++++++++
...offload-livein-value-canonicalization.mlir | 240 +++++++++++++++++
3 files changed, 574 insertions(+)
create mode 100644 flang/test/Fir/CUDA/cuf-offload-livein-value-canonicalization.fir
create mode 100644 flang/test/Fir/OpenACC/offload-livein-value-canonicalization.fir
create mode 100644 mlir/test/Dialect/OpenACC/offload-livein-value-canonicalization.mlir
diff --git a/flang/test/Fir/CUDA/cuf-offload-livein-value-canonicalization.fir b/flang/test/Fir/CUDA/cuf-offload-livein-value-canonicalization.fir
new file mode 100644
index 0000000000000..50ca2d02bffad
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuf-offload-livein-value-canonicalization.fir
@@ -0,0 +1,79 @@
+// RUN: fir-opt %s -offload-livein-value-canonicalization -split-input-file | FileCheck %s
+
+// -----
+
+// Test constant sinking into cuf.kernel
+func.func @test_constant_sink() {
+ %c1 = arith.constant 1 : index
+ %c1_i32 = arith.constant 1 : i32
+ cuf.kernel<<<%c1_i32, %c1_i32>>> (%arg0 : index) = (%c1 : index) to (%c1 : index) step (%c1 : index) {
+ %res = arith.addi %c1, %c1 : index
+ "fir.end"() : () -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: @test_constant_sink
+// CHECK: cuf.kernel
+// CHECK: arith.constant 1 : index
+
+// -----
+
+// Test constant rematerialization with cuf.kernel
+func.func @test_constant_rematerialize() {
+ %c1 = arith.constant 1 : index
+ %c1_i32 = arith.constant 1 : i32
+ %res = arith.addi %c1, %c1 : index
+ cuf.kernel<<<%c1_i32, %c1_i32>>> (%arg0 : index) = (%c1 : index) to (%c1 : index) step (%c1 : index) {
+ %res2 = arith.addi %c1, %c1 : index
+ "fir.end"() : () -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: @test_constant_rematerialize
+// CHECK: arith.constant 1 : index
+// CHECK: cuf.kernel
+// CHECK: arith.constant 1 : index
+
+// -----
+
+// Test fir.shape sinking into cuf.kernel
+func.func @test_firshape_sink() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %shape = fir.shape %c0 : (index) -> !fir.shape<1>
+ %c1_i32 = arith.constant 1 : i32
+ cuf.kernel<<<%c1_i32, %c1_i32>>> (%arg0 : index) = (%c1 : index) to (%c1 : index) step (%c1 : index) {
+ %zeroaddr = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+ %box = fir.embox %zeroaddr(%shape) : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+ "fir.end"() : () -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: @test_firshape_sink
+// CHECK: cuf.kernel
+// CHECK: fir.shape
+
+// -----
+
+// Test 2D shape sinking with cuf.kernel and array operations
+func.func @test_2d_shape_sink() {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c32 = arith.constant 32 : index
+ %0 = cuf.alloc !fir.array<3x32xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} -> !fir.ref<!fir.array<3x32xf32>>
+ %shape = fir.shape %c3, %c32 : (index, index) -> !fir.shape<2>
+ %decl = fir.declare %0(%shape) {data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} : (!fir.ref<!fir.array<3x32xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<3x32xf32>>
+ %c1_i32 = arith.constant 1 : i32
+ cuf.kernel<<<*, *>>> (%arg0 : index) = (%c1 : index) to (%c32 : index) step (%c1 : index) {
+ %coor = fir.array_coor %decl(%shape) %c1, %arg0 : (!fir.ref<!fir.array<3x32xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+ "fir.end"() : () -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: @test_2d_shape_sink
+// CHECK: cuf.kernel
+// CHECK: fir.shape
diff --git a/flang/test/Fir/OpenACC/offload-livein-value-canonicalization.fir b/flang/test/Fir/OpenACC/offload-livein-value-canonicalization.fir
new file mode 100644
index 0000000000000..6ecccde39d3fb
--- /dev/null
+++ b/flang/test/Fir/OpenACC/offload-livein-value-canonicalization.fir
@@ -0,0 +1,255 @@
+// RUN: fir-opt %s -offload-livein-value-canonicalization -split-input-file | FileCheck %s
+
+// -----
+
+// Test fir.shape sinking
+func.func private @use_box(!fir.box<!fir.heap<!fir.array<?xf32>>>) -> ()
+
+func.func @test_firshape_sink() {
+ %c10 = arith.constant 10 : index
+ %shape = fir.shape %c10 : (index) -> !fir.shape<1>
+ acc.serial {
+ %zeroaddr = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+ %box = fir.embox %zeroaddr(%shape) : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+ fir.call @use_box(%box) : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_firshape_sink
+// CHECK: acc.serial {
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[SHAPE:.*]] = fir.shape %[[C10]]
+// CHECK: fir.embox {{.*}}(%[[SHAPE]])
+
+// -----
+
+// Test fir.shape rematerialization
+func.func private @use_box(!fir.box<!fir.heap<!fir.array<?xf32>>>) -> ()
+
+func.func @test_firshape_rematerialize() {
+ %c10 = arith.constant 10 : index
+ %shape = fir.shape %c10 : (index) -> !fir.shape<1>
+ %zeroaddr = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+ %box = fir.embox %zeroaddr(%shape) : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+ fir.call @use_box(%box) : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> ()
+ acc.serial {
+ %addr = fir.box_addr %box : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.ref<!fir.array<?xf32>>
+ %box2 = fir.embox %addr(%shape) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+ fir.call @use_box(%box2) : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_firshape_rematerialize
+// CHECK: %[[SHAPE_OUTER:.*]] = fir.shape
+// CHECK: fir.embox {{.*}}(%[[SHAPE_OUTER]])
+// CHECK: acc.serial {
+// CHECK: %[[SHAPE_INNER:.*]] = fir.shape
+// CHECK: fir.embox {{.*}}(%[[SHAPE_INNER]])
+
+// -----
+
+// Test fir.shape_shift sinking
+func.func private @use_box(!fir.box<!fir.array<?xf32>>) -> ()
+
+func.func @test_shapeshift_sink(%arg0: !fir.ref<!fir.array<?xf32>>) {
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %shapeshift = fir.shape_shift %c1, %c10 : (index, index) -> !fir.shapeshift<1>
+ acc.serial {
+ %box = fir.embox %arg0(%shapeshift) : (!fir.ref<!fir.array<?xf32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xf32>>
+ fir.call @use_box(%box) : (!fir.box<!fir.array<?xf32>>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_shapeshift_sink
+// CHECK: acc.serial {
+// CHECK: fir.shape_shift
+// CHECK: fir.embox
+
+// -----
+
+// Test fir.field_index sinking
+func.func private @use_ref(!fir.ref<f32>) -> ()
+
+func.func @test_fieldindex_sink() {
+ %var = fir.alloca !fir.type<_QTmytype{field:f32}>
+ %fieldidx = fir.field_index field, !fir.type<_QTmytype{field:f32}>
+ acc.serial {
+ %coor = fir.coordinate_of %var, %fieldidx : (!fir.ref<!fir.type<_QTmytype{field:f32}>>, !fir.field) -> !fir.ref<f32>
+ fir.call @use_ref(%coor) : (!fir.ref<f32>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_fieldindex_sink
+// CHECK: acc.serial {
+// CHECK: %[[FIELD:.*]] = fir.field_index field
+// CHECK: fir.coordinate_of {{.*}}, %[[FIELD]]
+
+// -----
+
+// Test fir.address_of with acc.declare sinking
+fir.global @global_with_declare {acc.declare = #acc.declare<dataClause = acc_copyin>} : f32 {
+ %0 = arith.constant 0.0 : f32
+ fir.has_value %0 : f32
+}
+
+func.func private @use_ref(!fir.ref<f32>) -> ()
+
+func.func @test_address_of_with_declare_sink() {
+ %addr = fir.address_of(@global_with_declare) : !fir.ref<f32>
+ acc.serial {
+ fir.call @use_ref(%addr) : (!fir.ref<f32>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_address_of_with_declare_sink
+// CHECK: acc.serial {
+// CHECK: %[[ADDR:.*]] = fir.address_of(@global_with_declare)
+// CHECK: fir.call @use_ref(%[[ADDR]])
+
+// -----
+
+// Test fir.address_of with constant global sinking
+fir.global @global_constant constant : f32 {
+ %0 = arith.constant 42.0 : f32
+ fir.has_value %0 : f32
+}
+
+func.func private @use_ref(!fir.ref<f32>) -> ()
+
+func.func @test_address_of_constant_global_sink() {
+ %addr = fir.address_of(@global_constant) : !fir.ref<f32>
+ acc.serial {
+ fir.call @use_ref(%addr) : (!fir.ref<f32>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_address_of_constant_global_sink
+// CHECK: acc.serial {
+// CHECK: %[[ADDR:.*]] = fir.address_of(@global_constant)
+// CHECK: fir.call @use_ref(%[[ADDR]])
+
+// -----
+
+// Test fir.address_of with fir.convert tracing (ViewLikeOpInterface)
+fir.global @global_for_convert {acc.declare = #acc.declare<dataClause = acc_copyin>} : f32 {
+ %0 = arith.constant 0.0 : f32
+ fir.has_value %0 : f32
+}
+
+func.func private @use_ptr(!fir.ptr<f32>) -> ()
+
+func.func @test_address_of_with_convert_sink() {
+ %addr = fir.address_of(@global_for_convert) : !fir.ref<f32>
+ %converted = fir.convert %addr : (!fir.ref<f32>) -> !fir.ptr<f32>
+ acc.serial {
+ fir.call @use_ptr(%converted) : (!fir.ptr<f32>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_address_of_with_convert_sink
+// CHECK: acc.serial {
+// CHECK: fir.address_of(@global_for_convert)
+// CHECK: fir.convert
+
+// -----
+
+// Test fir.declare with PartialEntityAccessOpInterface tracing
+fir.global @global_for_declare {acc.declare = #acc.declare<dataClause = acc_copyin>} : !fir.array<10xf32> {
+ %0 = fir.zero_bits !fir.array<10xf32>
+ fir.has_value %0 : !fir.array<10xf32>
+}
+
+func.func private @use_ref(!fir.ref<!fir.array<10xf32>>) -> ()
+
+func.func @test_address_of_through_declare_sink() {
+ %c10 = arith.constant 10 : index
+ %addr = fir.address_of(@global_for_declare) : !fir.ref<!fir.array<10xf32>>
+ %shape = fir.shape %c10 : (index) -> !fir.shape<1>
+ %decl = fir.declare %addr(%shape) {uniq_name = "global"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xf32>>
+ acc.serial {
+ fir.call @use_ref(%decl) : (!fir.ref<!fir.array<10xf32>>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_address_of_through_declare_sink
+// CHECK: acc.serial {
+// CHECK-DAG: fir.address_of(@global_for_declare)
+// CHECK-DAG: fir.shape
+// CHECK: fir.declare
+
+// -----
+
+// Test 2D shape sinking for array operations
+func.func @test_2d_shape_sink() {
+ %c3 = arith.constant 3 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ %0 = fir.alloca !fir.array<3x32xf32>
+ %shape = fir.shape %c3, %c32 : (index, index) -> !fir.shape<2>
+ %decl = fir.declare %0(%shape) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<3x32xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<3x32xf32>>
+ acc.serial {
+ %coor = fir.array_coor %decl(%shape) %c1, %c1 : (!fir.ref<!fir.array<3x32xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_2d_shape_sink
+// CHECK: acc.serial {
+// CHECK: fir.shape
+// CHECK: fir.array_coor
+
+// -----
+
+// Test acc.bounds sinking with FIR types
+func.func @test_accbounds_sink_fir() {
+ %c1 = arith.constant 1 : index
+ %bounds = acc.bounds upperbound(%c1 : index)
+ acc.serial {
+ %local = fir.alloca i32
+ %priv = acc.private varPtr(%local : !fir.ref<i32>) bounds(%bounds) -> !fir.ref<i32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_accbounds_sink_fir
+// CHECK: acc.serial {
+// CHECK: acc.bounds
+
+// -----
+
+// Test acc.bounds rematerialization with FIR types
+func.func @test_accbounds_rematerialize_fir() {
+ %c1 = arith.constant 1 : index
+ %bounds = acc.bounds upperbound(%c1 : index)
+ %local = fir.alloca i32
+ %priv = acc.private varPtr(%local : !fir.ref<i32>) bounds(%bounds) -> !fir.ref<i32>
+ acc.serial {
+ %priv2 = acc.private varPtr(%local : !fir.ref<i32>) bounds(%bounds) -> !fir.ref<i32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_accbounds_rematerialize_fir
+// CHECK: acc.bounds
+// CHECK: acc.serial {
+// CHECK: acc.bounds
diff --git a/mlir/test/Dialect/OpenACC/offload-livein-value-canonicalization.mlir b/mlir/test/Dialect/OpenACC/offload-livein-value-canonicalization.mlir
new file mode 100644
index 0000000000000..7be62a789e6fa
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/offload-livein-value-canonicalization.mlir
@@ -0,0 +1,240 @@
+// RUN: mlir-opt %s -offload-livein-value-canonicalization -split-input-file | FileCheck %s
+
+// -----
+
+// Test constant sinking: when all uses are inside the region, sink the op.
+func.func private @use_i64(i64) -> ()
+
+func.func @test_constant_sink() {
+ %c1 = arith.constant 1 : i64
+ acc.serial {
+ func.call @use_i64(%c1) : (i64) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_constant_sink
+// CHECK-NEXT: acc.serial {
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : i64
+// CHECK-NEXT: func.call @use_i64(%[[C1]])
+
+// -----
+
+// Test constant rematerialization: when uses exist both inside and outside,
+// clone the op inside the region.
+func.func private @use_i64(i64) -> ()
+
+func.func @test_constant_rematerialize() {
+ %c1 = arith.constant 1 : i64
+ func.call @use_i64(%c1) : (i64) -> ()
+ acc.serial {
+ func.call @use_i64(%c1) : (i64) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_constant_rematerialize
+// CHECK: %[[C1_OUTER:.*]] = arith.constant 1 : i64
+// CHECK: call @use_i64(%[[C1_OUTER]])
+// CHECK: acc.serial {
+// CHECK: %[[C1_INNER:.*]] = arith.constant 1 : i64
+// CHECK: func.call @use_i64(%[[C1_INNER]])
+
+// -----
+
+// Test acc.bounds sinking
+// Note: Using orphan acc.copyin inside compute region is not strictly valid IR,
+// but using acc.private or similar requires recipe declarations which
+// complicates the test. The important thing is testing bounds sinking.
+func.func @test_accbounds_sink(%arg0: memref<10xf32>) {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %bounds = acc.bounds lowerbound(%c0 : index) upperbound(%c10 : index)
+ acc.serial {
+ %copy = acc.copyin varPtr(%arg0 : memref<10xf32>) bounds(%bounds) -> memref<10xf32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_accbounds_sink
+// CHECK: acc.serial {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%[[C0]] : index) upperbound(%[[C10]] : index)
+// CHECK: acc.copyin varPtr({{.*}}) bounds(%[[BOUNDS]])
+
+// -----
+
+// Test acc.bounds rematerialization (bounds used both inside and outside)
+// Note: Using orphan acc.copyin is not strictly valid IR (see comment above).
+func.func @test_accbounds_rematerialize(%arg0: memref<10xf32>) {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %bounds = acc.bounds lowerbound(%c0 : index) upperbound(%c10 : index)
+ %copy_outer = acc.copyin varPtr(%arg0 : memref<10xf32>) bounds(%bounds) -> memref<10xf32>
+ acc.serial {
+ %copy_inner = acc.copyin varPtr(%arg0 : memref<10xf32>) bounds(%bounds) -> memref<10xf32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_accbounds_rematerialize
+// CHECK: %[[BOUNDS_OUTER:.*]] = acc.bounds
+// CHECK: acc.copyin varPtr({{.*}}) bounds(%[[BOUNDS_OUTER]])
+// CHECK: acc.serial {
+// CHECK: %[[BOUNDS_INNER:.*]] = acc.bounds
+// CHECK: acc.copyin varPtr({{.*}}) bounds(%[[BOUNDS_INNER]])
+
+// -----
+
+// Test memref.get_global with acc.declare sinking
+memref.global @memref_global_with_declare : memref<10xf32> = dense<0.0> {acc.declare = #acc.declare<dataClause = acc_copyin>}
+
+func.func private @use_memref(memref<10xf32>) -> ()
+
+func.func @test_memref_get_global_sink() {
+ %memref = memref.get_global @memref_global_with_declare : memref<10xf32>
+ acc.serial {
+ func.call @use_memref(%memref) : (memref<10xf32>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_memref_get_global_sink
+// CHECK: acc.serial {
+// CHECK: %[[MEM:.*]] = memref.get_global @memref_global_with_declare
+// CHECK: func.call @use_memref(%[[MEM]])
+
+// -----
+
+// Test memref.reinterpret_cast traces through to get_global
+memref.global @memref_global_reinterpret : memref<2x5xf32> = dense<0.0> {acc.declare = #acc.declare<dataClause = acc_copyin>}
+
+func.func private @use_memref_1d(memref<10xf32>) -> ()
+
+func.func @test_memref_reinterpret_cast_sink() {
+ %memref = memref.get_global @memref_global_reinterpret : memref<2x5xf32>
+ %reinterpreted = memref.reinterpret_cast %memref to offset: [0], sizes: [10], strides: [1] : memref<2x5xf32> to memref<10xf32>
+ acc.serial {
+ func.call @use_memref_1d(%reinterpreted) : (memref<10xf32>) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_memref_reinterpret_cast_sink
+// CHECK: acc.serial {
+// CHECK: memref.get_global @memref_global_reinterpret
+// CHECK: memref.reinterpret_cast
+
+// -----
+
+// Test with acc.parallel (another OffloadRegionOpInterface)
+func.func private @use_i32(i32) -> ()
+
+func.func @test_parallel_region() {
+ %c42 = arith.constant 42 : i32
+ acc.parallel {
+ func.call @use_i32(%c42) : (i32) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_parallel_region
+// CHECK: acc.parallel {
+// CHECK: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK: func.call @use_i32(%[[C42]])
+
+// -----
+
+// Test with acc.kernels (another OffloadRegionOpInterface)
+func.func private @use_f32(f32) -> ()
+
+func.func @test_kernels_region() {
+ %cst = arith.constant 3.14 : f32
+ acc.kernels {
+ func.call @use_f32(%cst) : (f32) -> ()
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @test_kernels_region
+// CHECK: acc.kernels {
+// CHECK: %[[CST:.*]] = arith.constant 3.14{{.*}} : f32
+// CHECK: func.call @use_f32(%[[CST]])
+
+// -----
+
+// Test multiple constants with mixed sinking/rematerialization
+func.func private @use_index(index) -> ()
+
+func.func @test_multiple_constants() {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ func.call @use_index(%c1) : (index) -> ()
+ acc.serial {
+ func.call @use_index(%c1) : (index) -> ()
+ func.call @use_index(%c2) : (index) -> ()
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @test_multiple_constants
+// CHECK: %[[C1_OUTER:.*]] = arith.constant 1 : index
+// CHECK: call @use_index(%[[C1_OUTER]])
+// CHECK: acc.serial {
+// CHECK-DAG: arith.constant 1 : index
+// CHECK-DAG: arith.constant 2 : index
+
+// -----
+
+// Test with gpu.launch (another OffloadRegionOpInterface)
+func.func private @use_index(index) -> ()
+
+func.func @test_gpu_launch_region() {
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+ func.call @use_index(%c42) : (index) -> ()
+ gpu.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @test_gpu_launch_region
+// CHECK: gpu.launch
+// CHECK: %[[C42:.*]] = arith.constant 42 : index
+// CHECK: func.call @use_index(%[[C42]])
+
+// -----
+
+// Test gpu.launch with constant rematerialization
+func.func private @use_index(index) -> ()
+
+func.func @test_gpu_launch_rematerialize() {
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : index
+ func.call @use_index(%c42) : (index) -> ()
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+ func.call @use_index(%c42) : (index) -> ()
+ gpu.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @test_gpu_launch_rematerialize
+// CHECK: %[[C42_OUTER:.*]] = arith.constant 42 : index
+// CHECK: call @use_index(%[[C42_OUTER]])
+// CHECK: gpu.launch
+// CHECK: %[[C42_INNER:.*]] = arith.constant 42 : index
+// CHECK: func.call @use_index(%[[C42_INNER]])
>From e0d9a00147a91f1caa4eae82a64023f9ba7d1587 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 7 Jan 2026 09:06:58 -0800
Subject: [PATCH 3/3] remove braces
---
.../OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp b/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp
index 4651068e1e251..1392bddb123d0 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/OffloadLiveInValueCanonicalization.cpp
@@ -99,10 +99,9 @@ namespace {
/// Returns true if all users of the given value are inside the region.
static bool allUsersAreInsideRegion(Value val, Region ®ion) {
- for (Operation *user : val.getUsers()) {
+ for (Operation *user : val.getUsers())
if (!region.isAncestor(user->getParentRegion()))
return false;
- }
return true;
}
More information about the Mlir-commits
mailing list