[Mlir-commits] [mlir] 33da12a - [acc] Lower acc if with multi-block host fallback via scf.execute_region (#188350)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 25 08:09:48 PDT 2026
Author: Susan Tan (ス-ザン タン)
Date: 2026-03-25T11:09:37-04:00
New Revision: 33da12aae70ce26568aa06538329fab0481dcb4d
URL: https://github.com/llvm/llvm-project/commit/33da12aae70ce26568aa06538329fab0481dcb4d
DIFF: https://github.com/llvm/llvm-project/commit/33da12aae70ce26568aa06538329fab0481dcb4d.diff
LOG: [acc] Lower acc if with multi-block host fallback via scf.execute_region (#188350)
handle multi-block host fallback regions by wrapping them in
scf.execute_region, instead of rejecting with `not yet implemented:
region with multiple blocks`.
Added:
Modified:
mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
index 9095d7c915fa8..71df75958a134 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
@@ -59,6 +59,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
@@ -215,22 +216,28 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
// Host execution path (false branch)
- if (!computeConstructOp.getRegion().hasOneBlock()) {
- accSupport->emitNYI(computeConstructOp.getLoc(),
- "region with multiple blocks");
- return;
- }
-
- // Don't need to clone original ops, just take them and legalize for host
- ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
-
- // Swap acc yield for scf yield
- Block &elseBlock = ifOp.getElseRegion().front();
- elseBlock.getTerminator()->erase();
- rewriter.setInsertionPointToEnd(&elseBlock);
- scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
+ Region &hostRegion = computeConstructOp.getRegion();
+ if (hostRegion.hasOneBlock()) {
+ // Don't need to clone original ops, just take them and legalize for host.
+ ifOp.getElseRegion().takeBody(hostRegion);
+
+ // Swap acc yield for scf yield.
+ Block &elseBlock = ifOp.getElseRegion().front();
+ elseBlock.getTerminator()->erase();
+ rewriter.setInsertionPointToEnd(&elseBlock);
+ scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
- convertHostRegion(computeConstructOp, ifOp.getElseRegion());
+ convertHostRegion(computeConstructOp, ifOp.getElseRegion());
+ } else {
+ // scf.if regions must stay single-block. Wrap the original multi-block ACC
+ // body in scf.execute_region so it can be hosted in the else branch.
+ Block &elseBlock = ifOp.getElseRegion().front();
+ rewriter.setInsertionPoint(elseBlock.getTerminator());
+ IRMapping hostMapping;
+ auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion(
+ hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
+ convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
+ }
// The original op is now empty and can be erased
eraseOps.push_back(computeConstructOp);
diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
index 75f3a5cd211e0..4c88df432b6c7 100644
--- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
@@ -37,6 +37,42 @@ func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
// -----
+// Test acc.parallel if lowering when host fallback region has multiple blocks.
+// CHECK-LABEL: func.func @test_parallel_if_multiblock
+func.func @test_parallel_if_multiblock(%cond: i1, %n: i32) {
+ %c0_i32 = arith.constant 0 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %counter = memref.alloca() : memref<i32>
+ memref.store %n, %counter[] : memref<i32>
+
+ // CHECK-NOT: acc.parallel if
+ // CHECK: scf.if %{{.*}} {
+ // CHECK: acc.parallel {
+ // CHECK: } else {
+ // CHECK: scf.execute_region {
+ // CHECK: ^bb
+ // CHECK: cf.cond_br
+ // CHECK: scf.yield
+ // CHECK: }
+ // CHECK: }
+ acc.parallel if(%cond) {
+ cf.br ^bb1
+ ^bb1:
+ %v = memref.load %counter[] : memref<i32>
+ %pred = arith.cmpi sgt, %v, %c0_i32 : i32
+ cf.cond_br %pred, ^bb2, ^bb3
+ ^bb2:
+ %next = arith.subi %v, %c1_i32 : i32
+ memref.store %next, %counter[] : memref<i32>
+ cf.br ^bb1
+ ^bb3:
+ acc.yield
+ }
+ return
+}
+
+// -----
+
// Test acc.kernels with if condition
// CHECK-LABEL: func.func @test_kernels_if
func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {
More information about the Mlir-commits
mailing list