[Mlir-commits] [mlir] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to support allocator type defined in user program. (PR #157399)

Raghu Maddhipatla llvmlistbot at llvm.org
Tue Sep 9 12:03:10 PDT 2025


https://github.com/raghavendhra updated https://github.com/llvm/llvm-project/pull/157399

>From 0dfa7015d409927536f10635c3b565c977a734e3 Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Mon, 8 Sep 2025 02:13:39 -0500
Subject: [PATCH 1/2] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to
 support allocator type defined in user program.

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  4 +-
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        | 30 -------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 87 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         |  8 --
 4 files changed, 89 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 311c57fb4446c..2fbde606056c0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip<
                     extraClassDeclaration> {
 
   let arguments = (ins
-    OptionalAttr<AllocatorHandleAttr>:$allocator
+    DefaultValuedOptionalAttr<I64Attr, "0">:$allocator
   );
 
   let optAssemblyFormat = [{
-    `allocator` `(` custom<ClauseAttr>($allocator) `)`
+    `allocator` `(` custom<AllocatorHandle>($allocator) `)`
   }];
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index c080c3fac87d4..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -263,34 +263,4 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
   let assemblyFormat = "`(` $value `)`";
 }
 
-
-//===----------------------------------------------------------------------===//
-// allocator_handle enum.
-//===----------------------------------------------------------------------===//
-
-def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
-def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
-def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
-def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
-def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
-def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
-def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
-def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
-def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
-
-def AllocatorHandle : OpenMP_I32EnumAttr<
-    "AllocatorHandle",
-    "OpenMP allocator_handle", [
-      OpenMP_AllocatorHandleNullAllocator,
-      OpenMP_AllocatorHandleDefaultMemAlloc,
-      OpenMP_AllocatorHandleLargeCapMemAlloc,
-      OpenMP_AllocatorHandleConstMemAlloc,
-      OpenMP_AllocatorHandleHighBwMemAlloc,
-      OpenMP_AllocatorHandleLowLatMemAlloc,
-      OpenMP_AllocatorHandleCgroupMemAlloc,
-      OpenMP_AllocatorHandlePteamMemAlloc,
-      OpenMP_AllocatorHandlethreadMemAlloc
-    ]>;
-
-def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
 #endif // OPENMP_ENUMS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6e43f28e8d93d..dfef05d242465 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1245,6 +1245,93 @@ verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0)
+//===----------------------------------------------------------------------===//
+
+/// Parses a allocator clause. The value of allocator handle is an integer
+/// which is a combination of different allocator handles from
+/// `omp_allocator_handle_t`.
+///
+/// allocator-clause = `allocator` `(` allocator-value `)`
+static ParseResult parseAllocatorHandle(OpAsmParser &parser,
+                                        IntegerAttr &allocatorHandleAttr) {
+  StringRef allocatorKeyword;
+  int64_t allocator = 0;
+  if (succeeded(parser.parseOptionalKeyword("none"))) {
+    allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
+    return success();
+  }
+  auto parseKeyword = [&]() -> ParseResult {
+    if (failed(parser.parseKeyword(&allocatorKeyword)))
+      return failure();
+    if (allocatorKeyword == "omp_null_allocator")
+      allocator = 0;
+    else if (allocatorKeyword == "omp_default_mem_alloc")
+      allocator = 1;
+    else if (allocatorKeyword == "omp_large_cap_mem_alloc")
+      allocator = 2;
+    else if (allocatorKeyword == "omp_const_mem_alloc")
+      allocator = 3;
+    else if (allocatorKeyword == "omp_high_bw_mem_alloc")
+      allocator = 4;
+    else if (allocatorKeyword == "omp_low_lat_mem_alloc")
+      allocator = 5;
+    else if (allocatorKeyword == "omp_cgroup_mem_alloc")
+      allocator = 6;
+    else if (allocatorKeyword == "omp_pteam_mem_alloc")
+      allocator = 7;
+    else if (allocatorKeyword == "omp_thread_mem_alloc")
+      allocator = 8;
+    else
+      return parser.emitError(parser.getCurrentLocation())
+             << allocatorKeyword << " is not a valid allocator";
+    return success();
+  };
+  if (parser.parseCommaSeparatedList(parseKeyword))
+    return failure();
+  allocatorHandleAttr =
+      IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
+  return success();
+}
+
+/// Prints a allocator clause
+static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
+                                 IntegerAttr allocatorHandleAttr) {
+  int64_t allocator = allocatorHandleAttr.getInt();
+  StringRef allocatorHandle;
+  switch (allocator) {
+  case 0:
+    allocatorHandle = "omp_null_allocator";
+    break;
+  case 1:
+    allocatorHandle = "omp_default_mem_alloc";
+    break;
+  case 2:
+    allocatorHandle = "omp_large_cap_mem_alloc";
+    break;
+  case 3:
+    allocatorHandle = "omp_const_mem_alloc";
+    break;
+  case 4:
+    allocatorHandle = "omp_high_bw_mem_alloc";
+    break;
+  case 5:
+    allocatorHandle = "omp_low_lat_mem_alloc";
+    break;
+  case 6:
+    allocatorHandle = "omp_cgroup_mem_alloc";
+    break;
+  case 7:
+    allocatorHandle = "omp_pteam_mem_alloc";
+    break;
+  case 8:
+    allocatorHandle = "omp_thread_mem_alloc";
+    break;
+  }
+  p << allocatorHandle;
+}
+
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for Copyprivate
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 986c3844d0bb9..4e8266813a47e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3010,14 +3010,6 @@ func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
   return
 }
 
-// -----
-func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
-  // expected-error @below {{invalid clause value}}
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
-
-  return
-}
-
 // -----
 func.func @invalid_workdistribute_empty_region() -> () {
   omp.teams {

>From 2e248fda5579f59da28c8c91a43d43529e610a1d Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Tue, 9 Sep 2025 01:10:20 -0500
Subject: [PATCH 2/2] Add test-case for user-defined allocator value

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 53 ++++++++++----------
 mlir/test/Dialect/OpenMP/ops.mlir            |  6 ++-
 2 files changed, 32 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dfef05d242465..8bd4ba4f60988 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1262,34 +1262,32 @@ static ParseResult parseAllocatorHandle(OpAsmParser &parser,
     allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
     return success();
   }
-  auto parseKeyword = [&]() -> ParseResult {
-    if (failed(parser.parseKeyword(&allocatorKeyword)))
-      return failure();
-    if (allocatorKeyword == "omp_null_allocator")
-      allocator = 0;
-    else if (allocatorKeyword == "omp_default_mem_alloc")
-      allocator = 1;
-    else if (allocatorKeyword == "omp_large_cap_mem_alloc")
-      allocator = 2;
-    else if (allocatorKeyword == "omp_const_mem_alloc")
-      allocator = 3;
-    else if (allocatorKeyword == "omp_high_bw_mem_alloc")
-      allocator = 4;
-    else if (allocatorKeyword == "omp_low_lat_mem_alloc")
-      allocator = 5;
-    else if (allocatorKeyword == "omp_cgroup_mem_alloc")
-      allocator = 6;
-    else if (allocatorKeyword == "omp_pteam_mem_alloc")
-      allocator = 7;
-    else if (allocatorKeyword == "omp_thread_mem_alloc")
-      allocator = 8;
-    else
-      return parser.emitError(parser.getCurrentLocation())
-             << allocatorKeyword << " is not a valid allocator";
+  OptionalParseResult parsedInteger = parser.parseOptionalInteger(allocator);
+  if (parsedInteger.has_value()) {
+    allocatorHandleAttr =
+        IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
     return success();
-  };
-  if (parser.parseCommaSeparatedList(parseKeyword))
+  }
+  if (failed(parser.parseKeyword(&allocatorKeyword)))
     return failure();
+  if (allocatorKeyword == "omp_null_allocator")
+    allocator = 0;
+  else if (allocatorKeyword == "omp_default_mem_alloc")
+    allocator = 1;
+  else if (allocatorKeyword == "omp_large_cap_mem_alloc")
+    allocator = 2;
+  else if (allocatorKeyword == "omp_const_mem_alloc")
+    allocator = 3;
+  else if (allocatorKeyword == "omp_high_bw_mem_alloc")
+    allocator = 4;
+  else if (allocatorKeyword == "omp_low_lat_mem_alloc")
+    allocator = 5;
+  else if (allocatorKeyword == "omp_cgroup_mem_alloc")
+    allocator = 6;
+  else if (allocatorKeyword == "omp_pteam_mem_alloc")
+    allocator = 7;
+  else if (allocatorKeyword == "omp_thread_mem_alloc")
+    allocator = 8;
   allocatorHandleAttr =
       IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
   return success();
@@ -1328,6 +1326,9 @@ static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
   case 8:
     allocatorHandle = "omp_thread_mem_alloc";
     break;
+  default:
+    p << Twine(allocator).str();
+    return;
   }
   p << allocatorHandle;
 }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3c2e0a3b7cc15..b4ef4f02d6a9f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3225,7 +3225,7 @@ func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
 
   // Test with one data var and allocator clause
   // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(7)
 
   // Test with one data var, align clause and allocator clause
   // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
@@ -3235,6 +3235,10 @@ func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
   // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
   omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
 
+  // Test with one data var and user defined allocator clause
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(9)
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(9)
+
   return
 }
 



More information about the Mlir-commits mailing list