[Mlir-commits] [mlir] [MLIR] [OpenMP] Initial support for OMP ALLOCATE directive op. (PR #147900)

Raghu Maddhipatla llvmlistbot at llvm.org
Thu Jul 10 07:29:42 PDT 2025


https://github.com/raghavendhra updated https://github.com/llvm/llvm-project/pull/147900

>From 4c1539cdb861b33c3874b0cb981b7ca6ae00b8c0 Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Thu, 10 Jul 2025 01:06:10 -0500
Subject: [PATCH] [MLIR] [OpenMP] Initial support for OMP ALLOCATE directive
 op.

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 51 +++++++++++++++++++
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        | 30 +++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 27 ++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 14 +++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 24 +++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 33 ++++++++++++
 6 files changed, 179 insertions(+)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 16c14ef085d6d..311c57fb4446c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -22,6 +22,31 @@
 include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
 include "mlir/IR/SymbolInterfaces.td"
 
+//===----------------------------------------------------------------------===//
+// V5.2: [6.3] `align` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_AlignClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$align
+  );
+
+  let optAssemblyFormat = [{
+    `align` `(` $align `)`
+  }];
+
+  let description = [{
+    The `align` clause is used to specify the byte alignment to use for
+    allocations associated with the construct on which the clause appears.
+  }];
+}
+
+def OpenMP_AlignClause : OpenMP_AlignClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [5.11] `aligned` clause
 //===----------------------------------------------------------------------===//
@@ -84,6 +109,32 @@ class OpenMP_AllocateClauseSkip<
 
 def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [6.4] `allocator` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_AllocatorClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+
+  let arguments = (ins
+    OptionalAttr<AllocatorHandleAttr>:$allocator
+  );
+
+  let optAssemblyFormat = [{
+    `allocator` `(` custom<ClauseAttr>($allocator) `)`
+  }];
+
+  let description = [{
+    `allocator` specifies the memory allocator to be used for allocations
+    associated with the construct on which the clause appears.
+  }];
+}
+
+def OpenMP_AllocatorClause : OpenMP_AllocatorClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // LLVM OpenMP extension `ompx_bare` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 9dbe6897a3304..c080c3fac87d4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -263,4 +263,34 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
   let assemblyFormat = "`(` $value `)`";
 }
 
+
+//===----------------------------------------------------------------------===//
+// allocator_handle enum.
+//===----------------------------------------------------------------------===//
+
+def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
+def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
+def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
+def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
+def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
+def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
+def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
+def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
+def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
+
+def AllocatorHandle : OpenMP_I32EnumAttr<
+    "AllocatorHandle",
+    "OpenMP allocator_handle", [
+      OpenMP_AllocatorHandleNullAllocator,
+      OpenMP_AllocatorHandleDefaultMemAlloc,
+      OpenMP_AllocatorHandleLargeCapMemAlloc,
+      OpenMP_AllocatorHandleConstMemAlloc,
+      OpenMP_AllocatorHandleHighBwMemAlloc,
+      OpenMP_AllocatorHandleLowLatMemAlloc,
+      OpenMP_AllocatorHandleCgroupMemAlloc,
+      OpenMP_AllocatorHandlePteamMemAlloc,
+      OpenMP_AllocatorHandlethreadMemAlloc
+    ]>;
+
+def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
 #endif // OPENMP_ENUMS
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index dffb3c5e14b62..1fcbc63aab9f8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1883,4 +1883,31 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// [Spec 5.2] 6.5 allocate Directive
+//===----------------------------------------------------------------------===//
+def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
+    OpenMP_AlignClause, OpenMP_AllocatorClause
+  ]> {
+  let summary = "allocate directive";
+  let description = [{
+    The storage for each list item that appears in the allocate directive is
+    provided an allocation through the memory allocator.
+  }] # clausesDescription;
+
+  let arguments = !con((ins Variadic<AnyType>:$varList),
+                       clausesArgs);
+
+  // Override inherited assembly format to include `varList`.
+  let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" #
+                       clausesOptAssemblyFormat #
+                       ") attr-dict ";
+
+  let builders = [
+    OpBuilder<(ins CArg<"const AllocateDirOperands &">:$clauses)>
+  ];
+
+  let hasVerifier = 1;
+}
+
 #endif // OPENMP_OPS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index ffc84781f77ff..b632c618acf7a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3512,6 +3512,20 @@ LogicalResult ScanOp::verify() {
                    "reduction modifier");
 }
 
+/// Verifies align clause in allocate directive
+
+LogicalResult AllocateDirOp::verify() {
+  std::optional<u_int64_t> align = this->getAlign();
+
+  if (align.has_value()) {
+    if ((align.value() > 0) && ((align.value() & (align.value() - 1)) != 0))
+      return emitError() << "ALIGN value : " << align.value()
+                         << " must be power of 2";
+  }
+
+  return success();
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 7608ad57c7967..5088f2dfa7d7a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2993,3 +2993,27 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
   }
   llvm.return
 }
+
+// -----
+func.func @invalid_allocate_align_1(%arg0 : memref<i32>) -> () {
+  // expected-error @below {{failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+  omp.allocate_dir (%arg0 : memref<i32>) align(-1)
+
+  return
+}
+
+// -----
+func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
+  // expected-error @below {{must be power of 2}}
+  omp.allocate_dir (%arg0 : memref<i32>) align(3)
+
+  return
+}
+
+// -----
+func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
+  // expected-error @below {{invalid clause value}}
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
+
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 47cfc5278a5d0..4c50ed3230976 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3197,3 +3197,36 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   }
   return
 }
+
+// CHECK-LABEL: func.func @omp_allocate_dir(
+// CHECK-SAME: %[[ARG0:.*]]: memref<i32>,
+// CHECK-SAME: %[[ARG1:.*]]: memref<i32>) {
+func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
+
+  // Test with one data var
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>)
+  omp.allocate_dir (%arg0 : memref<i32>)
+
+  // Test with two data vars
+  // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>)
+  omp.allocate_dir (%arg0, %arg1: memref<i32>, memref<i32>)
+
+  // Test with one data var and align clause
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2)
+  omp.allocate_dir (%arg0 : memref<i32>) align(2)
+
+  // Test with one data var and allocator clause
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)
+
+  // Test with one data var, align clause and allocator clause
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
+  omp.allocate_dir (%arg0 : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
+
+  // Test with two data vars, align clause and allocator clause
+  // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
+  omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
+
+  return
+}
+



More information about the Mlir-commits mailing list