[Mlir-commits] [mlir] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to support allocator type defined in user program. (PR #157399)

Raghu Maddhipatla llvmlistbot at llvm.org
Mon Sep 8 00:37:27 PDT 2025


https://github.com/raghavendhra updated https://github.com/llvm/llvm-project/pull/157399

>From 218d6b18727b459a3a713fa0476c2e86c35a3173 Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Mon, 8 Sep 2025 02:13:39 -0500
Subject: [PATCH] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to
 support allocator type defined in user program.

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  4 +-
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        | 30 -------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 87 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         |  8 --
 4 files changed, 89 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 311c57fb4446c..2fbde606056c0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip<
                     extraClassDeclaration> {
 
   let arguments = (ins
-    OptionalAttr<AllocatorHandleAttr>:$allocator
+    DefaultValuedOptionalAttr<I64Attr, "0">:$allocator
   );
 
   let optAssemblyFormat = [{
-    `allocator` `(` custom<ClauseAttr>($allocator) `)`
+    `allocator` `(` custom<AllocatorHandle>($allocator) `)`
   }];
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index c080c3fac87d4..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -263,34 +263,4 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
   let assemblyFormat = "`(` $value `)`";
 }
 
-
-//===----------------------------------------------------------------------===//
-// allocator_handle enum.
-//===----------------------------------------------------------------------===//
-
-def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
-def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
-def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
-def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
-def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
-def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
-def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
-def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
-def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
-
-def AllocatorHandle : OpenMP_I32EnumAttr<
-    "AllocatorHandle",
-    "OpenMP allocator_handle", [
-      OpenMP_AllocatorHandleNullAllocator,
-      OpenMP_AllocatorHandleDefaultMemAlloc,
-      OpenMP_AllocatorHandleLargeCapMemAlloc,
-      OpenMP_AllocatorHandleConstMemAlloc,
-      OpenMP_AllocatorHandleHighBwMemAlloc,
-      OpenMP_AllocatorHandleLowLatMemAlloc,
-      OpenMP_AllocatorHandleCgroupMemAlloc,
-      OpenMP_AllocatorHandlePteamMemAlloc,
-      OpenMP_AllocatorHandlethreadMemAlloc
-    ]>;
-
-def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
 #endif // OPENMP_ENUMS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6e43f28e8d93d..dfef05d242465 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1245,6 +1245,93 @@ verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0)
+//===----------------------------------------------------------------------===//
+
+/// Parses a allocator clause. The value of allocator handle is an integer
+/// which is a combination of different allocator handles from
+/// `omp_allocator_handle_t`.
+///
+/// allocator-clause = `allocator` `(` allocator-value `)`
+static ParseResult parseAllocatorHandle(OpAsmParser &parser,
+                                        IntegerAttr &allocatorHandleAttr) {
+  StringRef allocatorKeyword;
+  int64_t allocator = 0;
+  if (succeeded(parser.parseOptionalKeyword("none"))) {
+    allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
+    return success();
+  }
+  auto parseKeyword = [&]() -> ParseResult {
+    if (failed(parser.parseKeyword(&allocatorKeyword)))
+      return failure();
+    if (allocatorKeyword == "omp_null_allocator")
+      allocator = 0;
+    else if (allocatorKeyword == "omp_default_mem_alloc")
+      allocator = 1;
+    else if (allocatorKeyword == "omp_large_cap_mem_alloc")
+      allocator = 2;
+    else if (allocatorKeyword == "omp_const_mem_alloc")
+      allocator = 3;
+    else if (allocatorKeyword == "omp_high_bw_mem_alloc")
+      allocator = 4;
+    else if (allocatorKeyword == "omp_low_lat_mem_alloc")
+      allocator = 5;
+    else if (allocatorKeyword == "omp_cgroup_mem_alloc")
+      allocator = 6;
+    else if (allocatorKeyword == "omp_pteam_mem_alloc")
+      allocator = 7;
+    else if (allocatorKeyword == "omp_thread_mem_alloc")
+      allocator = 8;
+    else
+      return parser.emitError(parser.getCurrentLocation())
+             << allocatorKeyword << " is not a valid allocator";
+    return success();
+  };
+  if (parser.parseCommaSeparatedList(parseKeyword))
+    return failure();
+  allocatorHandleAttr =
+      IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
+  return success();
+}
+
+/// Prints a allocator clause
+static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
+                                 IntegerAttr allocatorHandleAttr) {
+  int64_t allocator = allocatorHandleAttr.getInt();
+  StringRef allocatorHandle;
+  switch (allocator) {
+  case 0:
+    allocatorHandle = "omp_null_allocator";
+    break;
+  case 1:
+    allocatorHandle = "omp_default_mem_alloc";
+    break;
+  case 2:
+    allocatorHandle = "omp_large_cap_mem_alloc";
+    break;
+  case 3:
+    allocatorHandle = "omp_const_mem_alloc";
+    break;
+  case 4:
+    allocatorHandle = "omp_high_bw_mem_alloc";
+    break;
+  case 5:
+    allocatorHandle = "omp_low_lat_mem_alloc";
+    break;
+  case 6:
+    allocatorHandle = "omp_cgroup_mem_alloc";
+    break;
+  case 7:
+    allocatorHandle = "omp_pteam_mem_alloc";
+    break;
+  case 8:
+    allocatorHandle = "omp_thread_mem_alloc";
+    break;
+  }
+  p << allocatorHandle;
+}
+
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for Copyprivate
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 986c3844d0bb9..4e8266813a47e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3010,14 +3010,6 @@ func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
   return
 }
 
-// -----
-func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
-  // expected-error @below {{invalid clause value}}
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
-
-  return
-}
-
 // -----
 func.func @invalid_workdistribute_empty_region() -> () {
   omp.teams {



More information about the Mlir-commits mailing list