[Mlir-commits] [mlir] 1b433e9 - [mlir][acc] Add canonicalization patterns for compute_region (#192376)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 16 09:00:21 PDT 2026


Author: Razvan Lupusoru
Date: 2026-04-16T09:00:17-07:00
New Revision: 1b433e936fbeef8fc1c649ad223719df897d311f

URL: https://github.com/llvm/llvm-project/commit/1b433e936fbeef8fc1c649ad223719df897d311f
DIFF: https://github.com/llvm/llvm-project/commit/1b433e936fbeef8fc1c649ad223719df897d311f.diff

LOG: [mlir][acc] Add canonicalization patterns for compute_region (#192376)

This PR improves the APIs for navigating through acc.compute_region
block arguments and also adds canonicalization patterns for those
arguments to remove unused ones and merge duplicates.

Added: 
    mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
    mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
    mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
index 63fc8476c08d4..76902a6d2690e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
@@ -297,6 +297,11 @@ def OpenACC_ComputeRegionOp
     region (e.g., `"acc.parallel"`, `"acc.kernels"`). This is intended to
     be solely informational.
 
+    Canonicalization may simplify `ins` captures: duplicate `ins` operands
+    (same SSA value threaded more than once) are merged by reusing the first
+    block argument, and unused `ins` operands (block arguments with no uses)
+    are removed. `launch` operands are never merged or dropped.
+
     Example:
 
     ```mlir
@@ -327,6 +332,8 @@ def OpenACC_ComputeRegionOp
 
   let regions = (region AnyRegion:$region);
 
+  let hasCanonicalizer = 1;
+
   let extraClassDeclaration = [{
     /// Look up the par_width op for the given dimension among launch args.
     std::optional<mlir::Value> getLaunchArg(
@@ -365,9 +372,17 @@ def OpenACC_ComputeRegionOp
       return &getRegion().back().back();
     }
 
-    /// Map a block argument back to its corresponding operand
-    /// ($launchArgs or $inputArgs).
+    /// Return the `launch` or `ins` operand threaded to `blockArg`, or a null
+    /// `Value` if `blockArg` is not an argument of `getBody()` or its index is
+    /// out of range for this op's `launch` and `ins` operands.
     ::mlir::Value getOperand(::mlir::BlockArgument blockArg);
+
+    /// If `value` is a launch or input operand, return the body block argument
+    /// it is threaded through; otherwise `std::nullopt`. If `value` matches
+    /// more than one `ins` operand, the first match is returned (canonicalization
+    /// may merge duplicate `ins` values). Duplicate `launch` operands are not
+    /// folded.
+    std::optional<::mlir::BlockArgument> getBlockArg(::mlir::Value value);
   }];
 
   let hasVerifier = 1;

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
index 04f8c848c7287..de243f3dadd74 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Region.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
@@ -107,6 +108,90 @@ struct RemoveEmptyKernelEnvironment
   }
 };
 
+static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
+                                                    PatternRewriter &rewriter,
+                                                    size_t numInput) {
+  const size_t numLaunch = op.getLaunchArgs().size();
+  op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
+              rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunch),
+                                             static_cast<int32_t>(numInput),
+                                             op.getStream() ? 1 : 0}));
+}
+
+struct ComputeRegionRemoveDuplicateArgs
+    : public OpRewritePattern<ComputeRegionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ComputeRegionOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *body = op.getBody();
+    const size_t numLaunch = op.getLaunchArgs().size();
+    size_t numInput = op.getInputArgs().size();
+    assert(body->getNumArguments() == numLaunch + numInput &&
+           "region args mismatch");
+
+    bool mergedAny = false;
+    while (true) {
+      bool merged = false;
+      for (size_t j = 1; j < numInput && !merged; ++j) {
+        for (size_t i = 0; i < j; ++i) {
+          if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
+              op->getOperand(static_cast<unsigned>(numLaunch + j)))
+            continue;
+          unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
+          unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
+          rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
+                                      body->getArgument(keepIdx));
+          body->eraseArgument(dropIdx);
+          op->eraseOperand(dropIdx);
+          --numInput;
+          merged = true;
+          mergedAny = true;
+          break;
+        }
+      }
+      if (!merged)
+        break;
+    }
+
+    if (!mergedAny)
+      return failure();
+    updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+    return success();
+  }
+};
+
+struct ComputeRegionRemoveUnusedArgs
+    : public OpRewritePattern<ComputeRegionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ComputeRegionOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *body = op.getBody();
+    const size_t numLaunch = op.getLaunchArgs().size();
+    size_t numInput = op.getInputArgs().size();
+    assert(body->getNumArguments() == numLaunch + numInput &&
+           "region args mismatch");
+
+    bool changed = false;
+    for (size_t k = numLaunch; k < numLaunch + numInput;) {
+      if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
+        ++k;
+        continue;
+      }
+      body->eraseArgument(static_cast<unsigned>(k));
+      op->eraseOperand(static_cast<unsigned>(k));
+      --numInput;
+      changed = true;
+    }
+
+    if (!changed)
+      return failure();
+    updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+    return success();
+  }
+};
+
 template <typename EffectTy>
 static void addOperandEffect(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -441,15 +526,39 @@ SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
 }
 
 Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
+  Block *body = getBody();
+  if (blockArg.getOwner() != body)
+    return Value();
   unsigned argNumber = blockArg.getArgNumber();
   unsigned numLaunchArgs = getLaunchArgs().size();
-  assert(argNumber < (numLaunchArgs + getInputArgs().size()) &&
-         "invalid block argument");
+  unsigned numInputArgs = getInputArgs().size();
+  if (argNumber >= numLaunchArgs + numInputArgs)
+    return Value();
   if (argNumber < numLaunchArgs)
     return getLaunchArgs()[argNumber];
   return getInputArgs()[argNumber - numLaunchArgs];
 }
 
+std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
+  Block *body = getBody();
+  for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
+    if (launchVal == value)
+      return body->getArgument(idx);
+  }
+  unsigned numLaunch = getLaunchArgs().size();
+  for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
+    if (inputVal == value)
+      return body->getArgument(numLaunch + idx);
+  }
+  return std::nullopt;
+}
+
+void ComputeRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
+      context);
+}
+
 BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
   return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
 }

diff  --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1cc313206a99f..f20ace4398696 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -37,6 +37,8 @@ mlir::Operation *mlir::acc::getACCDataClauseOpForBlockArg(mlir::Value v) {
     return nullptr;
 
   mlir::Value orig = computeReg.getOperand(barg);
+  if (!orig)
+    return nullptr;
   mlir::Operation *def = orig.getDefiningOp();
   return mlir::isa_and_nonnull<ACC_DATA_ENTRY_OPS>(def) ? def : nullptr;
 }

diff  --git a/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
new file mode 100644
index 0000000000000..68b1193508ad6
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt -canonicalize -split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins
+func.func @merge_duplicate_ins() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %m = memref.alloca() : memref<i32>
+  memref.store %c0, %m[] : memref<i32>
+  acc.compute_region ins(%a = %m, %b = %m) : (memref<i32>, memref<i32>) {
+    %c1 = arith.constant 1 : i32
+    %v = memref.load %a[] : memref<i32>
+    %x = arith.addi %v, %c1 : i32
+    memref.store %x, %a[] : memref<i32>
+    acc.yield
+  } {origin = "acc.serial"}
+  %r = memref.load %m[] : memref<i32>
+  return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins_complex_pattern
+func.func @merge_duplicate_ins_complex_pattern() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %ma = memref.alloca() : memref<i32>
+  %mb = memref.alloca() : memref<i32>
+  %mc = memref.alloca() : memref<i32>
+  memref.store %c0, %ma[] : memref<i32>
+  memref.store %c0, %mb[] : memref<i32>
+  memref.store %c0, %mc[] : memref<i32>
+  acc.compute_region ins(%a0 = %ma, %b0 = %mb, %a1 = %ma, %mc0 = %mc, %mc1 = %mc, %b1 = %mb, %a2 = %ma) : (memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>) {
+    %one = arith.constant 1 : i32
+    %v0 = memref.load %a0[] : memref<i32>
+    %v1 = memref.load %b0[] : memref<i32>
+    %v2 = memref.load %a1[] : memref<i32>
+    %v3 = memref.load %mc0[] : memref<i32>
+    %v4 = memref.load %mc1[] : memref<i32>
+    %v5 = memref.load %b1[] : memref<i32>
+    %v6 = memref.load %a2[] : memref<i32>
+    %sum1 = arith.addi %v0, %v1 : i32
+    %sum2 = arith.addi %sum1, %v2 : i32
+    %sum3 = arith.addi %sum2, %v3 : i32
+    %sum4 = arith.addi %sum3, %v4 : i32
+    %sum5 = arith.addi %sum4, %v5 : i32
+    %sum6 = arith.addi %sum5, %v6 : i32
+    %out = arith.addi %sum6, %one : i32
+    memref.store %out, %a0[] : memref<i32>
+    acc.yield
+  } {origin = "acc.serial"}
+  %r = memref.load %ma[] : memref<i32>
+  return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>, memref<i32>, memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @drop_unused_ins
+func.func @drop_unused_ins() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %ma = memref.alloca() : memref<i32>
+  %mb = memref.alloca() : memref<i32>
+  %mc = memref.alloca() : memref<i32>
+  memref.store %c0, %ma[] : memref<i32>
+  memref.store %c0, %mb[] : memref<i32>
+  memref.store %c0, %mc[] : memref<i32>
+  acc.compute_region ins(%a = %ma, %b = %mb, %c = %mc) : (memref<i32>, memref<i32>, memref<i32>) {
+    %c1 = arith.constant 1 : i32
+    %v = memref.load %a[] : memref<i32>
+    %x = arith.addi %v, %c1 : i32
+    memref.store %x, %a[] : memref<i32>
+    acc.yield
+  } {origin = "acc.serial"}
+  %r = memref.load %ma[] : memref<i32>
+  return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {

diff  --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
index e7e5974ed5c70..6fe0ffb2d54fe 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
@@ -216,5 +216,10 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
   }
   EXPECT_TRUE(foundAddI);
 
+  EXPECT_EQ(cr.getOperand(crBlock.getArgument(0)), deviceBlock->getArgument(0));
+  ASSERT_TRUE(cr.getBlockArg(deviceBlock->getArgument(0)).has_value());
+  EXPECT_EQ(*cr.getBlockArg(deviceBlock->getArgument(0)),
+            crBlock.getArgument(0));
+
   func::ReturnOp::create(rewriter, loc);
 }


        


More information about the Mlir-commits mailing list