[Mlir-commits] [mlir] 66f06f5 - [mlir][acc] Sink constants into acc.compute_region when creating (#187777)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 20 12:57:02 PDT 2026
Author: Razvan Lupusoru
Date: 2026-03-20T12:56:58-07:00
New Revision: 66f06f54cb4d9fda87aed346b9d5747d0bc0215e
URL: https://github.com/llvm/llvm-project/commit/66f06f54cb4d9fda87aed346b9d5747d0bc0215e
DIFF: https://github.com/llvm/llvm-project/commit/66f06f54cb4d9fda87aed346b9d5747d0bc0215e.diff
LOG: [mlir][acc] Sink constants into acc.compute_region when creating (#187777)
When converting OpenACC compute constructs to acc.compute_region, also
sink constants inside so they do not become live-ins.
Added:
Modified:
mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
index e0b0acff57cae..9cc36312d3615 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp
@@ -52,6 +52,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -126,6 +127,36 @@ static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
op->setAttr(GPUParallelDimsAttr::name, attr);
}
+/// Clone defining ops of constant live-in values into `region`, rewrite uses
+/// inside the region to the clones, and remove those values from
+/// `liveInValues` so they are not threaded through `acc.compute_region` ins.
+static void materializeConstantLiveInsIntoRegion(Region ®ion,
+ SetVector<Value> &liveInValues,
+ RewriterBase &rewriter) {
+ SmallVector<Value> constantLiveIns;
+ for (Value v : liveInValues) {
+ Operation *defOp = v.getDefiningOp();
+ if (defOp && matchPattern(defOp, m_Constant())) {
+ // As per the definition of ConstantLike trait, constants must have a
+ // single result.
+ assert(defOp->getNumResults() == 1 &&
+ "constants must have a single result");
+ constantLiveIns.push_back(v);
+ }
+ }
+ if (constantLiveIns.empty())
+ return;
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(®ion.front());
+
+ for (Value v : constantLiveIns) {
+ Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0);
+ replaceAllUsesInRegionWith(v, newV, region);
+ liveInValues.remove(v);
+ }
+}
+
/// Insert a parallel dimension into the list, maintaining order by
/// GPUParallelDimAttr::getOrder (descending).
static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
@@ -320,6 +351,7 @@ class ComputeOpConversion : public OpRewritePattern<ComputeConstructT> {
Region ®ion = computeOp.getRegion();
SetVector<Value> liveInValues;
getUsedValuesDefinedAbove(region, region, liveInValues);
+ materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
IRMapping mapping;
auto computeRegion = buildComputeRegion(
computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
diff --git a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
index 77c4ba94c4f18..ee177aaf6e7a7 100644
--- a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir
@@ -105,3 +105,27 @@ func.func @kernels_loop(%buf: memref<8xi32>) {
acc.copyout accPtr(%dev : memref<8xi32>) to varPtr(%buf : memref<8xi32>)
return
}
+
+// -----
+
+// Constant live-ins are cloned into the compute region body so they are not
+// passed through `acc.compute_region` arguments.
+
+// CHECK-LABEL: func.func @constant_livein_materialized_into_compute_region
+func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>) {
+ %c0 = arith.constant 0 : index
+ %c42 = arith.constant 42 : i32
+ %dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
+ // CHECK: acc.kernel_environment
+ // CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) {
+ // CHECK-DAG: arith.constant 42 : i32
+ // CHECK-DAG: arith.constant 0 : index
+ // CHECK: memref.store
+ // CHECK: acc.yield
+ acc.serial dataOperands(%dev : memref<1xi32>) {
+ memref.store %c42, %dev[%c0] : memref<1xi32>
+ acc.yield
+ }
+ acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
+ return
+}
More information about the Mlir-commits
mailing list