[Mlir-commits] [mlir] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to support allocator type defined in user program. (PR #157399)

Raghu Maddhipatla llvmlistbot at llvm.org
Thu Sep 18 07:04:05 PDT 2025


https://github.com/raghavendhra updated https://github.com/llvm/llvm-project/pull/157399

>From a6b5d960f8ce6217800b90597b1b572cca4dccbf Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Mon, 8 Sep 2025 02:13:39 -0500
Subject: [PATCH 1/3] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to
 support allocator type defined in user program.

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  4 +-
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        | 30 -------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 87 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         |  8 --
 4 files changed, 89 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 5f40abe62a0f6..675f62902e75b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip<
                     extraClassDeclaration> {
 
   let arguments = (ins
-    OptionalAttr<AllocatorHandleAttr>:$allocator
+    DefaultValuedOptionalAttr<I64Attr, "0">:$allocator
   );
 
   let optAssemblyFormat = [{
-    `allocator` `(` custom<ClauseAttr>($allocator) `)`
+    `allocator` `(` custom<AllocatorHandle>($allocator) `)`
   }];
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index c080c3fac87d4..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -263,34 +263,4 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
   let assemblyFormat = "`(` $value `)`";
 }
 
-
-//===----------------------------------------------------------------------===//
-// allocator_handle enum.
-//===----------------------------------------------------------------------===//
-
-def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
-def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
-def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
-def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
-def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
-def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
-def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
-def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
-def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
-
-def AllocatorHandle : OpenMP_I32EnumAttr<
-    "AllocatorHandle",
-    "OpenMP allocator_handle", [
-      OpenMP_AllocatorHandleNullAllocator,
-      OpenMP_AllocatorHandleDefaultMemAlloc,
-      OpenMP_AllocatorHandleLargeCapMemAlloc,
-      OpenMP_AllocatorHandleConstMemAlloc,
-      OpenMP_AllocatorHandleHighBwMemAlloc,
-      OpenMP_AllocatorHandleLowLatMemAlloc,
-      OpenMP_AllocatorHandleCgroupMemAlloc,
-      OpenMP_AllocatorHandlePteamMemAlloc,
-      OpenMP_AllocatorHandlethreadMemAlloc
-    ]>;
-
-def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
 #endif // OPENMP_ENUMS
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 3d70e28ed23ab..cee9230f1ff0f 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1250,6 +1250,93 @@ verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0)
+//===----------------------------------------------------------------------===//
+
+/// Parses a allocator clause. The value of allocator handle is an integer
+/// which is a combination of different allocator handles from
+/// `omp_allocator_handle_t`.
+///
+/// allocator-clause = `allocator` `(` allocator-value `)`
+static ParseResult parseAllocatorHandle(OpAsmParser &parser,
+                                        IntegerAttr &allocatorHandleAttr) {
+  StringRef allocatorKeyword;
+  int64_t allocator = 0;
+  if (succeeded(parser.parseOptionalKeyword("none"))) {
+    allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
+    return success();
+  }
+  auto parseKeyword = [&]() -> ParseResult {
+    if (failed(parser.parseKeyword(&allocatorKeyword)))
+      return failure();
+    if (allocatorKeyword == "omp_null_allocator")
+      allocator = 0;
+    else if (allocatorKeyword == "omp_default_mem_alloc")
+      allocator = 1;
+    else if (allocatorKeyword == "omp_large_cap_mem_alloc")
+      allocator = 2;
+    else if (allocatorKeyword == "omp_const_mem_alloc")
+      allocator = 3;
+    else if (allocatorKeyword == "omp_high_bw_mem_alloc")
+      allocator = 4;
+    else if (allocatorKeyword == "omp_low_lat_mem_alloc")
+      allocator = 5;
+    else if (allocatorKeyword == "omp_cgroup_mem_alloc")
+      allocator = 6;
+    else if (allocatorKeyword == "omp_pteam_mem_alloc")
+      allocator = 7;
+    else if (allocatorKeyword == "omp_thread_mem_alloc")
+      allocator = 8;
+    else
+      return parser.emitError(parser.getCurrentLocation())
+             << allocatorKeyword << " is not a valid allocator";
+    return success();
+  };
+  if (parser.parseCommaSeparatedList(parseKeyword))
+    return failure();
+  allocatorHandleAttr =
+      IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
+  return success();
+}
+
+/// Prints a allocator clause
+static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
+                                 IntegerAttr allocatorHandleAttr) {
+  int64_t allocator = allocatorHandleAttr.getInt();
+  StringRef allocatorHandle;
+  switch (allocator) {
+  case 0:
+    allocatorHandle = "omp_null_allocator";
+    break;
+  case 1:
+    allocatorHandle = "omp_default_mem_alloc";
+    break;
+  case 2:
+    allocatorHandle = "omp_large_cap_mem_alloc";
+    break;
+  case 3:
+    allocatorHandle = "omp_const_mem_alloc";
+    break;
+  case 4:
+    allocatorHandle = "omp_high_bw_mem_alloc";
+    break;
+  case 5:
+    allocatorHandle = "omp_low_lat_mem_alloc";
+    break;
+  case 6:
+    allocatorHandle = "omp_cgroup_mem_alloc";
+    break;
+  case 7:
+    allocatorHandle = "omp_pteam_mem_alloc";
+    break;
+  case 8:
+    allocatorHandle = "omp_thread_mem_alloc";
+    break;
+  }
+  p << allocatorHandle;
+}
+
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for Copyprivate
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 763f41c5420b8..af24d969064ab 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3033,14 +3033,6 @@ func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
   return
 }
 
-// -----
-func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
-  // expected-error @below {{invalid clause value}}
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
-
-  return
-}
-
 // -----
 func.func @invalid_workdistribute_empty_region() -> () {
   omp.teams {

>From d455dcafef67f2457bc51eb5608d786af28fec15 Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Tue, 9 Sep 2025 01:10:20 -0500
Subject: [PATCH 2/3] Add test-case for user-defined allocator value

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 53 ++++++++++----------
 mlir/test/Dialect/OpenMP/ops.mlir            |  6 ++-
 2 files changed, 32 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cee9230f1ff0f..7d15eab5c5232 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1267,34 +1267,32 @@ static ParseResult parseAllocatorHandle(OpAsmParser &parser,
     allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
     return success();
   }
-  auto parseKeyword = [&]() -> ParseResult {
-    if (failed(parser.parseKeyword(&allocatorKeyword)))
-      return failure();
-    if (allocatorKeyword == "omp_null_allocator")
-      allocator = 0;
-    else if (allocatorKeyword == "omp_default_mem_alloc")
-      allocator = 1;
-    else if (allocatorKeyword == "omp_large_cap_mem_alloc")
-      allocator = 2;
-    else if (allocatorKeyword == "omp_const_mem_alloc")
-      allocator = 3;
-    else if (allocatorKeyword == "omp_high_bw_mem_alloc")
-      allocator = 4;
-    else if (allocatorKeyword == "omp_low_lat_mem_alloc")
-      allocator = 5;
-    else if (allocatorKeyword == "omp_cgroup_mem_alloc")
-      allocator = 6;
-    else if (allocatorKeyword == "omp_pteam_mem_alloc")
-      allocator = 7;
-    else if (allocatorKeyword == "omp_thread_mem_alloc")
-      allocator = 8;
-    else
-      return parser.emitError(parser.getCurrentLocation())
-             << allocatorKeyword << " is not a valid allocator";
+  OptionalParseResult parsedInteger = parser.parseOptionalInteger(allocator);
+  if (parsedInteger.has_value()) {
+    allocatorHandleAttr =
+        IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
     return success();
-  };
-  if (parser.parseCommaSeparatedList(parseKeyword))
+  }
+  if (failed(parser.parseKeyword(&allocatorKeyword)))
     return failure();
+  if (allocatorKeyword == "omp_null_allocator")
+    allocator = 0;
+  else if (allocatorKeyword == "omp_default_mem_alloc")
+    allocator = 1;
+  else if (allocatorKeyword == "omp_large_cap_mem_alloc")
+    allocator = 2;
+  else if (allocatorKeyword == "omp_const_mem_alloc")
+    allocator = 3;
+  else if (allocatorKeyword == "omp_high_bw_mem_alloc")
+    allocator = 4;
+  else if (allocatorKeyword == "omp_low_lat_mem_alloc")
+    allocator = 5;
+  else if (allocatorKeyword == "omp_cgroup_mem_alloc")
+    allocator = 6;
+  else if (allocatorKeyword == "omp_pteam_mem_alloc")
+    allocator = 7;
+  else if (allocatorKeyword == "omp_thread_mem_alloc")
+    allocator = 8;
   allocatorHandleAttr =
       IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
   return success();
@@ -1333,6 +1331,9 @@ static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
   case 8:
     allocatorHandle = "omp_thread_mem_alloc";
     break;
+  default:
+    p << Twine(allocator).str();
+    return;
   }
   p << allocatorHandle;
 }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60b1f61135ac2..79046a72006d7 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3279,7 +3279,7 @@ func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
 
   // Test with one data var and allocator clause
   // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(7)
 
   // Test with one data var, align clause and allocator clause
   // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
@@ -3289,6 +3289,10 @@ func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
   // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
   omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
 
+  // Test with one data var and user defined allocator clause
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(9)
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(9)
+
   return
 }
 

>From c89d0ff593e18ef6264316b2913a2f7cdcaa0a3e Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Wed, 17 Sep 2025 17:38:03 -0500
Subject: [PATCH 3/3] Changed allocator clause definition to use Integer type
 value argument instead of IntegerAttr.

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  4 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 88 -------------------
 mlir/test/Dialect/OpenMP/ops.mlir             | 29 ++++--
 4 files changed, 24 insertions(+), 99 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 675f62902e75b..1eda5e4bc1618 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip<
                     extraClassDeclaration> {
 
   let arguments = (ins
-    DefaultValuedOptionalAttr<I64Attr, "0">:$allocator
+    Optional<I64>:$allocator
   );
 
   let optAssemblyFormat = [{
-    `allocator` `(` custom<AllocatorHandle>($allocator) `)`
+    `allocator` `(` $allocator `)`
   }];
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 830b36f440098..5c77e215467e4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2100,7 +2100,7 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
 //===----------------------------------------------------------------------===//
 // [Spec 5.2] 6.5 allocate Directive
 //===----------------------------------------------------------------------===//
-def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
+def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clauses = [
     OpenMP_AlignClause, OpenMP_AllocatorClause
   ]> {
   let summary = "allocate directive";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 7d15eab5c5232..3d70e28ed23ab 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1250,94 +1250,6 @@ verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0)
-//===----------------------------------------------------------------------===//
-
-/// Parses a allocator clause. The value of allocator handle is an integer
-/// which is a combination of different allocator handles from
-/// `omp_allocator_handle_t`.
-///
-/// allocator-clause = `allocator` `(` allocator-value `)`
-static ParseResult parseAllocatorHandle(OpAsmParser &parser,
-                                        IntegerAttr &allocatorHandleAttr) {
-  StringRef allocatorKeyword;
-  int64_t allocator = 0;
-  if (succeeded(parser.parseOptionalKeyword("none"))) {
-    allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
-    return success();
-  }
-  OptionalParseResult parsedInteger = parser.parseOptionalInteger(allocator);
-  if (parsedInteger.has_value()) {
-    allocatorHandleAttr =
-        IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
-    return success();
-  }
-  if (failed(parser.parseKeyword(&allocatorKeyword)))
-    return failure();
-  if (allocatorKeyword == "omp_null_allocator")
-    allocator = 0;
-  else if (allocatorKeyword == "omp_default_mem_alloc")
-    allocator = 1;
-  else if (allocatorKeyword == "omp_large_cap_mem_alloc")
-    allocator = 2;
-  else if (allocatorKeyword == "omp_const_mem_alloc")
-    allocator = 3;
-  else if (allocatorKeyword == "omp_high_bw_mem_alloc")
-    allocator = 4;
-  else if (allocatorKeyword == "omp_low_lat_mem_alloc")
-    allocator = 5;
-  else if (allocatorKeyword == "omp_cgroup_mem_alloc")
-    allocator = 6;
-  else if (allocatorKeyword == "omp_pteam_mem_alloc")
-    allocator = 7;
-  else if (allocatorKeyword == "omp_thread_mem_alloc")
-    allocator = 8;
-  allocatorHandleAttr =
-      IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
-  return success();
-}
-
-/// Prints a allocator clause
-static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
-                                 IntegerAttr allocatorHandleAttr) {
-  int64_t allocator = allocatorHandleAttr.getInt();
-  StringRef allocatorHandle;
-  switch (allocator) {
-  case 0:
-    allocatorHandle = "omp_null_allocator";
-    break;
-  case 1:
-    allocatorHandle = "omp_default_mem_alloc";
-    break;
-  case 2:
-    allocatorHandle = "omp_large_cap_mem_alloc";
-    break;
-  case 3:
-    allocatorHandle = "omp_const_mem_alloc";
-    break;
-  case 4:
-    allocatorHandle = "omp_high_bw_mem_alloc";
-    break;
-  case 5:
-    allocatorHandle = "omp_low_lat_mem_alloc";
-    break;
-  case 6:
-    allocatorHandle = "omp_cgroup_mem_alloc";
-    break;
-  case 7:
-    allocatorHandle = "omp_pteam_mem_alloc";
-    break;
-  case 8:
-    allocatorHandle = "omp_thread_mem_alloc";
-    break;
-  default:
-    p << Twine(allocator).str();
-    return;
-  }
-  p << allocatorHandle;
-}
-
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for Copyprivate
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 79046a72006d7..cbd863f88fd1f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3260,6 +3260,10 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   return
 }
 
+func.func @omp_init_allocator(%custom_allocator : i64) -> i64 {
+    return %custom_allocator : i64
+}
+
 // CHECK-LABEL: func.func @omp_allocate_dir(
 // CHECK-SAME: %[[ARG0:.*]]: memref<i32>,
 // CHECK-SAME: %[[ARG1:.*]]: memref<i32>) {
@@ -3278,20 +3282,29 @@ func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
   omp.allocate_dir (%arg0 : memref<i32>) align(2)
 
   // Test with one data var and allocator clause
-  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(7)
+  // CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64
+  %omp_default_mem_alloc = arith.constant 1 : i64
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(%[[VAL_1:.*]])
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(%omp_default_mem_alloc)
 
   // Test with one data var, align clause and allocator clause
-  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
-  omp.allocate_dir (%arg0 : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
+  // CHECK: %[[VAL_2:.*]] = arith.constant 7 : i64
+  %omp_pteam_mem_alloc = arith.constant 7 : i64
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>)  align(4) allocator(%[[VAL_2:.*]])
+  omp.allocate_dir (%arg0 : memref<i32>)  align(4) allocator(%omp_pteam_mem_alloc)
 
   // Test with two data vars, align clause and allocator clause
-  // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
-  omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
+  // CHECK: %[[VAL_3:.*]] = arith.constant 6 : i64
+  %omp_cgroup_mem_alloc = arith.constant 6 : i64
+  // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(8) allocator(%[[VAL_3:.*]])
+  omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(8) allocator(%omp_cgroup_mem_alloc)
 
   // Test with one data var and user defined allocator clause
-  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(9)
-  omp.allocate_dir (%arg0 : memref<i32>) allocator(9)
+  // CHECK: %[[VAL_4:.*]] = arith.constant 9 : i64
+  %custom_allocator = arith.constant 9 : i64
+  %custom_mem_alloc = func.call @omp_init_allocator(%custom_allocator) : (i64) -> (i64)
+  // CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(%[[VAL_5:.*]])
+  omp.allocate_dir (%arg0 : memref<i32>) allocator(%custom_mem_alloc)
 
   return
 }



More information about the Mlir-commits mailing list