[Mlir-commits] [mlir] 1b433e9 - [mlir][acc] Add canonicalization patterns for compute_region (#192376)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 09:00:21 PDT 2026
Author: Razvan Lupusoru
Date: 2026-04-16T09:00:17-07:00
New Revision: 1b433e936fbeef8fc1c649ad223719df897d311f
URL: https://github.com/llvm/llvm-project/commit/1b433e936fbeef8fc1c649ad223719df897d311f
DIFF: https://github.com/llvm/llvm-project/commit/1b433e936fbeef8fc1c649ad223719df897d311f.diff
LOG: [mlir][acc] Add canonicalization patterns for compute_region (#192376)
This PR improves the APIs for navigating through acc.compute_region
block arguments and also adds canonicalization patterns for those
arguments to remove unused ones and merge duplicates.
Added:
mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
index 63fc8476c08d4..76902a6d2690e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
@@ -297,6 +297,11 @@ def OpenACC_ComputeRegionOp
region (e.g., `"acc.parallel"`, `"acc.kernels"`). This is intended to
be solely informational.
+ Canonicalization may simplify `ins` captures: duplicate `ins` operands
+ (same SSA value threaded more than once) are merged by reusing the first
+ block argument, and unused `ins` operands (block arguments with no uses)
+ are removed. `launch` operands are never merged or dropped.
+
Example:
```mlir
@@ -327,6 +332,8 @@ def OpenACC_ComputeRegionOp
let regions = (region AnyRegion:$region);
+ let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{
/// Look up the par_width op for the given dimension among launch args.
std::optional<mlir::Value> getLaunchArg(
@@ -365,9 +372,17 @@ def OpenACC_ComputeRegionOp
return &getRegion().back().back();
}
- /// Map a block argument back to its corresponding operand
- /// ($launchArgs or $inputArgs).
+ /// Return the `launch` or `ins` operand threaded to `blockArg`, or a null
+ /// `Value` if `blockArg` is not an argument of `getBody()` or its index is
+ /// out of range for this op's `launch` and `ins` operands.
::mlir::Value getOperand(::mlir::BlockArgument blockArg);
+
+ /// If `value` is a launch or input operand, return the body block argument
+ /// it is threaded through; otherwise `std::nullopt`. If `value` matches
+ /// more than one `ins` operand, the first match is returned (canonicalization
+ /// may merge duplicate `ins` values). Duplicate `launch` operands are not
+ /// folded.
+ std::optional<::mlir::BlockArgument> getBlockArg(::mlir::Value value);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
index 04f8c848c7287..de243f3dadd74 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LogicalResult.h"
@@ -107,6 +108,90 @@ struct RemoveEmptyKernelEnvironment
}
};
+static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
+ PatternRewriter &rewriter,
+ size_t numInput) {
+ const size_t numLaunch = op.getLaunchArgs().size();
+ op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
+ rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunch),
+ static_cast<int32_t>(numInput),
+ op.getStream() ? 1 : 0}));
+}
+
+struct ComputeRegionRemoveDuplicateArgs
+ : public OpRewritePattern<ComputeRegionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ComputeRegionOp op,
+ PatternRewriter &rewriter) const override {
+ Block *body = op.getBody();
+ const size_t numLaunch = op.getLaunchArgs().size();
+ size_t numInput = op.getInputArgs().size();
+ assert(body->getNumArguments() == numLaunch + numInput &&
+ "region args mismatch");
+
+ bool mergedAny = false;
+ while (true) {
+ bool merged = false;
+ for (size_t j = 1; j < numInput && !merged; ++j) {
+ for (size_t i = 0; i < j; ++i) {
+ if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
+ op->getOperand(static_cast<unsigned>(numLaunch + j)))
+ continue;
+ unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
+ unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
+ rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
+ body->getArgument(keepIdx));
+ body->eraseArgument(dropIdx);
+ op->eraseOperand(dropIdx);
+ --numInput;
+ merged = true;
+ mergedAny = true;
+ break;
+ }
+ }
+ if (!merged)
+ break;
+ }
+
+ if (!mergedAny)
+ return failure();
+ updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+ return success();
+ }
+};
+
+struct ComputeRegionRemoveUnusedArgs
+ : public OpRewritePattern<ComputeRegionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ComputeRegionOp op,
+ PatternRewriter &rewriter) const override {
+ Block *body = op.getBody();
+ const size_t numLaunch = op.getLaunchArgs().size();
+ size_t numInput = op.getInputArgs().size();
+ assert(body->getNumArguments() == numLaunch + numInput &&
+ "region args mismatch");
+
+ bool changed = false;
+ for (size_t k = numLaunch; k < numLaunch + numInput;) {
+ if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
+ ++k;
+ continue;
+ }
+ body->eraseArgument(static_cast<unsigned>(k));
+ op->eraseOperand(static_cast<unsigned>(k));
+ --numInput;
+ changed = true;
+ }
+
+ if (!changed)
+ return failure();
+ updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+ return success();
+ }
+};
+
template <typename EffectTy>
static void addOperandEffect(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -441,15 +526,39 @@ SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
}
Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
+ Block *body = getBody();
+ if (blockArg.getOwner() != body)
+ return Value();
unsigned argNumber = blockArg.getArgNumber();
unsigned numLaunchArgs = getLaunchArgs().size();
- assert(argNumber < (numLaunchArgs + getInputArgs().size()) &&
- "invalid block argument");
+ unsigned numInputArgs = getInputArgs().size();
+ if (argNumber >= numLaunchArgs + numInputArgs)
+ return Value();
if (argNumber < numLaunchArgs)
return getLaunchArgs()[argNumber];
return getInputArgs()[argNumber - numLaunchArgs];
}
+std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
+ Block *body = getBody();
+ for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
+ if (launchVal == value)
+ return body->getArgument(idx);
+ }
+ unsigned numLaunch = getLaunchArgs().size();
+ for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
+ if (inputVal == value)
+ return body->getArgument(numLaunch + idx);
+ }
+ return std::nullopt;
+}
+
+void ComputeRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
+ context);
+}
+
BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
}
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1cc313206a99f..f20ace4398696 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -37,6 +37,8 @@ mlir::Operation *mlir::acc::getACCDataClauseOpForBlockArg(mlir::Value v) {
return nullptr;
mlir::Value orig = computeReg.getOperand(barg);
+ if (!orig)
+ return nullptr;
mlir::Operation *def = orig.getDefiningOp();
return mlir::isa_and_nonnull<ACC_DATA_ENTRY_OPS>(def) ? def : nullptr;
}
diff --git a/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
new file mode 100644
index 0000000000000..68b1193508ad6
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt -canonicalize -split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins
+func.func @merge_duplicate_ins() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %m = memref.alloca() : memref<i32>
+ memref.store %c0, %m[] : memref<i32>
+ acc.compute_region ins(%a = %m, %b = %m) : (memref<i32>, memref<i32>) {
+ %c1 = arith.constant 1 : i32
+ %v = memref.load %a[] : memref<i32>
+ %x = arith.addi %v, %c1 : i32
+ memref.store %x, %a[] : memref<i32>
+ acc.yield
+ } {origin = "acc.serial"}
+ %r = memref.load %m[] : memref<i32>
+ return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins_complex_pattern
+func.func @merge_duplicate_ins_complex_pattern() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %ma = memref.alloca() : memref<i32>
+ %mb = memref.alloca() : memref<i32>
+ %mc = memref.alloca() : memref<i32>
+ memref.store %c0, %ma[] : memref<i32>
+ memref.store %c0, %mb[] : memref<i32>
+ memref.store %c0, %mc[] : memref<i32>
+ acc.compute_region ins(%a0 = %ma, %b0 = %mb, %a1 = %ma, %mc0 = %mc, %mc1 = %mc, %b1 = %mb, %a2 = %ma) : (memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>) {
+ %one = arith.constant 1 : i32
+ %v0 = memref.load %a0[] : memref<i32>
+ %v1 = memref.load %b0[] : memref<i32>
+ %v2 = memref.load %a1[] : memref<i32>
+ %v3 = memref.load %mc0[] : memref<i32>
+ %v4 = memref.load %mc1[] : memref<i32>
+ %v5 = memref.load %b1[] : memref<i32>
+ %v6 = memref.load %a2[] : memref<i32>
+ %sum1 = arith.addi %v0, %v1 : i32
+ %sum2 = arith.addi %sum1, %v2 : i32
+ %sum3 = arith.addi %sum2, %v3 : i32
+ %sum4 = arith.addi %sum3, %v4 : i32
+ %sum5 = arith.addi %sum4, %v5 : i32
+ %sum6 = arith.addi %sum5, %v6 : i32
+ %out = arith.addi %sum6, %one : i32
+ memref.store %out, %a0[] : memref<i32>
+ acc.yield
+ } {origin = "acc.serial"}
+ %r = memref.load %ma[] : memref<i32>
+ return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>, memref<i32>, memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @drop_unused_ins
+func.func @drop_unused_ins() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %ma = memref.alloca() : memref<i32>
+ %mb = memref.alloca() : memref<i32>
+ %mc = memref.alloca() : memref<i32>
+ memref.store %c0, %ma[] : memref<i32>
+ memref.store %c0, %mb[] : memref<i32>
+ memref.store %c0, %mc[] : memref<i32>
+ acc.compute_region ins(%a = %ma, %b = %mb, %c = %mc) : (memref<i32>, memref<i32>, memref<i32>) {
+ %c1 = arith.constant 1 : i32
+ %v = memref.load %a[] : memref<i32>
+ %x = arith.addi %v, %c1 : i32
+ memref.store %x, %a[] : memref<i32>
+ acc.yield
+ } {origin = "acc.serial"}
+ %r = memref.load %ma[] : memref<i32>
+ return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
index e7e5974ed5c70..6fe0ffb2d54fe 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
@@ -216,5 +216,10 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
}
EXPECT_TRUE(foundAddI);
+ EXPECT_EQ(cr.getOperand(crBlock.getArgument(0)), deviceBlock->getArgument(0));
+ ASSERT_TRUE(cr.getBlockArg(deviceBlock->getArgument(0)).has_value());
+ EXPECT_EQ(*cr.getBlockArg(deviceBlock->getArgument(0)),
+ crBlock.getArgument(0));
+
func::ReturnOp::create(rewriter, loc);
}
More information about the Mlir-commits
mailing list