[Mlir-commits] [mlir] [mlir][OpenMP] Add optional alloc region to reduction decl (PR #102522)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 8 12:12:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Tom Eccles (tblah)

<details>
<summary>Changes</summary>

This region is intended to separate alloca operations from reduction variable initialization. This makes it easier to hoist allocas to the entry block before control flow and complex code for initialization.

The verifier checks that there is at most one block in the alloc region. This is not sufficient to avoid control flow in general MLIR, but by the time we are converting to LLVMIR structured control flow should already have been lowered to the cf dialect.

---
Full diff: https://github.com/llvm/llvm-project/pull/102522.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+31-7) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+55-17) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+84-1) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+30) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 68f92e6952694b..f54e8b3f924bad 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1523,18 +1523,29 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
     Declares an OpenMP reduction kind. This requires two mandatory and two
     optional regions.
 
-      1. The initializer region specifies how to initialize the thread-local
+      1. The optional alloc region specifies how to allocate the thread-local
+         reduction value. This region should not contain control flow and all
+         IR should be suitable for inlining straight into an entry block. In
+         the common case this is expected to contain only allocas. It is
+         expected to `omp.yield` the allocated value on all control paths.
+         If allocation is conditional (e.g. only allocate if the mold is
+         allocated), this should be done in the initilizer region and this
+         region not included. The alloc region is not used for by-value
+         reductions (where allocation is implicit).
+      2. The initializer region specifies how to initialize the thread-local
          reduction value. This is usually the neutral element of the reduction.
          For convenience, the region has an argument that contains the value
-         of the reduction accumulator at the start of the reduction. It is
-         expected to `omp.yield` the new value on all control flow paths.
-      2. The reduction region specifies how to combine two values into one, i.e.
+         of the reduction accumulator at the start of the reduction. If an alloc
+         region is specified, there is a second block argument containing the
+         address of the allocated memory. The initializer region is expected to
+         `omp.yield` the new value on all control flow paths.
+      3. The reduction region specifies how to combine two values into one, i.e.
          the reduction operator. It accepts the two values as arguments and is
          expected to `omp.yield` the combined value on all control flow paths.
-      3. The atomic reduction region is optional and specifies how two values
+      4. The atomic reduction region is optional and specifies how two values
          can be combined atomically given local accumulator variables. It is
          expected to store the combined value in the first accumulator variable.
-      4. The cleanup region is optional and specifies how to clean up any memory
+      5. The cleanup region is optional and specifies how to clean up any memory
          allocated by the initializer region. The region has an argument that
          contains the value of the thread-local reduction accumulator. This will
          be executed after the reduction has completed.
@@ -1550,12 +1561,14 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttr:$type);
 
-  let regions = (region AnyRegion:$initializerRegion,
+  let regions = (region MaxSizedRegion<1>:$allocRegion,
+                        AnyRegion:$initializerRegion,
                         AnyRegion:$reductionRegion,
                         AnyRegion:$atomicReductionRegion,
                         AnyRegion:$cleanupRegion);
 
   let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword "
+                       "custom<AllocReductionRegion>($allocRegion) "
                        "`init` $initializerRegion "
                        "`combiner` $reductionRegion "
                        "custom<AtomicReductionRegion>($atomicReductionRegion) "
@@ -1568,6 +1581,17 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
 
       return cast<PointerLikeType>(getAtomicReductionRegion().front().getArgument(0).getType());
     }
+
+    Value getInitializerMoldArg() {
+      return getInitializerRegion().front().getArgument(0);
+    }
+
+    Value getInitializerAllocArg() {
+      if (getAllocRegion().empty() ||
+          getInitializerRegion().front().getNumArguments() != 2)
+        return {nullptr};
+      return getInitializerRegion().front().getArgument(1);
+    }
   }];
   let hasRegionVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 11780f84697b15..7acbb9a8f37c01 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1839,46 +1839,84 @@ LogicalResult DistributeOp::verify() {
 // DeclareReductionOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
-                                              Region &region) {
-  if (parser.parseOptionalKeyword("atomic"))
+static ParseResult parseOptionalReductionRegion(OpAsmParser &parser,
+                                                Region &region,
+                                                StringRef keyword) {
+  if (parser.parseOptionalKeyword(keyword))
     return success();
   return parser.parseRegion(region);
 }
 
-static void printAtomicReductionRegion(OpAsmPrinter &printer,
-                                       DeclareReductionOp op, Region &region) {
+static void printOptionalReductionRegion(OpAsmPrinter &printer, Region &region,
+                                         StringRef keyword) {
   if (region.empty())
     return;
-  printer << "atomic ";
+  printer << keyword << " ";
   printer.printRegion(region);
 }
 
+static ParseResult parseAllocReductionRegion(OpAsmParser &parser,
+                                             Region &region) {
+  return parseOptionalReductionRegion(parser, region, "alloc");
+}
+
+static void printAllocReductionRegion(OpAsmPrinter &printer,
+                                      DeclareReductionOp op, Region &region) {
+  printOptionalReductionRegion(printer, region, "alloc");
+}
+
+static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
+                                              Region &region) {
+  return parseOptionalReductionRegion(parser, region, "atomic");
+}
+
+static void printAtomicReductionRegion(OpAsmPrinter &printer,
+                                       DeclareReductionOp op, Region &region) {
+  printOptionalReductionRegion(printer, region, "atomic");
+}
+
 static ParseResult parseCleanupReductionRegion(OpAsmParser &parser,
                                                Region &region) {
-  if (parser.parseOptionalKeyword("cleanup"))
-    return success();
-  return parser.parseRegion(region);
+  return parseOptionalReductionRegion(parser, region, "cleanup");
 }
 
 static void printCleanupReductionRegion(OpAsmPrinter &printer,
                                         DeclareReductionOp op, Region &region) {
-  if (region.empty())
-    return;
-  printer << "cleanup ";
-  printer.printRegion(region);
+  printOptionalReductionRegion(printer, region, "cleanup");
 }
 
 LogicalResult DeclareReductionOp::verifyRegions() {
+  if (!getAllocRegion().empty()) {
+    for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
+      if (yieldOp.getResults().size() != 1 ||
+          yieldOp.getResults().getTypes()[0] != getType())
+        return emitOpError() << "expects alloc region to yield a value "
+                                "of the reduction type";
+    }
+  }
+
   if (getInitializerRegion().empty())
     return emitOpError() << "expects non-empty initializer region";
   Block &initializerEntryBlock = getInitializerRegion().front();
-  if (initializerEntryBlock.getNumArguments() != 1 ||
-      initializerEntryBlock.getArgument(0).getType() != getType()) {
-    return emitOpError() << "expects initializer region with one argument "
-                            "of the reduction type";
+
+  if (initializerEntryBlock.getNumArguments() == 1) {
+    if (!getAllocRegion().empty())
+      return emitOpError() << "expects two arguments to the initializer region "
+                              "when an allocation region is used";
+  } else if (initializerEntryBlock.getNumArguments() == 2) {
+    if (getAllocRegion().empty())
+      return emitOpError() << "expects one argument to the initializer region "
+                              "when no allocation region is used";
+  } else {
+    return emitOpError()
+           << "expects one or two arguments to the initializer region";
   }
 
+  for (mlir::Value arg : initializerEntryBlock.getArguments())
+    if (arg.getType() != getType())
+      return emitOpError() << "expects initializer region argument to match "
+                              "the reduction type";
+
   for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
     if (yieldOp.getResults().size() != 1 ||
         yieldOp.getResults().getTypes()[0] != getType())
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1d1d93f0977588..e62e5a0b368a30 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -565,7 +565,63 @@ func.func @omp_simd_pretty_simdlen_safelen(%lb : index, %ub : index, %step : ind
 
 // -----
 
-// expected-error @below {{op expects initializer region with one argument of the reduction type}}
+// expected-error @below {{op expects alloc region to yield a value of the reduction type}}
+omp.declare_reduction @add_f32 : f32
+alloc {
+^bb0(%arg: f32):
+// nonsense test code
+  %0 = arith.constant 0.0 : f64
+  omp.yield (%0 : f64)
+}
+init {
+^bb0(%arg0: f32, %arg1: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{op expects two arguments to the initializer region when an allocation region is used}}
+omp.declare_reduction @add_f32 : f32
+alloc {
+^bb0(%arg: f32):
+// nonsense test code
+  omp.yield (%arg : f32)
+}
+init {
+^bb0(%arg0: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{op expects one argument to the initializer region when no allocation region is used}}
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32, %arg2: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+// -----
+
+// expected-error @below {{op expects initializer region argument to match the reduction type}}
 omp.declare_reduction @add_f32 : f64
 init {
 ^bb0(%arg: f32):
@@ -683,6 +739,33 @@ cleanup {
 
 // -----
 
+// expected-error @below {{op region #0 ('allocRegion') failed to verify constraint: region with at most 1 blocks}}
+omp.declare_reduction @alloc_reduction : !llvm.ptr
+alloc {
+^bb0(%arg: !llvm.ptr):
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
+  cf.br ^bb1(%0: !llvm.ptr)
+^bb1(%ret: !llvm.ptr):
+  omp.yield (%ret : !llvm.ptr)
+}
+init {
+^bb0(%arg: !llvm.ptr):
+  %cst = arith.constant 1.0 : f32
+  llvm.store %cst, %arg : f32, !llvm.ptr
+  omp.yield (%arg : !llvm.ptr)
+}
+combiner {
+^bb1(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  %0 = llvm.load %arg0 : !llvm.ptr -> f32
+  %1 = llvm.load %arg1 : !llvm.ptr -> f32
+  %2 = arith.addf %0, %1 : f32
+  llvm.store %2, %arg0 : f32, !llvm.ptr
+  omp.yield (%arg0 : !llvm.ptr)
+}
+
+// -----
+
 func.func @foo(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d2924998f41b87..c75a542a8f7e1c 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2520,6 +2520,36 @@ atomic {
   omp.yield
 }
 
+// CHECK-LABEL: @alloc_reduction
+// CHECK-SAME:  alloc {
+// CHECK-NEXT:  ^bb0(%[[ARG0:.*]]: !llvm.ptr):
+// ...
+// CHECK:         omp.yield
+// CHECK-NEXT:  } init {
+// CHECK:       } combiner {
+// CHECK:       }
+omp.declare_reduction @alloc_reduction : !llvm.ptr
+alloc {
+^bb0(%arg: !llvm.ptr):
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
+  omp.yield (%0 : !llvm.ptr)
+}
+init {
+^bb0(%mold: !llvm.ptr, %alloc: !llvm.ptr):
+  %cst = arith.constant 1.0 : f32
+  llvm.store %cst, %alloc : f32, !llvm.ptr
+  omp.yield (%alloc : !llvm.ptr)
+}
+combiner {
+^bb1(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  %0 = llvm.load %arg0 : !llvm.ptr -> f32
+  %1 = llvm.load %arg1 : !llvm.ptr -> f32
+  %2 = arith.addf %0, %1 : f32
+  llvm.store %2, %arg0 : f32, !llvm.ptr
+  omp.yield (%arg0 : !llvm.ptr)
+}
+
 // CHECK-LABEL: omp_targets_with_map_bounds
 // CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr)
 func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> () {

``````````

</details>


https://github.com/llvm/llvm-project/pull/102522


More information about the Mlir-commits mailing list